From f53d089f80aa56c070aa650eb2efd4e6af683ad5 Mon Sep 17 00:00:00 2001 From: Matthias Vill Date: Sat, 23 Sep 2023 23:29:04 +0200 Subject: [PATCH] Improve invocation support --- src/LinqKit.Core/ExpressionExpander.cs | 37 ++++++++--- .../ExpressionExpanderTests.cs | 62 +++++++++++++++++++ 2 files changed, 91 insertions(+), 8 deletions(-) diff --git a/src/LinqKit.Core/ExpressionExpander.cs b/src/LinqKit.Core/ExpressionExpander.cs index 13b962c..1d77275 100644 --- a/src/LinqKit.Core/ExpressionExpander.cs +++ b/src/LinqKit.Core/ExpressionExpander.cs @@ -34,14 +34,7 @@ protected LambdaExpression EvaluateTarget(Expression target) } } - var lambda = target.EvaluateExpression() as LambdaExpression; - - if (lambda == null) - { - throw new InvalidOperationException($"Invoke cannot evaluate LambdaExpression from '{target}'. Ensure that your function/property/member returns LambdaExpression"); - } - - return lambda; + return target.EvaluateExpression() as LambdaExpression; } /// @@ -53,6 +46,10 @@ protected override Expression VisitInvocation(InvocationExpression iv) var target = iv.Expression; var lambda = EvaluateTarget(target); + if (lambda == null) + { + return base.VisitInvocation(iv); + } var body = ExpressionReplacer.GetBody(lambda, iv.Arguments); @@ -118,6 +115,11 @@ protected override Expression VisitMethodCall(MethodCallExpression m) var target = m.Arguments[0]; var lambda = EvaluateTarget(target); + if (lambda == null) + { + throw new InvalidOperationException($"Invoke cannot evaluate LambdaExpression from '{target}'. Ensure that your function/property/member returns LambdaExpression"); + } + var replaceVars = new Dictionary(); for (int i = 0; i < lambda.Parameters.Count; i++) { @@ -129,6 +131,25 @@ protected override Expression VisitMethodCall(MethodCallExpression m) return Visit(body); } + if (m.Method.Name == nameof(Action.Invoke) + && m.Method.DeclaringType.GetTypeInfo().IsSubclassOf(typeof(Delegate))) + { + var lambda = EvaluateTarget(m.Object); + + if (lambda != null) + { + var replaceVars = new Dictionary(); + for (int i = 0; i < lambda.Parameters.Count; i++) + { + replaceVars.Add(lambda.Parameters[i], Visit(m.Arguments[i])); + } + + var body = ExpressionReplacer.Replace(lambda.Body, replaceVars); + + return Visit(body); + } + } + if (GetExpandLambda(m.Method, out var methodLambda)) { var replaceVars = new Dictionary(); diff --git a/tests/LinqKit.Tests.Net452/ExpressionExpanderTests.cs b/tests/LinqKit.Tests.Net452/ExpressionExpanderTests.cs index 3d80140..867bb12 100644 --- a/tests/LinqKit.Tests.Net452/ExpressionExpanderTests.cs +++ b/tests/LinqKit.Tests.Net452/ExpressionExpanderTests.cs @@ -64,6 +64,68 @@ public void ExpressionExpander_Expression_Block() Assert.Equal(lambda.Invoke(42), expandedLambda.Invoke(42)); } + [Fact] + public void ExpressionExpander_Expression_InvokeExpressionRemoved() + { + Expression> lambda = o => o; + + var expandedLambda = Linq.Expr((object o) => lambda.Invoke(o)) + .Expand(); + Assert.Equal(ExpressionType.Parameter, expandedLambda.Body.NodeType); + Assert.Equal(lambda.ToString(), expandedLambda.ToString()); + Assert.Equal(lambda.Invoke(42), expandedLambda.Invoke(42)); + } + + [Fact] + public void ExpressionExpander_Expression_CompileAndInvokeExpressionRemoved() + { + Expression> lambda = o => o; + + var expandedLambda = Linq.Expr((object o) => lambda.Compile()(o)) + .Expand(); + Assert.Equal(ExpressionType.Parameter, expandedLambda.Body.NodeType); + Assert.Equal(lambda.ToString(), expandedLambda.ToString()); + Assert.Equal(lambda.Invoke(42), expandedLambda.Invoke(42)); + } + + [Fact] + public void ExpressionExpander_Expression_CompileAndInvokeOnExpressionRemoved() + { + Expression> lambda = o => o; + + var expandedLambda = Linq.Expr((object o) => lambda.Compile().Invoke(o)) + .Expand(); + Assert.Equal(ExpressionType.Parameter, expandedLambda.Body.NodeType); + Assert.Equal(lambda.ToString(), expandedLambda.ToString()); + Assert.Equal(lambda.Invoke(42), expandedLambda.Invoke(42)); + } + + [Fact] + public void ExpressionExpander_Expression_InvokeDelegate() + { + Func func = o => o.ToString(); + + Expression> lambda = o => func(o); + + var expandedLambda = Linq.Expr((object o) => lambda.Invoke(o)) + .Expand(); + Assert.Equal(lambda.ToString(), expandedLambda.ToString()); + Assert.Equal(lambda.Invoke(42), expandedLambda.Invoke(42)); + } + + [Fact] + public void ExpressionExpander_Expression_InvokeOnDelegate() + { + Func func = o => o.ToString(); + + Expression> lambda = o => func.Invoke(o); + + var expandedLambda = Linq.Expr((object o) => lambda.Invoke(o)) + .Expand(); + Assert.Equal(lambda.ToString(), expandedLambda.ToString()); + Assert.Equal(lambda.Invoke(42), expandedLambda.Invoke(42)); + } + [Fact] public void ExpressionExpander_Expression_Throw() {