diff --git a/src/Hyperion.Tests/SurrogateTests.cs b/src/Hyperion.Tests/SurrogateTests.cs index 1fe5e6e5..66640b7b 100644 --- a/src/Hyperion.Tests/SurrogateTests.cs +++ b/src/Hyperion.Tests/SurrogateTests.cs @@ -7,6 +7,7 @@ // ----------------------------------------------------------------------- #endregion +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -15,7 +16,7 @@ namespace Hyperion.Tests { #region test support classes - + public interface IOriginal { ISurrogate ToSurrogate(); @@ -30,13 +31,30 @@ public class Foo : IOriginal { public string Bar { get; set; } + public Foo() { } + + public Foo(string bar) + { + Bar = bar; + } + public ISurrogate ToSurrogate() => new FooSurrogate { Bar = Bar }; } - public class FooSurrogate : ISurrogate + public class ChildFoo : Foo + { + public ChildFoo() { } + + public ChildFoo(string bar) + { + Bar = bar; + } + } + + public class FooSurrogate : ISurrogate, IEquatable, IEquatable { public string Bar { get; set; } @@ -60,6 +78,33 @@ public Foo Restore() Bar = Bar }; } + + public bool Equals(FooSurrogate other) + { + if (ReferenceEquals(null, other)) return false; + if (ReferenceEquals(this, other)) return true; + return string.Equals(Bar, other.Bar); + } + + public bool Equals(Foo other) + { + if (other == null) return false; + return Equals(other.ToSurrogate()); + } + + public override bool Equals(object obj) + { + if (ReferenceEquals(null, obj)) return false; + if (ReferenceEquals(this, obj)) return true; + var foo = obj as Foo; + if (foo != null) return Equals(foo); + return Equals(obj as FooSurrogate); + } + + public override int GetHashCode() + { + return Bar.GetHashCode(); + } } public class SurrogatedKey : IOriginal @@ -72,7 +117,7 @@ public SurrogatedKey(string key) } public ISurrogate ToSurrogate() => new KeySurrogate(Key); - public override bool Equals(object obj) => obj is SurrogatedKey && Key == ((SurrogatedKey) obj).Key; + public override bool Equals(object obj) => obj is SurrogatedKey && Key == ((SurrogatedKey)obj).Key; public override int GetHashCode() => Key?.GetHashCode() ?? 0; } @@ -172,11 +217,54 @@ public void CanSerializeWithSurrogateInCollection() serializer.Serialize(dictionary, stream); stream.Position = 0; - var actual = serializer.Deserialize> (stream); + var actual = serializer.Deserialize>(stream); Assert.Equal(key, actual.Keys.First()); Assert.Equal(foo.Bar, actual[key].Bar); Assert.Equal(2, invoked.Count); } + + [Fact] + public void CanDeserializeSurrogateWithIEquatableInsideArrays() + { + var surrogating = 0; + var desurrogate = new List(); + var serializer = new Serializer(new SerializerOptions( + preserveObjectReferences: true, + surrogates: new[] + { + Surrogate.Create( + from => + { + surrogating++; + return from.ToSurrogate(); + }, + to => + { + desurrogate.Add(to); + return to.FromSurrogate(); + } + ), + })); + + var stream = new MemoryStream(); + var expected = new Foo[] + { + new ChildFoo("one"), + new ChildFoo("two"), + new ChildFoo("one"), + }; + + serializer.Serialize(expected, stream); + stream.Position = 0; + + var actual = (Foo[])serializer.Deserialize(stream); + + Assert.Equal(expected.Length, actual.Length); + Assert.Equal(expected[0].Bar, actual[0].Bar); + Assert.Equal(3, desurrogate.Count); + Assert.Equal(3, surrogating); + } + } }