From eaf347310599116da78e497b150d90c40409250e Mon Sep 17 00:00:00 2001 From: Nikolaj Pedersen Date: Mon, 16 Dec 2024 15:41:25 +0100 Subject: [PATCH] Add more code --- src/Pose/IL/MethodRewriter.cs | 472 ++++++++++++++++++ src/Pose/Pose.csproj | 6 +- src/Sandbox/Program.cs | 71 ++- .../Pose.Tests/IL/AsyncMethodRewriterTests.cs | 121 +++++ test/Pose.Tests/Pose.Tests.csproj | 1 + 5 files changed, 629 insertions(+), 42 deletions(-) create mode 100644 test/Pose.Tests/IL/AsyncMethodRewriterTests.cs diff --git a/src/Pose/IL/MethodRewriter.cs b/src/Pose/IL/MethodRewriter.cs index 1cbcc35..ba495f1 100644 --- a/src/Pose/IL/MethodRewriter.cs +++ b/src/Pose/IL/MethodRewriter.cs @@ -197,6 +197,478 @@ public MethodBase Rewrite() return dynamicMethod; } + private static Type GetStateMachineType(MethodBase method) + { + var stateMachineType = method + ?.GetCustomAttribute() + ?.StateMachineType; + + return stateMachineType; + } + + private static (MethodInfo StartMethod, MethodInfo CreateMethod, PropertyInfo TaskProperty, MethodInfo OriginalMethod) GetMethods(MethodInfo method) + { + var originalMethod = method; + var originalMethodReturnType = + originalMethod.ReturnType.IsGenericType + ? originalMethod.ReturnType.GetGenericArguments()[0] + : typeof(void); + + const string startMethodName = nameof(AsyncTaskMethodBuilder.Start); + var startMethod = (originalMethodReturnType == typeof(void) + ? typeof(AsyncTaskMethodBuilder).GetMethod(startMethodName) + : typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetMethod(startMethodName)) ?? throw new Exception($"Cannot get {startMethodName} method"); + + const string taskPropertyName = nameof(AsyncTaskMethodBuilder.Task); + var taskProperty = (originalMethodReturnType == typeof(void) + ? typeof(AsyncTaskMethodBuilder).GetProperty(taskPropertyName) + : typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetProperty(taskPropertyName)) ?? throw new Exception($"Cannot get {taskPropertyName} property"); + + const string createMethodName = nameof(AsyncTaskMethodBuilder.Create); + var createMethod = (originalMethodReturnType == typeof(void) + ? typeof(AsyncTaskMethodBuilder).GetMethod(createMethodName) + : typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetMethod(createMethodName)) ?? throw new Exception($"Cannot get {createMethodName} method"); + + return (startMethod, createMethod, taskProperty, originalMethod); + } + + public MethodBase RewriteAsync() + { + var (startMethod, createMethod, taskProperty, originalMethod) = GetMethods((MethodInfo)_method); + + var stateMachine = GetStateMachineType((MethodInfo)_method); + var typeWithRewrittenMoveNext = RewriteMoveNext(stateMachine); + + var moveNextMethodInfo = typeWithRewrittenMoveNext.GetMethod(nameof(IAsyncStateMachine.MoveNext)); + + var rewrittenOriginalMethod = new DynamicMethod( + name: StubHelper.CreateStubNameFromMethod("impl", originalMethod), + returnType: originalMethod.ReturnType, + parameterTypes: originalMethod.GetParameters().Select(p => p.ParameterType).ToArray(), + m: originalMethod.Module, + skipVisibility: true + ); + + var methodBody = originalMethod.GetMethodBody() + ?? throw new MethodRewriteException($"Method {moveNextMethodInfo.Name} does not have a body"); + var locals = methodBody.LocalVariables; + + var ilGenerator = rewrittenOriginalMethod.GetILGenerator(); + + foreach (var local in locals) + { + if (locals[0].LocalType == stateMachine) + { + // References to the original state machine must be re-targeted to the rewritten state machine + ilGenerator.DeclareLocal(typeWithRewrittenMoveNext, local.IsPinned); + } + else + { + ilGenerator.DeclareLocal(local.LocalType, local.IsPinned); + } + } + + var constructorInfo = typeWithRewrittenMoveNext.GetConstructors()[0]; + ilGenerator.Emit(OpCodes.Newobj, constructorInfo); + ilGenerator.Emit(OpCodes.Stloc_0); + ilGenerator.Emit(OpCodes.Ldloc_0); + + ilGenerator.Emit(OpCodes.Call, createMethod); + + var builderField = typeWithRewrittenMoveNext.GetField("<>t__builder") ?? throw new Exception("Cannot get builder field"); + ilGenerator.Emit(OpCodes.Stfld, builderField); + + ilGenerator.Emit(OpCodes.Ldloc_0); + ilGenerator.Emit(OpCodes.Ldc_I4_M1); + var stateField = typeWithRewrittenMoveNext.GetField("<>1__state") ?? throw new Exception("Cannot get state field"); + ilGenerator.Emit(OpCodes.Stfld, stateField); + + ilGenerator.Emit(OpCodes.Ldloc_0); + ilGenerator.Emit(OpCodes.Ldflda, builderField); + ilGenerator.Emit(OpCodes.Ldloca_S, 0); + + var genericMethod = startMethod.MakeGenericMethod(typeWithRewrittenMoveNext); + ilGenerator.Emit(OpCodes.Call, genericMethod); + + ilGenerator.Emit(OpCodes.Ldloc_0); + ilGenerator.Emit(OpCodes.Ldflda, builderField); + + ilGenerator.Emit(OpCodes.Call, taskProperty.GetMethod); + + ilGenerator.Emit(OpCodes.Ret); + +#if TRACE + var ilBytes = ilGenerator.GetILBytes(); + var browsableDynamicMethod = new BrowsableDynamicMethod(rewrittenOriginalMethod, new DynamicMethodBody(ilBytes, locals)); + Console.WriteLine("\n" + rewrittenOriginalMethod); + + foreach (var instruction in browsableDynamicMethod.GetInstructions()) + { + Console.WriteLine(instruction); + } +#endif + + return rewrittenOriginalMethod; + } + + public static Type RewriteMoveNext(Type stateMachine) + { + var ab = AssemblyBuilder.DefineDynamicAssembly(new AssemblyName("AsyncAssembly"), AssemblyBuilderAccess.RunAndCollect); + var mb = ab.DefineDynamicModule("AsyncModule"); + var tb = mb.DefineType($"{stateMachine.Name}__Rewrite", TypeAttributes.Class | TypeAttributes.Public | TypeAttributes.Sealed); + tb.AddInterfaceImplementation(typeof(IAsyncStateMachine)); + + var fields = stateMachine.GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + .ToList() + .Select(f => tb.DefineField(f.Name, f.FieldType, FieldAttributes.Public)) + .ToArray(); + + var fieldDict = fields.ToDictionary(f => f.Name); + + stateMachine.GetMethods(BindingFlags.NonPublic | BindingFlags.Instance) + .ToList() + .ForEach(m => + { + // Console.WriteLine(m.Name); + var _exceptionBlockLevel = 0; + TypeInfo _constrainedType = null; + + var parameters = m.GetParameters().Select(p => p.ParameterType).ToArray(); + var meth = tb.DefineMethod(m.Name, MethodAttributes.Public | MethodAttributes.Virtual, m.ReturnType, parameters); + + var methodBody = m.GetMethodBody() ?? throw new MethodRewriteException($"Method {m.Name} does not have a body"); + var locals = methodBody.LocalVariables; + var targetInstructions = new Dictionary(); + var handlers = new List(); + + var ilGenerator = meth.GetILGenerator(); + var instructions = m.GetInstructions(); + + foreach (var clause in methodBody.ExceptionHandlingClauses) + { + var handler = new ExceptionHandler + { + Flags = clause.Flags, + CatchType = clause.Flags == ExceptionHandlingClauseOptions.Clause ? clause.CatchType : null, + TryStart = clause.TryOffset, + TryEnd = clause.TryOffset + clause.TryLength, + FilterStart = clause.Flags == ExceptionHandlingClauseOptions.Filter ? clause.FilterOffset : -1, + HandlerStart = clause.HandlerOffset, + HandlerEnd = clause.HandlerOffset + clause.HandlerLength + }; + handlers.Add(handler); + } + + foreach (var local in locals) + { + ilGenerator.DeclareLocal(local.LocalType, local.IsPinned); + } + + var ifTargets = instructions + .Where(i => i.Operand is Instruction) + .Select(i => i.Operand as Instruction); + + foreach (var ifInstruction in ifTargets) + { + if (ifInstruction == null) throw new Exception("The impossible happened"); + + targetInstructions.TryAdd(ifInstruction.Offset, ilGenerator.DefineLabel()); + } + + var switchTargets = instructions + .Where(i => i.Operand is Instruction[]) + .Select(i => i.Operand as Instruction[]); + + foreach (var switchInstructions in switchTargets) + { + if (switchInstructions == null) throw new Exception("The impossible happened"); + + foreach (var instruction in switchInstructions) + targetInstructions.TryAdd(instruction.Offset, ilGenerator.DefineLabel()); + } + + foreach (var instruction in instructions) + { + #if TRACE + Console.WriteLine(instruction); + #endif + + // EmitILForExceptionHandlers(ref _exceptionBlockLevel, ilGenerator, instruction, handlers); + + if (targetInstructions.TryGetValue(instruction.Offset, out var label)) + ilGenerator.MarkLabel(label); + + if (new []{ OpCodes.Endfilter, OpCodes.Endfinally }.Contains(instruction.OpCode)) continue; + + switch (instruction.OpCode.OperandType) + { + case OperandType.InlineNone: + ilGenerator.Emit(instruction.OpCode); + break; + case OperandType.InlineI: + ilGenerator.Emit(instruction.OpCode, (int)instruction.Operand); + break; + case OperandType.InlineI8: + ilGenerator.Emit(instruction.OpCode, (long)instruction.Operand); + break; + case OperandType.ShortInlineI: + if (instruction.OpCode == OpCodes.Ldc_I4_S) + ilGenerator.Emit(instruction.OpCode, (sbyte)instruction.Operand); + else + ilGenerator.Emit(instruction.OpCode, (byte)instruction.Operand); + break; + case OperandType.InlineR: + ilGenerator.Emit(instruction.OpCode, (double)instruction.Operand); + break; + case OperandType.ShortInlineR: + ilGenerator.Emit(instruction.OpCode, (float)instruction.Operand); + break; + case OperandType.InlineString: + ilGenerator.Emit(instruction.OpCode, (string)instruction.Operand); + break; + case OperandType.ShortInlineBrTarget: + case OperandType.InlineBrTarget: + var targetLabel = targetInstructions[(instruction.Operand as Instruction).Offset]; + + var opCode = instruction.OpCode; + + // Offset values could change and not be short form anymore + if (opCode == OpCodes.Br_S) opCode = OpCodes.Br; + else if (opCode == OpCodes.Brfalse_S) opCode = OpCodes.Brfalse; + else if (opCode == OpCodes.Brtrue_S) opCode = OpCodes.Brtrue; + else if (opCode == OpCodes.Beq_S) opCode = OpCodes.Beq; + else if (opCode == OpCodes.Bge_S) opCode = OpCodes.Bge; + else if (opCode == OpCodes.Bgt_S) opCode = OpCodes.Bgt; + else if (opCode == OpCodes.Ble_S) opCode = OpCodes.Ble; + else if (opCode == OpCodes.Blt_S) opCode = OpCodes.Blt; + else if (opCode == OpCodes.Bne_Un_S) opCode = OpCodes.Bne_Un; + else if (opCode == OpCodes.Bge_Un_S) opCode = OpCodes.Bge_Un; + else if (opCode == OpCodes.Bgt_Un_S) opCode = OpCodes.Bgt_Un; + else if (opCode == OpCodes.Ble_Un_S) opCode = OpCodes.Ble_Un; + else if (opCode == OpCodes.Blt_Un_S) opCode = OpCodes.Blt_Un; + else if (opCode == OpCodes.Leave_S) opCode = OpCodes.Leave; + + // 'Leave' instructions must be emitted if we are rewriting an async method. + // Otherwise the rewritten method will always start from the beginning every time. + if (opCode == OpCodes.Leave) + { + ilGenerator.Emit(opCode, targetLabel); + continue; + } + + // Check if 'Leave' opcode is being used in an exception block, + // only emit it if that's not the case + if (opCode == OpCodes.Leave && _exceptionBlockLevel > 0) continue; + + ilGenerator.Emit(opCode, targetLabel); + break; + case OperandType.InlineSwitch: + var switchInstructions = (Instruction[])instruction.Operand; + var targetLabels = new Label[switchInstructions.Length]; + for (var i = 0; i < switchInstructions.Length; i++) + targetLabels[i] = targetInstructions[switchInstructions[i].Offset]; + ilGenerator.Emit(instruction.OpCode, targetLabels); + break; + case OperandType.ShortInlineVar: + case OperandType.InlineVar: + var index = 0; + if (instruction.OpCode.Name.Contains("loc")) + { + index = ((LocalVariableInfo)instruction.Operand).LocalIndex; + } + else + { + index = ((ParameterInfo)instruction.Operand).Position; + index += 1; + } + + if (instruction.OpCode.OperandType == OperandType.ShortInlineVar) + ilGenerator.Emit(instruction.OpCode, (byte)index); + else + ilGenerator.Emit(instruction.OpCode, (ushort)index); + break; + case OperandType.InlineTok: + case OperandType.InlineType: + case OperandType.InlineField: + case OperandType.InlineMethod: + var memberInfo = (MemberInfo)instruction.Operand; + if (memberInfo.MemberType == MemberTypes.Field) + { + if (instruction.OpCode == OpCodes.Ldflda && ((FieldInfo)instruction.Operand).DeclaringType.Name == stateMachine.Name) + { + var name = ((FieldInfo) instruction.Operand).Name; + + if (fieldDict.TryGetValue(name, out var field)) + { + ilGenerator.Emit(OpCodes.Ldflda, field); + continue; + } + else + { + throw new Exception($"Cannot find field {name}"); + } + } + + if (instruction.OpCode == OpCodes.Stfld && ((FieldInfo) instruction.Operand).DeclaringType.Name == stateMachine.Name) + { + var name = ((FieldInfo) instruction.Operand).Name; + + if (fieldDict.TryGetValue(name, out var field)) + { + ilGenerator.Emit(OpCodes.Stfld, field); + continue; + } + else + { + throw new Exception($"Cannot find field {name}"); + } + } + + if (instruction.OpCode == OpCodes.Ldfld && ((FieldInfo) instruction.Operand).DeclaringType.Name == stateMachine.Name) + { + var name = ((FieldInfo) instruction.Operand).Name; + + if (fieldDict.TryGetValue(name, out var field)) + { + ilGenerator.Emit(OpCodes.Ldfld, field); + continue; + } + else + { + throw new Exception($"Cannot find field {name}"); + } + } + + ilGenerator.Emit(instruction.OpCode, memberInfo as FieldInfo); + } + else if (memberInfo.MemberType == MemberTypes.TypeInfo + || memberInfo.MemberType == MemberTypes.NestedType) + { + if (instruction.OpCode == OpCodes.Constrained) + { + _constrainedType = memberInfo as TypeInfo; + continue; + } + + ilGenerator.Emit(instruction.OpCode, memberInfo as TypeInfo); + } + else if (memberInfo.MemberType == MemberTypes.Constructor) + { + throw new NotSupportedException(); + // var constructorInfo = memberInfo as ConstructorInfo; + // + // if (constructorInfo.InCoreLibrary()) + // { + // // Don't attempt to rewrite inaccessible constructors in System.Private.CoreLib/mscorlib + // if (ShouldForward(constructorInfo)) goto forward; + // } + // + // if (instruction.OpCode == OpCodes.Call) + // { + // ilGenerator.Emit(OpCodes.Ldtoken, (ConstructorInfo)memberInfo); + // ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForDirectCall(constructorInfo)); + // return; + // } + // + // if (instruction.OpCode == OpCodes.Newobj) + // { + // //ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForConstructor(constructorInfo, instruction.OpCode, constructorInfo.IsForValueType())); + // ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForObjectInitialization(constructorInfo)); + // return; + // } + // + // if (instruction.OpCode == OpCodes.Ldftn) + // { + // //ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForConstructor(constructorInfo, instruction.OpCode, constructorInfo.IsForValueType())); + // ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForDirectLoad(constructorInfo)); + // return; + // } + // + // // If we get here, then we haven't accounted for an opcode. + // // Throw exception to make this obvious. + // throw new NotSupportedException(instruction.OpCode.Name); + // + // forward: + // ilGenerator.Emit(instruction.OpCode, constructorInfo); + } + else if (memberInfo.MemberType == MemberTypes.Method) + { + var methodInfo = memberInfo as MethodInfo; + + if (methodInfo.InCoreLibrary()) + { + // Don't attempt to rewrite inaccessible methods in System.Private.CoreLib/mscorlib + if (ShouldForward(methodInfo)) goto forward; + } + + if (instruction.OpCode == OpCodes.Call) + { + if (methodInfo.DeclaringType.Name == nameof(AsyncTaskMethodBuilder) && methodInfo.Name == nameof(AsyncTaskMethodBuilder.AwaitUnsafeOnCompleted)) + { + // The call is to AwaitUnsafeOnCompleted which must have the correct generic arguments + var taskAwaiterArgument = methodInfo.GetGenericArguments()[0]; + methodInfo = methodInfo.GetGenericMethodDefinition().MakeGenericMethod(taskAwaiterArgument, tb); + } + else if (methodInfo.IsGenericMethod + && methodInfo.DeclaringType.IsGenericType + && methodInfo.DeclaringType.GetGenericTypeDefinition() == typeof(AsyncTaskMethodBuilder<>) + && methodInfo.Name == "AwaitUnsafeOnCompleted") + { + // The call is to AwaitUnsafeOnCompleted which must have the correct generic arguments + var taskAwaiterArgument = methodInfo.GetGenericArguments()[0]; + methodInfo = methodInfo.GetGenericMethodDefinition().MakeGenericMethod(taskAwaiterArgument, tb); + } + + ilGenerator.Emit(OpCodes.Call, methodInfo); + // ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForDirectCall(methodInfo)); + continue; + } + + if (instruction.OpCode == OpCodes.Callvirt) + { + if (_constrainedType != null) + { + ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForVirtualCall(methodInfo, _constrainedType)); + _constrainedType = null; + continue; + } + + ilGenerator.Emit(OpCodes.Callvirt, methodInfo); + continue; + } + + if (instruction.OpCode == OpCodes.Ldftn) + { + ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForDirectLoad(methodInfo)); + continue; + } + + if (instruction.OpCode == OpCodes.Ldvirtftn) + { + ilGenerator.Emit(OpCodes.Call, Stubs.GenerateStubForVirtualLoad(methodInfo)); + continue; + } + + forward: + ilGenerator.Emit(instruction.OpCode, methodInfo); + } + else + { + throw new NotSupportedException(); + } + break; + default: + throw new NotSupportedException(instruction.OpCode.OperandType.ToString()); + } + } + + + ilGenerator.Emit(OpCodes.Ret); + }); + + return tb.CreateTypeInfo(); + } + private void EmitILForExceptionHandlers(ILGenerator ilGenerator, Instruction instruction, IReadOnlyCollection handlers) { var tryBlocks = handlers.Where(h => h.TryStart == instruction.Offset).GroupBy(h => h.TryEnd); diff --git a/src/Pose/Pose.csproj b/src/Pose/Pose.csproj index d37bf2d..05c189e 100644 --- a/src/Pose/Pose.csproj +++ b/src/Pose/Pose.csproj @@ -7,11 +7,15 @@ false full - + TRACE + + + + \ No newline at end of file diff --git a/src/Sandbox/Program.cs b/src/Sandbox/Program.cs index 2bb05b0..07dfe52 100644 --- a/src/Sandbox/Program.cs +++ b/src/Sandbox/Program.cs @@ -45,9 +45,9 @@ public static async Task DoWork2Async() public static async Task DoWork3Async() { - Console.WriteLine("Here"); - await Task.Delay(1000); - Console.WriteLine("Here 2"); + Console.WriteLine("Here 3.1"); + await Task.Delay(10); + Console.WriteLine("Here 3.2"); } public static async Task DoWork1Async() @@ -55,19 +55,18 @@ public static async Task DoWork1Async() return GetInt(); } - private static Type GetStateMachineType(string methodName) + private static Type GetStateMachineType(MethodBase method) { - var stateMachineType = typeof(TOwningType) - .GetMethod(methodName) + var stateMachineType = method ?.GetCustomAttribute() ?.StateMachineType; return stateMachineType; } - private static (MethodInfo StartMethod, MethodInfo CreateMethod, PropertyInfo TaskProperty, MethodInfo OriginalMethod) GetMethods(string methodName) + private static (MethodInfo StartMethod, MethodInfo CreateMethod, PropertyInfo TaskProperty, MethodInfo OriginalMethod) GetMethods(MethodInfo method) { - var originalMethod = typeof(TOwningType).GetMethod(methodName) ?? throw new Exception("Cannot get original method"); + var originalMethod = method; var originalMethodReturnType = originalMethod.ReturnType.IsGenericType ? originalMethod.ReturnType.GetGenericArguments()[0] @@ -91,11 +90,11 @@ private static (MethodInfo StartMethod, MethodInfo CreateMethod, PropertyInfo Ta return (startMethod, createMethod, taskProperty, originalMethod); } - private static void RunAsync(string methodName) where TReturnType : class + private static void RunAsync(Type owningType, MethodInfo method) where TReturnType : class { - var (startMethod, createMethod, taskProperty, _) = GetMethods(methodName); + var (startMethod, createMethod, taskProperty, _) = GetMethods(method); - var stateMachineType = GetStateMachineType(methodName); + var stateMachineType = GetStateMachineType(method); var rewrittenStateMachine = RewriteMoveNext(stateMachineType); var stateMachineInstance = Activator.CreateInstance(rewrittenStateMachine); @@ -113,11 +112,11 @@ private static void RunAsync(string methodName) where var task = taskProperty.GetValue(builder) as TReturnType ?? throw new Exception("Cannot get task"); } - private static MethodBase RewriteAsync(string methodName) + private static MethodBase RewriteAsync(Type owningType, MethodInfo method) { - var (startMethod, createMethod, taskProperty, originalMethod) = GetMethods(methodName); + var (startMethod, createMethod, taskProperty, originalMethod) = GetMethods(method); - var stateMachine = GetStateMachineType(methodName); + var stateMachine = GetStateMachineType(method); var typeWithRewrittenMoveNext = RewriteMoveNext(stateMachine); var moveNextMethodInfo = typeWithRewrittenMoveNext.GetMethod(nameof(IAsyncStateMachine.MoveNext)); @@ -128,7 +127,7 @@ private static MethodBase RewriteAsync(string methodName) name: StubHelper.CreateStubNameFromMethod("impl", originalMethod), returnType: originalMethod.ReturnType, parameterTypes: originalMethod.GetParameters().Select(p => p.ParameterType).ToArray(), - m: typeof(Program).Module, + m: originalMethod.Module, skipVisibility: true ); @@ -141,6 +140,7 @@ private static MethodBase RewriteAsync(string methodName) { if (locals[0].LocalType == stateMachine) { + // References to the original state machine must be re-targeted to the rewritten state machine ilGenerator.DeclareLocal(typeWithRewrittenMoveNext, local.IsPinned); } else @@ -178,6 +178,7 @@ private static MethodBase RewriteAsync(string methodName) ilGenerator.Emit(OpCodes.Ret); +#if TRACE var ilBytes = ilGenerator.GetILBytes(); var browsableDynamicMethod = new BrowsableDynamicMethod(rewrittenOriginalMethod, new DynamicMethodBody(ilBytes, locals)); Console.WriteLine("\n" + rewrittenOriginalMethod); @@ -186,27 +187,12 @@ private static MethodBase RewriteAsync(string methodName) { Console.WriteLine(instruction); } +#endif return rewrittenOriginalMethod; - - // - // var instance = Activator.CreateInstance(copyType); - // builderField.SetValue(instance, AsyncTaskMethodBuilder.Create()); - // stateField.SetValue(instance, -1); - // var startMethod = typeof(AsyncTaskMethodBuilder).GetMethod(nameof(AsyncTaskMethodBuilder.Start)) ?? throw new Exception("Cannot get start method"); - // var genericMethod = startMethod.MakeGenericMethod(copyType); - // genericMethod.Invoke(builderField.GetValue(instance), new object[] { instance }); - - // var builder = builderField.GetValue(instance); - // var taskProperty = typeof(AsyncTaskMethodBuilder).GetProperty("Task") ?? throw new Exception("Cannot get task property"); - // var task = taskProperty.GetValue(builder) as Task ?? throw new Exception("Cannot get task"); - // var result = task.Result; - // - // Console.WriteLine(result); } throw new Exception("Failed to rewrite async method"); - // Console.WriteLine("SUCCESS!"); } public static async Task Main(string[] args) @@ -228,11 +214,20 @@ public static async Task Main(string[] args) try { - RunAsync>(nameof(DoWork2Async)); - // RunAsync(nameof(DoWork3Async)); - var task = (MethodInfo) RewriteAsync(nameof(DoWork2Async)); - var @delegate = task.CreateDelegate(typeof(Func>)); + var asyncMethod = typeof(Program).GetMethod(nameof(DoWork2Async)); + var methodRewriter = MethodRewriter.CreateRewriter(asyncMethod, false); + var methodBase = (MethodInfo)methodRewriter.RewriteAsync(); + var @delegate = methodBase.CreateDelegate(typeof(Func>)); var result = @delegate.DynamicInvoke(new object[0]) as Task; + + // RunAsync>(typeof(Program), typeof(Program).GetMethod(nameof(DoWork2Async))); + // Console.WriteLine("---"); + // RunAsync(typeof(Program), typeof(Program).GetMethod(nameof(DoWork3Async))); + // Console.WriteLine("---"); + // var task = (MethodInfo) RewriteAsync(typeof(Program), typeof(Program).GetMethod(nameof(DoWork2Async))); + // var @delegate = task.CreateDelegate(typeof(Func>)); + // var result = @delegate.DynamicInvoke(new object[0]) as Task; + // Console.WriteLine("---"); // @delegate.DynamicInvoke(new object[0]); // var result = task.Invoke(null, new object[] { }); Console.WriteLine(result.Result); @@ -293,9 +288,6 @@ public static Type RewriteMoveNext(Type stateMachine) var ilGenerator = meth.GetILGenerator(); var instructions = m.GetInstructions(); - ilGenerator.Emit(OpCodes.Ldstr, "Hello World"); - ilGenerator.Emit(OpCodes.Call, typeof(Console).GetMethod("WriteLine", new Type[] { typeof(string) })); - foreach (var clause in methodBody.ExceptionHandlingClauses) { var handler = new ExceptionHandler @@ -616,9 +608,6 @@ public static Type RewriteMoveNext(Type stateMachine) ilGenerator.Emit(OpCodes.Ret); - - Console.WriteLine(); - Console.WriteLine(); }); return tb.CreateType(); diff --git a/test/Pose.Tests/IL/AsyncMethodRewriterTests.cs b/test/Pose.Tests/IL/AsyncMethodRewriterTests.cs new file mode 100644 index 0000000..1b8b9d4 --- /dev/null +++ b/test/Pose.Tests/IL/AsyncMethodRewriterTests.cs @@ -0,0 +1,121 @@ +using System; +using System.Reflection; +using System.Threading.Tasks; +using FluentAssertions; +using Pose.IL; +using Xunit; + +namespace Pose.Tests +{ + public class AsyncMethodRewriterTests + { + private const int AsyncMethodReturnValue = 1; + + private static async Task AsyncMethodWithReturnValue() + { + await Task.Delay(1000); + return AsyncMethodReturnValue; + } + + private static readonly MethodInfo AsyncMethodWithReturnValueInfo = typeof(AsyncMethodRewriterTests).GetMethod(nameof(AsyncMethodWithReturnValue), BindingFlags.Static | BindingFlags.NonPublic); + + private static async Task AsyncMethodWithoutReturnValue() + { + await Task.Delay(0); + } + + private static readonly MethodInfo AsyncMethodWithoutReturnValueInfo = typeof(AsyncMethodRewriterTests).GetMethod(nameof(AsyncMethodWithoutReturnValue), BindingFlags.Static | BindingFlags.NonPublic); + + private static async void AsyncVoidMethod() + { + await Task.Delay(0); + } + + private static readonly MethodInfo AsyncVoidMethodInfo = typeof(AsyncMethodRewriterTests).GetMethod(nameof(AsyncVoidMethod), BindingFlags.Static | BindingFlags.NonPublic); + + [Fact] + public void Can_rewrite_async_method_with_return_value() + { + // Arrange + var methodRewriter = MethodRewriter.CreateRewriter(AsyncMethodWithReturnValueInfo, false); + + // Act + Action act = () => methodRewriter.RewriteAsync(); + + // Assert + act.Should().NotThrow(); + } + + [Fact] + public void Can_run_async_method_with_return_value() + { + // Arrange + var methodRewriter = MethodRewriter.CreateRewriter(AsyncMethodWithReturnValueInfo, false); + var rewrittenMethod = (MethodInfo) methodRewriter.RewriteAsync(); + var sut = rewrittenMethod.CreateDelegate(typeof(Func>)); + + // Act + Func> runner = () => sut.DynamicInvoke(Array.Empty()) as Task; + + // Assert + runner.Should().NotThrowAsync().Result.Which.Should().Be(AsyncMethodReturnValue, because: "that is the return value of the async method"); + } + + [Fact] + public void Can_rewrite_async_method_without_return_value() + { + // Arrange + var methodRewriter = MethodRewriter.CreateRewriter(AsyncMethodWithoutReturnValueInfo, false); + + // Act + Action act = () => methodRewriter.RewriteAsync(); + + // Assert + act.Should().NotThrow(); + } + + [Fact] + public void Can_run_async_method_without_return_value() + { + // Arrange + var methodRewriter = MethodRewriter.CreateRewriter(AsyncMethodWithoutReturnValueInfo, false); + var rewrittenMethod = (MethodInfo) methodRewriter.RewriteAsync(); + var sut = rewrittenMethod.CreateDelegate(typeof(Func)); + + // Act + Func runner = () => sut.DynamicInvoke(Array.Empty()) as Task; + + // Assert + runner.Should().NotThrowAsync(); + } + + [Fact] + public void Can_rewrite_async_void_method() + { + // Arrange + var methodRewriter = MethodRewriter.CreateRewriter(AsyncVoidMethodInfo, false); + + // Act + Action act = () => methodRewriter.RewriteAsync(); + + // Assert + act.Should().NotThrow(); + } + + [Fact] + public void Can_run_async_void_method() + { + // Arrange + var methodRewriter = MethodRewriter.CreateRewriter(AsyncVoidMethodInfo, false); + var rewrittenMethod = (MethodInfo) methodRewriter.RewriteAsync(); + var sut = rewrittenMethod.CreateDelegate(typeof(Action)); + + // Act + Func runner = () => sut.DynamicInvoke(Array.Empty()) as Task; + + // Assert + runner.Should().NotThrowAsync(); + } + + } +} \ No newline at end of file diff --git a/test/Pose.Tests/Pose.Tests.csproj b/test/Pose.Tests/Pose.Tests.csproj index 4addad8..25ce702 100644 --- a/test/Pose.Tests/Pose.Tests.csproj +++ b/test/Pose.Tests/Pose.Tests.csproj @@ -2,6 +2,7 @@ netcoreapp2.0;netcoreapp3.0;netcoreapp3.1;net47;net48;net6.0;net7.0;net8.0 + false