Skip to content

Commit

Permalink
SUCCESS - middleware works!!
Browse files Browse the repository at this point in the history
  • Loading branch information
aritchie committed Jun 4, 2024
1 parent db0fbb4 commit 9fc1e6b
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 85 deletions.
2 changes: 1 addition & 1 deletion Sample/MyRequestMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Sample;
// [RegisterMiddleware]
public class MyRequestMiddleware(AppSqliteConnection conn) : IRequestMiddleware<MyMessageRequest, MyMessageResponse>
{
public async Task<MyMessageResponse> Process(MyMessageRequest request, Func<Task<MyMessageResponse>> next, CancellationToken cancellationToken)
public async Task<MyMessageResponse> Process(MyMessageRequest request, RequestHandlerDelegate<MyMessageResponse> next, CancellationToken cancellationToken)
{
var sw = new Stopwatch();
sw.Start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class CacheAttribute : Attribute

public class ConnectivityCacheRequestMiddleware<TRequest, TResult>(IConnectivity connectivity, IFileSystem fileSystem) : IRequestMiddleware<TRequest, TResult> where TRequest : IRequest<TResult>
{
public async Task<TResult> Process(TRequest request, Func<Task<TResult>> next, CancellationToken cancellationToken)
public async Task<TResult> Process(TRequest request, RequestHandlerDelegate<TResult> next, CancellationToken cancellationToken)
{
var config = typeof(TRequest).GetCustomAttribute<CacheAttribute>();
if (config == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class TimedLoggingMiddlewareConfig

public class TimedLoggingRequestMiddleware<TRequest, TResult>(ILogger<TRequest> logger, TimedLoggingMiddlewareConfig config) : IRequestMiddleware<TRequest, TResult> where TRequest : IRequest<TResult>
{
public async Task<TResult> Process(TRequest request, Func<Task<TResult>> next, CancellationToken cancellationToken)
public async Task<TResult> Process(TRequest request, RequestHandlerDelegate<TResult> next, CancellationToken cancellationToken)
{
var sw = new Stopwatch();
sw.Start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class UserExceptionRequestMiddlewareConfig

public class UserExceptionRequestMiddleware<TRequest, TResult>(ILogger<TRequest> logger, UserExceptionRequestMiddlewareConfig config) : IRequestMiddleware<TRequest, TResult> where TRequest : IRequest<TResult>
{
public async Task<TResult> Process(TRequest request, Func<Task<TResult>> next, CancellationToken cancellationToken)
public async Task<TResult> Process(TRequest request, RequestHandlerDelegate<TResult> next, CancellationToken cancellationToken)
{
var result = default(TResult);
try
Expand Down
3 changes: 2 additions & 1 deletion src/Shiny.Mediator/IRequestMiddleware.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using Shiny.Mediator;


public delegate Task<TResult> RequestHandlerDelegate<TResult>();
public interface IRequestMiddleware<in TRequest, TResult> where TRequest : IRequest<TResult>
{
Task<TResult> Process(TRequest request, Func<Task<TResult>> next, CancellationToken cancellationToken);
Task<TResult> Process(TRequest request, RequestHandlerDelegate<TResult> next, CancellationToken cancellationToken);
}
119 changes: 47 additions & 72 deletions src/Shiny.Mediator/Impl/DefaultRequestSender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,91 +7,66 @@ namespace Shiny.Mediator.Impl;

public class DefaultRequestSender(IServiceProvider services) : IRequestSender
{
public async Task Send<TRequest>(TRequest request, CancellationToken cancellationToken) where TRequest : IRequest
public async Task Send(IRequest request, CancellationToken cancellationToken)
{
using var scope = services.CreateScope();
var handlers = scope.ServiceProvider.GetServices<IRequestHandler<TRequest>>().ToList();
AssertRequestHandlers(handlers.Count, request);

await this.ExecuteMiddleware(
scope,
(IRequest<Unit>)request,
async () =>
{
await handlers
.First()
.Handle(request, cancellationToken)
.ConfigureAwait(false);
return Unit.Value;
},
cancellationToken
)
.ConfigureAwait(false);
// using var scope = services.CreateScope();
// var handlers = scope.ServiceProvider.GetServices<IRequestHandler<TRequest>>().ToList();
// AssertRequestHandlers(handlers.Count, request);
//
// await this.ExecuteMiddleware(
// scope,
// (IRequest<Unit>)request,
// async () =>
// {
// await handlers
// .First()
// .Handle(request, cancellationToken)
// .ConfigureAwait(false);
// return Unit.Value;
// },
// cancellationToken
// )
// .ConfigureAwait(false);
throw new BadImageFormatException();
}


public async Task<TResult> Request<TResult>(IRequest<TResult> request, CancellationToken cancellationToken = default)
{
var handlerType = typeof(IRequestHandler<,>).MakeGenericType(request.GetType(), typeof(TResult));
using var scope = services.CreateScope();
var handlers = scope.ServiceProvider.GetServices(handlerType).ToList();
AssertRequestHandlers(handlers.Count, request);

Func<Task<TResult>> execute = async () =>
{
var handler = handlers.First();
var handleMethod = handlerType.GetMethod("Handle", BindingFlags.Instance | BindingFlags.Public)!;
var resultTask = (Task<TResult>)handleMethod.Invoke(handler, [request, cancellationToken])!;
var result = await resultTask.ConfigureAwait(false);
return result;
};
var result = await this.ExecuteMiddleware(scope, request, execute, cancellationToken).ConfigureAwait(false);
var wrapperType = typeof(RequestWrapper<,>).MakeGenericType([request.GetType(), typeof(TResult)]);
var wrapperMethod = wrapperType.GetMethod("Handle", BindingFlags.Public | BindingFlags.Instance)!;
var wrapper = Activator.CreateInstance(wrapperType);
var task = (Task<TResult>)wrapperMethod.Invoke(wrapper, [services, request, cancellationToken])!;
var result = await task.ConfigureAwait(false);
return result;
}
}


async Task<TResult> ExecuteMiddleware<TRequest, TResult>(
IServiceScope scope,
TRequest request,
Func<Task<TResult>> initialExecute,
CancellationToken cancellationToken
) where TRequest : IRequest<TResult>
public class RequestWrapper<TRequest, TResult> where TRequest : IRequest<TResult>
{
public async Task<TResult> Handle(IServiceProvider services, TRequest request, CancellationToken cancellationToken)
{
var middlewareType = typeof(IRequestMiddleware<,>).MakeGenericType(request.GetType(), typeof(TResult));
var middlewareMethod = middlewareType.GetMethod("Process", BindingFlags.Instance | BindingFlags.Public)!;
var middlewares = scope.ServiceProvider.GetServices(middlewareType).ToList();
var pipeline = new List<Func<Task<TResult>>> { initialExecute };

// Unable to find seq points for method 'System.Runtime.CompilerServices.AsyncMethodBuilderCore:Start<Sample.MyRequestMiddleware/<Process>d__2> (Sample.MyRequestMiddleware/<Process>d__2&)', offset 0xfffffe88.
// we get the middleware reverse ordered from last to first so execution ordered properly
// middlewares.Reverse();
// foreach (var middleware in middlewares)
// {
// var pipelineAdd = () => (Task<TResult>)middlewareMethod.Invoke(middleware, [
// request,
// pipeline.Last(),
// cancellationToken
// ])!;
// pipeline.Add(pipelineAdd);
// }
//
// var result = await pipeline
// .Last()
// .Invoke()
// .ConfigureAwait(false);
// return result;
var handler = new RequestHandlerDelegate<TResult>(() => services
.GetRequiredService<IRequestHandler<TRequest, TResult>>()!
.Handle(request, cancellationToken)
);

var result = await services
.GetServices<IRequestMiddleware<TRequest, TResult>>()
.Reverse()
.Aggregate(
handler,
(next, middleware) => () => middleware.Process(
request,
next,
cancellationToken
)
)
.Invoke()
.ConfigureAwait(false);

var result = await initialExecute.Invoke().ConfigureAwait(false);
return result;
}


static void AssertRequestHandlers(int count, object request)
{
if (count == 0)
throw new InvalidOperationException("No request handler found for " + request.GetType().FullName);

if (count > 1)
throw new InvalidOperationException("More than 1 request handlers found for " + request.GetType().FullName);
}
}
2 changes: 1 addition & 1 deletion src/Shiny.Mediator/Impl/Mediator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ IEventPublisher eventPublisher
public Task<TResult> Request<TResult>(IRequest<TResult> request, CancellationToken cancellationToken = default)
=> requestSender.Request(request, cancellationToken);

public Task Send<TRequest>(TRequest request, CancellationToken cancellationToken = default) where TRequest : IRequest
public Task Send(IRequest request, CancellationToken cancellationToken = default)
=> requestSender.Send(request, cancellationToken);

public Task Publish<TEvent>(TEvent @event, CancellationToken cancellationToken = default) where TEvent : IEvent
Expand Down
6 changes: 3 additions & 3 deletions src/Shiny.Mediator/Infrastructure/IRequestSender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Task<TResult> Request<TResult>(
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Task Send<TRequest>(
TRequest request,
Task Send(
IRequest request,
CancellationToken cancellationToken = default
) where TRequest : IRequest;
);
}
8 changes: 4 additions & 4 deletions tests/Shiny.Mediator.Tests/MiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ public static class Executed
}
public class ConstrainedMiddleware : IRequestMiddleware<MiddlewareResultRequest, int>
{
public Task<int> Process(MiddlewareResultRequest request, Func<Task<int>> next, CancellationToken cancellationToken)
public Task<int> Process(MiddlewareResultRequest request, RequestHandlerDelegate<int> next, CancellationToken cancellationToken)
{
Executed.Constrained = true;
return next();
}
}

public class VariantRequestMiddleware<TRequest, TResponse> : IRequestMiddleware<TRequest, TResponse>
where TRequest : IRequest<TResponse>
public class VariantRequestMiddleware<TRequest, TResult> : IRequestMiddleware<TRequest, TResult>
where TRequest : IRequest<TResult>
{
public Task<TResponse> Process(TRequest request, Func<Task<TResponse>> next, CancellationToken cancellationToken)
public Task<TResult> Process(TRequest request, RequestHandlerDelegate<TResult> next, CancellationToken cancellationToken)
{
Executed.Variant = true;
return next();
Expand Down

0 comments on commit 9fc1e6b

Please sign in to comment.