Skip to content

Commit

Permalink
Successfully rewrite async method
Browse files Browse the repository at this point in the history
  • Loading branch information
nikopede committed Dec 16, 2024
1 parent 11ee97e commit f567493
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/Pose/Pose.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<PropertyGroup Condition=" '$(Configuration)' == 'Debug' ">
<DebugSymbols>false</DebugSymbols>
<DebugType>full</DebugType>
<DefineConstants>TRACE</DefineConstants>
<DefineConstants></DefineConstants>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Mono.Reflection.Core" Version="1.1.1" />
Expand Down
75 changes: 24 additions & 51 deletions src/Sandbox/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ private static Type GetStateMachineType<TOwningType>(string methodName)

return stateMachineType;
}
private static void RunAsync<TOwningType, TReturnType>(string methodName) where TReturnType : class

private static (MethodInfo StartMethod, MethodInfo CreateMethod, PropertyInfo TaskProperty, MethodInfo OriginalMethod) GetMethods<TOwningType>(string methodName)
{
var originalMethod = typeof(TOwningType).GetMethod(methodName) ?? throw new Exception("Cannot get original method");
var originalMethodReturnType =
Expand All @@ -83,14 +83,20 @@ private static void RunAsync<TOwningType, TReturnType>(string methodName) where
? typeof(AsyncTaskMethodBuilder).GetProperty(taskPropertyName)
: typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetProperty(taskPropertyName)) ?? throw new Exception($"Cannot get {taskPropertyName} property");

var stateMachineType = GetStateMachineType<TOwningType>(methodName);
var rewrittenStateMachine = RewriteMoveNext(stateMachineType);

const string createMethodName = nameof(AsyncTaskMethodBuilder<int>.Create);
var createMethod = (originalMethodReturnType == typeof(void)
? typeof(AsyncTaskMethodBuilder).GetMethod(createMethodName)
: typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetMethod(createMethodName)) ?? throw new Exception($"Cannot get {createMethodName} method");

return (startMethod, createMethod, taskProperty, originalMethod);
}

private static void RunAsync<TOwningType, TReturnType>(string methodName) where TReturnType : class
{
var (startMethod, createMethod, taskProperty, _) = GetMethods<TOwningType>(methodName);

var stateMachineType = GetStateMachineType<TOwningType>(methodName);
var rewrittenStateMachine = RewriteMoveNext(stateMachineType);
var stateMachineInstance = Activator.CreateInstance(rewrittenStateMachine);

var builderField = rewrittenStateMachine.GetField("<>t__builder") ?? throw new Exception("Cannot get builder field");
Expand All @@ -99,43 +105,18 @@ private static void RunAsync<TOwningType, TReturnType>(string methodName) where
var stateField = rewrittenStateMachine.GetField("<>1__state") ?? throw new Exception("Cannot get state field");
stateField.SetValue(stateMachineInstance, -1);

// var startMethod = typeof(AsyncTaskMethodBuilder<int>).GetMethod(nameof(AsyncTaskMethodBuilder<int>.Start)) ?? throw new Exception("Cannot get start method");
var genericMethod = startMethod.MakeGenericMethod(rewrittenStateMachine);
var builder = builderField.GetValue(stateMachineInstance);

genericMethod.Invoke(builder, new object[] { stateMachineInstance });

// var taskProperty = typeof(AsyncTaskMethodBuilder<int>).GetProperty("Task") ?? throw new Exception("Cannot get task property");
var task = taskProperty.GetValue(builder) as TReturnType ?? throw new Exception("Cannot get task");
}

private static MethodBase RewriteAsync<TOwningType>(string methodName)
{
var originalMethod = typeof(TOwningType).GetMethod(methodName) ?? throw new Exception("Cannot get original method");
var originalMethodReturnType =
originalMethod.ReturnType.IsGenericType
? originalMethod.ReturnType.GetGenericArguments()[0]
: typeof(void);

const string startMethodName = nameof(AsyncTaskMethodBuilder<int>.Start);
var startMethod = (originalMethodReturnType == typeof(void)
? typeof(AsyncTaskMethodBuilder).GetMethod(startMethodName)
: typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetMethod(startMethodName)) ?? throw new Exception($"Cannot get {startMethodName} method");

const string taskPropertyName = nameof(AsyncTaskMethodBuilder<int>.Task);
var taskProperty = (originalMethodReturnType == typeof(void)
? typeof(AsyncTaskMethodBuilder).GetProperty(taskPropertyName)
: typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetProperty(taskPropertyName)) ?? throw new Exception($"Cannot get {taskPropertyName} property");

var stateMachineType = GetStateMachineType<TOwningType>(methodName);
var rewrittenStateMachine = RewriteMoveNext(stateMachineType);

const string createMethodName = nameof(AsyncTaskMethodBuilder<int>.Create);
var createMethod = (originalMethodReturnType == typeof(void)
? typeof(AsyncTaskMethodBuilder).GetMethod(createMethodName)
: typeof(AsyncTaskMethodBuilder<>).MakeGenericType(originalMethodReturnType).GetMethod(createMethodName)) ?? throw new Exception($"Cannot get {createMethodName} method");
var (startMethod, createMethod, taskProperty, originalMethod) = GetMethods<TOwningType>(methodName);


var stateMachine = GetStateMachineType<TOwningType>(methodName);
var typeWithRewrittenMoveNext = RewriteMoveNext(stateMachine);

Expand All @@ -147,33 +128,29 @@ private static MethodBase RewriteAsync<TOwningType>(string methodName)
name: StubHelper.CreateStubNameFromMethod("impl", originalMethod),
returnType: originalMethod.ReturnType,
parameterTypes: originalMethod.GetParameters().Select(p => p.ParameterType).ToArray(),
m: StubHelper.GetOwningModule(),
m: typeof(Program).Module,
skipVisibility: true
);

var methodBody = moveNextMethodInfo.GetMethodBody() ?? throw new MethodRewriteException($"Method {moveNextMethodInfo.Name} does not have a body");
var methodBody = originalMethod.GetMethodBody() ?? throw new MethodRewriteException($"Method {moveNextMethodInfo.Name} does not have a body");
var locals = methodBody.LocalVariables;

var ilGenerator = rewrittenOriginalMethod.GetILGenerator();

var index = 0;
foreach (var local in locals)
{
if (index == 3)
if (locals[0].LocalType == stateMachine)
{
ilGenerator.DeclareLocal(stateMachine, local.IsPinned);
ilGenerator.DeclareLocal(typeWithRewrittenMoveNext, local.IsPinned);
}
else
{
ilGenerator.DeclareLocal(local.LocalType, local.IsPinned);
}

index++;
}

ilGenerator.Emit(OpCodes.Nop);

ilGenerator.Emit(OpCodes.Newobj, typeWithRewrittenMoveNext);
var constructorInfo = typeWithRewrittenMoveNext.GetConstructors()[0];
ilGenerator.Emit(OpCodes.Newobj, constructorInfo);
ilGenerator.Emit(OpCodes.Stloc_0);
ilGenerator.Emit(OpCodes.Ldloc_0);

Expand All @@ -191,7 +168,8 @@ private static MethodBase RewriteAsync<TOwningType>(string methodName)
ilGenerator.Emit(OpCodes.Ldflda, builderField);
ilGenerator.Emit(OpCodes.Ldloca_S, 0);

ilGenerator.Emit(OpCodes.Call, startMethod);
var genericMethod = startMethod.MakeGenericMethod(typeWithRewrittenMoveNext);
ilGenerator.Emit(OpCodes.Call, genericMethod);

ilGenerator.Emit(OpCodes.Ldloc_0);
ilGenerator.Emit(OpCodes.Ldflda, builderField);
Expand Down Expand Up @@ -250,14 +228,14 @@ public static async Task Main(string[] args)

try
{
// RunAsync<Program, Task<int>>(nameof(DoWork2Async));
RunAsync<Program, Task<int>>(nameof(DoWork2Async));
// RunAsync<Program, Task>(nameof(DoWork3Async));
var task = (MethodInfo) RewriteAsync<Program>(nameof(DoWork2Async));
var @delegate = task.CreateDelegate(typeof(Func<Task<int>>));
var result = @delegate.DynamicInvoke(new object[0]);
var result = @delegate.DynamicInvoke(new object[0]) as Task<int>;
// @delegate.DynamicInvoke(new object[0]);
// var result = task.Invoke(null, new object[] { });
Console.WriteLine(result);
Console.WriteLine(result.Result);
}
catch (Exception e)
{
Expand Down Expand Up @@ -286,7 +264,6 @@ public static Type RewriteMoveNext(Type stateMachine)
{
var ab = AssemblyBuilder.DefineDynamicAssembly(new AssemblyName("AsyncAssembly"), AssemblyBuilderAccess.RunAndCollect);
var mb = ab.DefineDynamicModule("AsyncModule");
// var containerBuilder = mb.DefineType("AsyncMethodContainer", TypeAttributes.Class | TypeAttributes.Public);
var tb = mb.DefineType($"{stateMachine.Name}__Rewrite", TypeAttributes.Class | TypeAttributes.Public | TypeAttributes.Sealed);
tb.AddInterfaceImplementation(typeof(IAsyncStateMachine));

Expand All @@ -301,17 +278,13 @@ public static Type RewriteMoveNext(Type stateMachine)
.ToList()
.ForEach(m =>
{
Console.WriteLine(m.Name);
// Console.WriteLine(m.Name);
var _exceptionBlockLevel = 0;
TypeInfo _constrainedType = null;

var parameters = m.GetParameters().Select(p => p.ParameterType).ToArray();
var meth = tb.DefineMethod(m.Name, MethodAttributes.Public | MethodAttributes.Virtual, m.ReturnType, parameters);

// var methodRewriter = MethodRewriter.CreateRewriter(m, false);
// var rewritten = methodRewriter.Rewrite();

// generator.Emit(OpCodes.Call, (MethodInfo) rewritten);
var methodBody = m.GetMethodBody() ?? throw new MethodRewriteException($"Method {m.Name} does not have a body");
var locals = methodBody.LocalVariables;
var targetInstructions = new Dictionary<int, Label>();
Expand Down
4 changes: 4 additions & 0 deletions src/Sandbox/Sandbox.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
<TargetFrameworks>netcoreapp2.0;netcoreapp3.0;net6.0;net7.0;net8.0</TargetFrameworks>
</PropertyGroup>

<PropertyGroup Condition=" '$(Configuration)' == 'Debug' ">
<DefineConstants />
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Pose\Pose.csproj" />
</ItemGroup>
Expand Down

0 comments on commit f567493

Please sign in to comment.