diff --git a/src/DelegateDecompiler.Tests/EnumTests.cs b/src/DelegateDecompiler.Tests/EnumTests.cs index ac6b2f8..e07ccb9 100644 --- a/src/DelegateDecompiler.Tests/EnumTests.cs +++ b/src/DelegateDecompiler.Tests/EnumTests.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Linq.Expressions; using NUnit.Framework; @@ -293,15 +294,23 @@ public void Issue98B() Test(expected, compiled); } - [Test, Ignore("Not fixed yet")] + [Test] public void Issue160() { Expression> expected1 = x => (TestEnum?) x == TestEnum.Bar; - Expression> expected2 = x => (int?) x == (int?) TestEnum.Bar; + Expression> expected2 = x => (x.HasValue ? (TestEnum?) (x ?? 0) : null) == TestEnum.Bar; Func compiled = x => (TestEnum?) x == TestEnum.Bar; Test(expected1, expected2, compiled); } + [Test] + public void Issue176Array() + { + Expression> expected = x => new [] {TestEnum.Foo, TestEnum.Bar}.Contains(x); + Func compiled = x => new[] {TestEnum.Foo, TestEnum.Bar}.Contains(x); + Test(expected, compiled); + } + private static bool TestEnumMethod(TestEnum p0) { throw new NotImplementedException(); diff --git a/src/DelegateDecompiler/ExpressionHelper.cs b/src/DelegateDecompiler/ExpressionHelper.cs index efd2693..5ec669f 100644 --- a/src/DelegateDecompiler/ExpressionHelper.cs +++ b/src/DelegateDecompiler/ExpressionHelper.cs @@ -8,6 +8,11 @@ internal static class ExpressionHelper internal static Expression Default(Type type) => // LINQ to entities and possibly other providers don't support Expression.Default, so this gets the default // value and then uses an Expression.Constant instead - Expression.Constant(type.IsValueType ? Activator.CreateInstance(type) : null, type); + Expression.Constant(GetDefaultValue(type), type); + + internal static object GetDefaultValue(Type type) + { + return type.IsValueType ? Activator.CreateInstance(type) : null; + } } } diff --git a/src/DelegateDecompiler/OptimizeExpressionVisitor.cs b/src/DelegateDecompiler/OptimizeExpressionVisitor.cs index badf08c..7ce0434 100644 --- a/src/DelegateDecompiler/OptimizeExpressionVisitor.cs +++ b/src/DelegateDecompiler/OptimizeExpressionVisitor.cs @@ -10,18 +10,13 @@ class OptimizeExpressionVisitor : ExpressionVisitor protected override Expression VisitNew(NewExpression node) { // Test if this is a nullable type - if (IsNullable(node.Type) && node.Arguments.Count == 1) + if (node.Type.IsNullableType() && node.Arguments.Count == 1) { return Expression.Convert(Visit(node.Arguments[0]), node.Type); } return base.VisitNew(node); } - private static bool IsNullable(Type type) - { - return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); - } - readonly Dictionary expressionsCache = new Dictionary(); @@ -45,8 +40,7 @@ protected override Expression VisitConditional(ConditionalExpression node) var ifTrue = Visit(node.IfTrue); var ifFalse = Visit(node.IfFalse); - Expression expression; - if (IsCoalesce(test, ifTrue, out expression)) + if (IsCoalesce(test, ifTrue, out var expression)) { return Expression.Coalesce(expression, ifFalse); } @@ -215,7 +209,7 @@ private static bool TryConvert1(Expression hasValue, BinaryExpression getValueOr static Expression ConvertToNullable(Expression expression) { - if (!expression.Type.IsValueType || IsNullable(expression.Type)) return expression; + if (!expression.Type.IsValueType || expression.Type.IsNullableType()) return expression; var operand = expression.NodeType == ExpressionType.Convert ? ((UnaryExpression) expression).Operand @@ -227,7 +221,7 @@ static Expression ConvertToNullable(Expression expression) static Expression UnwrapConvertToNullable(Expression expression) { var unary = expression as UnaryExpression; - if (unary != null && expression.NodeType == ExpressionType.Convert && IsNullable(expression.Type)) + if (unary != null && expression.NodeType == ExpressionType.Convert && expression.Type.IsNullableType()) { return unary.Operand; } @@ -266,7 +260,7 @@ static bool IsCoalesce(Expression hasValue, Expression getValueOrDefault, out Ex static bool IsHasValue(Expression expression, out MemberExpression property) { property = expression as MemberExpression; - return property != null && property.Member.Name == "HasValue" && property.Expression != null && IsNullable(property.Expression.Type); + return property != null && property.Member.Name == "HasValue" && property.Expression != null && property.Expression.Type.IsNullableType(); } static bool IsGetValueOrDefault(Expression expression, out MethodCallExpression method) @@ -277,12 +271,13 @@ static bool IsGetValueOrDefault(Expression expression, out MethodCallExpression static bool IsGetValueOrDefault(MethodCallExpression method) { - return method.Method.Name == "GetValueOrDefault" && method.Object != null && IsNullable(method.Object.Type); + return method.Method.Name == "GetValueOrDefault" && method.Object != null && method.Object.Type.IsNullableType(); } protected override Expression VisitBinary(BinaryExpression node) { var left = Visit(node.Left); + var right = Visit(node.Right); if (node.Right is ConstantExpression rightConstant) { if (rightConstant.Value as bool? == false) @@ -310,7 +305,7 @@ left is MethodCallExpression expression && if (node.NodeType == ExpressionType.And) { - if (ExtractNullableArgument(node.Right, left, out var result)) + if (ExtractNullableArgument(right, left, out var result)) { return Visit(result); } diff --git a/src/DelegateDecompiler/Processor.cs b/src/DelegateDecompiler/Processor.cs index aec60ab..2e00f99 100644 --- a/src/DelegateDecompiler/Processor.cs +++ b/src/DelegateDecompiler/Processor.cs @@ -652,8 +652,7 @@ Expression Process() { var operand = (Type) state.Instruction.Operand; var expression = state.Stack.Pop(); - var size = expression.Expression as ConstantExpression; - if (size != null && (int) size.Value == 0) // optimization + if (expression.Expression is ConstantExpression size && (int) size.Value == 0) // optimization state.Stack.Push(Expression.NewArrayInit(operand)); else state.Stack.Push(Expression.NewArrayBounds(operand, expression)); @@ -664,8 +663,15 @@ Expression Process() } else if (state.Instruction.OpCode == OpCodes.Newobj) { - var constructor = (ConstructorInfo)state.Instruction.Operand; - state.Stack.Push(Expression.New(constructor, GetArguments(state, constructor))); + var constructor = (ConstructorInfo) state.Instruction.Operand; + if (constructor.DeclaringType.IsNullableType() && constructor.GetParameters().Length == 1) + { + state.Stack.Push(Expression.Convert(state.Stack.Pop(), constructor.DeclaringType)); + } + else + { + state.Stack.Push(Expression.New(constructor, GetArguments(state, constructor))); + } } else if (state.Instruction.OpCode == OpCodes.Call || state.Instruction.OpCode == OpCodes.Callvirt) { @@ -1018,12 +1024,11 @@ static void StElem(ProcessorState state) var index = state.Stack.Pop(); var array = state.Stack.Pop(); - var newArray = array.Expression as NewArrayExpression; - if (newArray != null) + if (array.Expression is NewArrayExpression newArray) { - var expressions = CreateArrayInitExpressions(newArray, value, index); - var newArrayInit = Expression.NewArrayInit(array.Type.GetElementType(), expressions); - array.Expression = newArrayInit; + var elementType = array.Type.GetElementType(); + var expressions = CreateArrayInitExpressions(elementType, newArray, value, index); + array.Expression = Expression.NewArrayInit(elementType, expressions); } else { @@ -1031,25 +1036,37 @@ static void StElem(ProcessorState state) } } - static IEnumerable CreateArrayInitExpressions(NewArrayExpression newArray, Expression valueExpression, Expression indexExpression) + static IEnumerable CreateArrayInitExpressions( + Type elementType, NewArrayExpression newArray, Expression valueExpression, Expression indexExpression) { + var indexGetter = (Func) Expression.Lambda(indexExpression).Compile(); + var index = indexGetter(); + + Expression[] expressions; if (newArray.NodeType == ExpressionType.NewArrayInit) { - var indexGetter = (Func) Expression.Lambda(indexExpression).Compile(); - var index = indexGetter(); - var expressions = newArray.Expressions.ToArray(); - + expressions = newArray.Expressions.ToArray(); if (index >= newArray.Expressions.Count) { Array.Resize(ref expressions, index + 1); } - expressions[index] = valueExpression; + } + else if (newArray.NodeType == ExpressionType.NewArrayBounds) + { + var sizeExpression = newArray.Expressions.Single(); + var sizeGetter = (Func) Expression.Lambda(sizeExpression).Compile(); + var getter = sizeGetter(); - return expressions; + expressions = Enumerable.Repeat(ExpressionHelper.Default(elementType), getter).ToArray(); + } + else + { + throw new NotSupportedException(); } - return new[] {valueExpression}; + expressions[index] = AdjustType(valueExpression, elementType); + return expressions; } static void LdC(ProcessorState state, int i) diff --git a/src/DelegateDecompiler/TypeExtensions.cs b/src/DelegateDecompiler/TypeExtensions.cs new file mode 100644 index 0000000..ecd310a --- /dev/null +++ b/src/DelegateDecompiler/TypeExtensions.cs @@ -0,0 +1,12 @@ +using System; + +namespace DelegateDecompiler +{ + static class TypeExtensions + { + public static bool IsNullableType(this Type type) + { + return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); + } + } +}