Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve UseRequestTimeouts validation #2501

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 77 additions & 15 deletions src/ReverseProxy/Model/ProxyPipelineInitializerMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
#if NET8_0_OR_GREATER
using System.Threading;
using Microsoft.AspNetCore.Http.Timeouts;
using Microsoft.Extensions.Options;
#endif
using Microsoft.Extensions.Logging;
#if NET8_0_OR_GREATER
using Yarp.ReverseProxy.Configuration;
#endif
using Yarp.ReverseProxy.Utilities;

namespace Yarp.ReverseProxy.Model;
Expand All @@ -23,12 +22,23 @@ internal sealed class ProxyPipelineInitializerMiddleware
{
private readonly ILogger _logger;
private readonly RequestDelegate _next;
#if NET8_0_OR_GREATER
private readonly IOptionsMonitor<RequestTimeoutOptions> _timeoutOptions;
#endif

public ProxyPipelineInitializerMiddleware(RequestDelegate next,
ILogger<ProxyPipelineInitializerMiddleware> logger)
ILogger<ProxyPipelineInitializerMiddleware> logger
#if NET8_0_OR_GREATER
, IOptionsMonitor<RequestTimeoutOptions> timeoutOptions
#endif
)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_next = next ?? throw new ArgumentNullException(nameof(next));

#if NET8_0_OR_GREATER
_timeoutOptions = timeoutOptions ?? throw new ArgumentNullException(nameof(timeoutOptions));
#endif
}

public Task Invoke(HttpContext context)
Expand All @@ -47,19 +57,11 @@ public Task Invoke(HttpContext context)
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
return Task.CompletedTask;
}

#if NET8_0_OR_GREATER
// There's no way to detect the presence of the timeout middleware before this, only the options.
if (endpoint.Metadata.GetMetadata<RequestTimeoutAttribute>() != null
&& context.Features.Get<IHttpRequestTimeoutFeature>() == null
// The feature is skipped if the request is already canceled. We'll handle canceled requests later for consistency.
&& !context.RequestAborted.IsCancellationRequested)
{
Log.TimeoutNotApplied(_logger, route.Config.RouteId);
// Out of an abundance of caution, refuse the request rather than allowing it to proceed without the configured timeout.
throw new InvalidOperationException($"The timeout was not applied for route '{route.Config.RouteId}', ensure `IApplicationBuilder.UseRequestTimeouts()`"
+ " is called between `IApplicationBuilder.UseRouting()` and `IApplicationBuilder.UseEndpoints()`.");
}
EnsureRequestTimeoutPolicyIsAppliedCorrectly(context, endpoint, route);
#endif

var destinationsState = cluster.DestinationsState;
context.Features.Set<IReverseProxyFeature>(new ReverseProxyFeature
{
Expand Down Expand Up @@ -91,6 +93,66 @@ private async Task AwaitWithActivity(HttpContext context, Activity activity)
}
}

#if NET8_0_OR_GREATER
private void EnsureRequestTimeoutPolicyIsAppliedCorrectly(HttpContext context, Endpoint endpoint, RouteModel route)
{
// There's no way to detect the presence of the timeout middleware before this, only the options.
if (endpoint.Metadata.GetMetadata<RequestTimeoutAttribute>() is { } requestTimeout &&
context.Features.Get<IHttpRequestTimeoutFeature>() is null &&
// The feature is skipped if the request is already canceled. We'll handle canceled requests later for consistency.
!context.RequestAborted.IsCancellationRequested &&
// The policy may set the timeout to null / infinite.
TimeoutPolicyRequestedATimeoutBeSet(requestTimeout))
{
// A timeout should have been set.
// Out of an abundance of caution, refuse the request rather than allowing it to proceed without the configured timeout.
Throw(route);
}

void Throw(RouteModel route)
{
// The feature is skipped if the debugger is attached.
if (!Debugger.IsAttached)
{
Log.TimeoutNotApplied(_logger, route.Config.RouteId);

throw new InvalidOperationException(
$"The timeout was not applied for route '{route.Config.RouteId}', " +
"ensure `IApplicationBuilder.UseRequestTimeouts()` is called between " +
"`IApplicationBuilder.UseRouting()` and `IApplicationBuilder.UseEndpoints()`.");
}
}
}

private bool TimeoutPolicyRequestedATimeoutBeSet(RequestTimeoutAttribute requestTimeout)
{
if (requestTimeout.Timeout is not TimeSpan timeout)
{
if (requestTimeout.PolicyName is not string policyName)
{
Debug.Fail("Either Timeout or PolicyName should have been set.");
return false;
}

if (!_timeoutOptions.CurrentValue.Policies.TryGetValue(policyName, out var policy))
{
// This should only happen if the policy existed at some point, but the options were updated to remove it.
return false;
}

if (policy.Timeout is null)
{
// The policy requested no timeout.
return false;
}

timeout = policy.Timeout.Value;
}

return timeout != Timeout.InfiniteTimeSpan;
}
#endif

private static class Log
{
private static readonly Action<ILogger, string, Exception?> _noClusterFound = LoggerMessage.Define<string>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Yarp.Tests.Common;
using Yarp.ReverseProxy.Configuration;
using Yarp.ReverseProxy.Forwarder;
using System.Diagnostics;

namespace Yarp.ReverseProxy.Model.Tests;

Expand Down Expand Up @@ -122,9 +123,12 @@ public async Task Invoke_NoHealthyEndpoints_CallsNext()

Assert.Equal(StatusCodes.Status418ImATeapot, httpContext.Response.StatusCode);
}

#if NET8_0_OR_GREATER
[Fact]
public async Task Invoke_MissingTimeoutMiddleware_RefuseRequest()
[Theory]
[InlineData(1)]
[InlineData(Timeout.Infinite)]
public async Task Invoke_MissingTimeoutMiddleware_RefuseRequest(int timeoutMs)
{
var httpClient = new HttpMessageInvoker(new Mock<HttpMessageHandler>().Object);
var cluster1 = new ClusterState(clusterId: "cluster1")
Expand All @@ -140,15 +144,23 @@ public async Task Invoke_MissingTimeoutMiddleware_RefuseRequest()
var aspNetCoreEndpoint = CreateAspNetCoreEndpoint(routeConfig,
builder =>
{
builder.Metadata.Add(new RequestTimeoutAttribute(1));
builder.Metadata.Add(new RequestTimeoutAttribute(timeoutMs));
});
aspNetCoreEndpoints.Add(aspNetCoreEndpoint);
var httpContext = new DefaultHttpContext();
httpContext.SetEndpoint(aspNetCoreEndpoint);

var sut = Create<ProxyPipelineInitializerMiddleware>();

await Assert.ThrowsAsync<InvalidOperationException>(() => sut.Invoke(httpContext));
if (timeoutMs == Timeout.Infinite || Debugger.IsAttached)
{
// If the timeout was infinite or the debugger is attached, we shouldn't refuse the request.
await sut.Invoke(httpContext);
}
else
{
await Assert.ThrowsAsync<InvalidOperationException>(() => sut.Invoke(httpContext));
}
}
#endif

Expand Down
Loading