diff --git a/src/DelegateDecompiler.Tests/EnumTests.cs b/src/DelegateDecompiler.Tests/EnumTests.cs index 739df74..8d0a0f1 100644 --- a/src/DelegateDecompiler.Tests/EnumTests.cs +++ b/src/DelegateDecompiler.Tests/EnumTests.cs @@ -286,11 +286,11 @@ public void Issue98B() Test(expected, compiled); } - [Test, Ignore("Not fixed yet")] + [Test] public void Issue160() { Expression> expected1 = x => (TestEnum?) x == TestEnum.Bar; - Expression> expected2 = x => (int?) x == (int?) TestEnum.Bar; + Expression> expected2 = x => (x.HasValue ? (TestEnum?) (x ?? 0) : null) == TestEnum.Bar; Func compiled = x => (TestEnum?) x == TestEnum.Bar; Test(expected1, expected2, compiled); } diff --git a/src/DelegateDecompiler/ExpressionHelper.cs b/src/DelegateDecompiler/ExpressionHelper.cs index efd2693..5ec669f 100644 --- a/src/DelegateDecompiler/ExpressionHelper.cs +++ b/src/DelegateDecompiler/ExpressionHelper.cs @@ -8,6 +8,11 @@ internal static class ExpressionHelper internal static Expression Default(Type type) => // LINQ to entities and possibly other providers don't support Expression.Default, so this gets the default // value and then uses an Expression.Constant instead - Expression.Constant(type.IsValueType ? Activator.CreateInstance(type) : null, type); + Expression.Constant(GetDefaultValue(type), type); + + internal static object GetDefaultValue(Type type) + { + return type.IsValueType ? Activator.CreateInstance(type) : null; + } } } diff --git a/src/DelegateDecompiler/OptimizeExpressionVisitor.cs b/src/DelegateDecompiler/OptimizeExpressionVisitor.cs index badf08c..7ce0434 100644 --- a/src/DelegateDecompiler/OptimizeExpressionVisitor.cs +++ b/src/DelegateDecompiler/OptimizeExpressionVisitor.cs @@ -10,18 +10,13 @@ class OptimizeExpressionVisitor : ExpressionVisitor protected override Expression VisitNew(NewExpression node) { // Test if this is a nullable type - if (IsNullable(node.Type) && node.Arguments.Count == 1) + if (node.Type.IsNullableType() && node.Arguments.Count == 1) { return Expression.Convert(Visit(node.Arguments[0]), node.Type); } return base.VisitNew(node); } - private static bool IsNullable(Type type) - { - return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); - } - readonly Dictionary expressionsCache = new Dictionary(); @@ -45,8 +40,7 @@ protected override Expression VisitConditional(ConditionalExpression node) var ifTrue = Visit(node.IfTrue); var ifFalse = Visit(node.IfFalse); - Expression expression; - if (IsCoalesce(test, ifTrue, out expression)) + if (IsCoalesce(test, ifTrue, out var expression)) { return Expression.Coalesce(expression, ifFalse); } @@ -215,7 +209,7 @@ private static bool TryConvert1(Expression hasValue, BinaryExpression getValueOr static Expression ConvertToNullable(Expression expression) { - if (!expression.Type.IsValueType || IsNullable(expression.Type)) return expression; + if (!expression.Type.IsValueType || expression.Type.IsNullableType()) return expression; var operand = expression.NodeType == ExpressionType.Convert ? ((UnaryExpression) expression).Operand @@ -227,7 +221,7 @@ static Expression ConvertToNullable(Expression expression) static Expression UnwrapConvertToNullable(Expression expression) { var unary = expression as UnaryExpression; - if (unary != null && expression.NodeType == ExpressionType.Convert && IsNullable(expression.Type)) + if (unary != null && expression.NodeType == ExpressionType.Convert && expression.Type.IsNullableType()) { return unary.Operand; } @@ -266,7 +260,7 @@ static bool IsCoalesce(Expression hasValue, Expression getValueOrDefault, out Ex static bool IsHasValue(Expression expression, out MemberExpression property) { property = expression as MemberExpression; - return property != null && property.Member.Name == "HasValue" && property.Expression != null && IsNullable(property.Expression.Type); + return property != null && property.Member.Name == "HasValue" && property.Expression != null && property.Expression.Type.IsNullableType(); } static bool IsGetValueOrDefault(Expression expression, out MethodCallExpression method) @@ -277,12 +271,13 @@ static bool IsGetValueOrDefault(Expression expression, out MethodCallExpression static bool IsGetValueOrDefault(MethodCallExpression method) { - return method.Method.Name == "GetValueOrDefault" && method.Object != null && IsNullable(method.Object.Type); + return method.Method.Name == "GetValueOrDefault" && method.Object != null && method.Object.Type.IsNullableType(); } protected override Expression VisitBinary(BinaryExpression node) { var left = Visit(node.Left); + var right = Visit(node.Right); if (node.Right is ConstantExpression rightConstant) { if (rightConstant.Value as bool? == false) @@ -310,7 +305,7 @@ left is MethodCallExpression expression && if (node.NodeType == ExpressionType.And) { - if (ExtractNullableArgument(node.Right, left, out var result)) + if (ExtractNullableArgument(right, left, out var result)) { return Visit(result); } diff --git a/src/DelegateDecompiler/Processor.cs b/src/DelegateDecompiler/Processor.cs index 8d2681f..2e00f99 100644 --- a/src/DelegateDecompiler/Processor.cs +++ b/src/DelegateDecompiler/Processor.cs @@ -663,8 +663,15 @@ Expression Process() } else if (state.Instruction.OpCode == OpCodes.Newobj) { - var constructor = (ConstructorInfo)state.Instruction.Operand; - state.Stack.Push(Expression.New(constructor, GetArguments(state, constructor))); + var constructor = (ConstructorInfo) state.Instruction.Operand; + if (constructor.DeclaringType.IsNullableType() && constructor.GetParameters().Length == 1) + { + state.Stack.Push(Expression.Convert(state.Stack.Pop(), constructor.DeclaringType)); + } + else + { + state.Stack.Push(Expression.New(constructor, GetArguments(state, constructor))); + } } else if (state.Instruction.OpCode == OpCodes.Call || state.Instruction.OpCode == OpCodes.Callvirt) { diff --git a/src/DelegateDecompiler/TypeExtensions.cs b/src/DelegateDecompiler/TypeExtensions.cs new file mode 100644 index 0000000..ecd310a --- /dev/null +++ b/src/DelegateDecompiler/TypeExtensions.cs @@ -0,0 +1,12 @@ +using System; + +namespace DelegateDecompiler +{ + static class TypeExtensions + { + public static bool IsNullableType(this Type type) + { + return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); + } + } +}