From d72c4ce2eec13aacf1358c6a5299095605c6589f Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Wed, 15 Jan 2025 16:18:10 +0900 Subject: [PATCH] Introduce `Preprocessor` (#6057) Motivation: This PR introduces the notion of `Preprocessor`s and allows users to configure these to clients as options. The second part of this PR will introduce a way for users to solely create a client based on `Preprocessor`s. The eventual POC can be found here: https://github.com/jrhee17/armeria/pull/36/files Eventually this extension point will also make it easier/clearer for users to use xDS with existing Armeria APIs. The full capability/limitations/design of `Preprocessors` are better described in the following PR: https://github.com/line/armeria/pull/6051 Modifications: - Introduced `Preprocessor` and `PreClient` APIs - Added `ClientPreprocessors` and `ClientPreprocessorsBuilder` to allow users to easily add `Preprocessor`s to clients as options - Modified `DefaultWebClient`, `DefaultTHttpClient`, and `ArmeriaClientCall` to use `Preprocessor`s - In order to allow users a way to overwrite the chosen `EndpointGroup`, the `EndpointGroup` is now specified when creating a `ClientRequestContext` instead of at initialization time. - Modified `ClientUtil` methods to pass an additional `req` field which signifies the original request for type-safety. Result: - Users can specify `Preprocessor`s when creating a client. --- .../client/AbstractClientOptionsBuilder.java | 26 +++ .../client/AbstractWebClientBuilder.java | 12 ++ .../armeria/client/ClientBuilder.java | 10 ++ .../armeria/client/ClientOptions.java | 21 +++ .../armeria/client/ClientOptionsBuilder.java | 10 ++ .../armeria/client/ClientPreprocessors.java | 145 +++++++++++++++ .../client/ClientPreprocessorsBuilder.java | 73 ++++++++ .../armeria/client/ClientRequestContext.java | 3 +- .../client/ClientRequestContextBuilder.java | 21 +-- .../client/ClientRequestContextWrapper.java | 1 - .../armeria/client/DefaultWebClient.java | 17 +- .../HttpClientPipelineConfigurator.java | 3 +- .../armeria/client/HttpPreClient.java | 29 +++ .../armeria/client/HttpPreprocessor.java | 77 ++++++++ .../linecorp/armeria/client/PreClient.java | 47 +++++ .../client/PreClientRequestContext.java | 53 ++++++ .../linecorp/armeria/client/Preprocessor.java | 36 ++++ .../armeria/client/RedirectingClient.java | 5 +- .../armeria/client/RestClientBuilder.java | 11 ++ .../linecorp/armeria/client/RpcPreClient.java | 29 +++ .../armeria/client/RpcPreprocessor.java | 77 ++++++++ .../linecorp/armeria/client/UserClient.java | 64 ++++--- .../armeria/client/WebClientBuilder.java | 11 ++ .../armeria/client/retry/RetryingClient.java | 13 +- .../client/retry/RetryingRpcClient.java | 6 +- .../websocket/WebSocketClientBuilder.java | 21 +++ .../client/ClientRequestContextExtension.java | 19 +- .../armeria/internal/client/ClientUtil.java | 46 +++-- .../client/DefaultClientRequestContext.java | 169 +++++++++++++----- .../internal/client/TailPreClient.java | 77 ++++++++ .../common/CancellationScheduler.java | 5 + .../common/DefaultCancellationScheduler.java | 3 +- .../common/NonWrappingRequestContext.java | 13 +- .../common/NoopCancellationScheduler.java | 7 +- .../server/DefaultServiceRequestContext.java | 9 +- .../client/ClientContextCustomizerTest.java | 56 ++++-- .../client/ClientOptionsBuilderTest.java | 78 ++++++++ .../armeria/client/ClientOptionsTest.java | 6 +- .../armeria/client/HttpPreprocessorTest.java | 120 +++++++++++++ .../DefaultClientRequestContextTest.java | 25 ++- ...DerivedClientRequestContextClientTest.java | 15 +- .../eureka/EurekaEndpointGroupBuilder.java | 13 ++ .../eureka/EurekaUpdatingListenerBuilder.java | 13 ++ .../client/grpc/GrpcClientBuilder.java | 14 ++ .../internal/client/grpc/ArmeriaChannel.java | 27 ++- .../client/grpc/ArmeriaClientCall.java | 37 ++-- .../client/auth/oauth2/OAuth2Client.java | 2 +- .../retrofit2/ArmeriaRetrofitBuilder.java | 12 ++ .../client/thrift/ThriftClientBuilder.java | 14 ++ .../client/thrift/DefaultTHttpClient.java | 15 +- .../client/thrift/RpcPreprocessorTest.java | 103 +++++++++++ 51 files changed, 1517 insertions(+), 202 deletions(-) create mode 100644 core/src/main/java/com/linecorp/armeria/client/ClientPreprocessors.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/ClientPreprocessorsBuilder.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/HttpPreClient.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/HttpPreprocessor.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/PreClient.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/PreClientRequestContext.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/Preprocessor.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/RpcPreClient.java create mode 100644 core/src/main/java/com/linecorp/armeria/client/RpcPreprocessor.java create mode 100644 core/src/main/java/com/linecorp/armeria/internal/client/TailPreClient.java create mode 100644 core/src/test/java/com/linecorp/armeria/client/HttpPreprocessorTest.java create mode 100644 thrift/thrift0.13/src/test/java/com/linecorp/armeria/client/thrift/RpcPreprocessorTest.java diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java b/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java index 50cd5b966f1..6c42b6452cb 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java @@ -57,6 +57,7 @@ public class AbstractClientOptionsBuilder { private final Map, ClientOptionValue> options = new LinkedHashMap<>(); private final ClientDecorationBuilder decoration = ClientDecoration.builder(); + private final ClientPreprocessorsBuilder clientPreprocessorsBuilder = new ClientPreprocessorsBuilder(); private final HttpHeadersBuilder headers = HttpHeaders.builder(); @Nullable @@ -127,6 +128,8 @@ public AbstractClientOptionsBuilder option(ClientOptionValue optionValue) } else if (opt == ClientOptions.HEADERS) { final HttpHeaders h = (HttpHeaders) optionValue.value(); setHeaders(h); + } else if (opt == ClientOptions.PREPROCESSORS) { + clientPreprocessorsBuilder.add((ClientPreprocessors) optionValue.value()); } else { options.put(opt, optionValue); } @@ -520,6 +523,28 @@ public AbstractClientOptionsBuilder responseTimeoutMode(ResponseTimeoutMode resp requireNonNull(responseTimeoutMode, "responseTimeoutMode")); } + /** + * Adds the specified HTTP-level {@code preprocessor}. + * + * @param preprocessor the {@link HttpPreprocessor} that preprocesses an invocation + */ + @UnstableApi + public AbstractClientOptionsBuilder preprocessor(HttpPreprocessor preprocessor) { + clientPreprocessorsBuilder.add(preprocessor); + return this; + } + + /** + * Adds the specified RPC-level {@code rpcPreprocessor}. + * + * @param rpcPreprocessor the {@link RpcPreprocessor} that preprocesses an invocation + */ + @UnstableApi + public AbstractClientOptionsBuilder rpcPreprocessor(RpcPreprocessor rpcPreprocessor) { + clientPreprocessorsBuilder.addRpc(rpcPreprocessor); + return this; + } + /** * Builds {@link ClientOptions} with the given options and the * {@linkplain ClientOptions#of() default options}. @@ -538,6 +563,7 @@ protected final ClientOptions buildOptions(@Nullable ClientOptions baseOptions) ImmutableList.builder(); additionalValues.addAll(optVals); additionalValues.add(ClientOptions.DECORATION.newValue(decoration.build())); + additionalValues.add(ClientOptions.PREPROCESSORS.newValue(clientPreprocessorsBuilder.build())); additionalValues.add(ClientOptions.HEADERS.newValue(headers.build())); additionalValues.add(ClientOptions.CONTEXT_HOOK.newValue(contextHook)); if (contextCustomizer != null) { diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java index 04ca8523611..f75ff6ec13b 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java @@ -184,4 +184,16 @@ public AbstractWebClientBuilder rpcDecorator(Function contextHook) public ClientBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (ClientBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + @Override + public ClientBuilder preprocessor(HttpPreprocessor decorator) { + return (ClientBuilder) super.preprocessor(decorator); + } + + @Override + public ClientBuilder rpcPreprocessor(RpcPreprocessor decorator) { + return (ClientBuilder) super.rpcPreprocessor(decorator); + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java b/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java index 2975cff4eb4..c08bb7fc382 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java @@ -160,6 +160,19 @@ public final class ClientOptions public static final ClientOption RESPONSE_TIMEOUT_MODE = ClientOption.define("RESPONSE_TIMEOUT_MODE", Flags.responseTimeoutMode()); + @UnstableApi + public static final ClientOption PREPROCESSORS = + ClientOption.define("PREPROCESSORS", ClientPreprocessors.of(), Function.identity(), + (oldValue, newValue) -> { + final ClientPreprocessors newPreprocessors = newValue.value(); + final ClientPreprocessors oldPreprocessors = oldValue.value(); + return newValue.option().newValue( + ClientPreprocessors.builder() + .add(oldPreprocessors) + .add(newPreprocessors) + .build()); + }); + private static final List PROHIBITED_HEADER_NAMES = ImmutableList.of( HttpHeaderNames.HTTP2_SETTINGS, HttpHeaderNames.METHOD, @@ -410,6 +423,14 @@ public ResponseTimeoutMode responseTimeoutMode() { return get(RESPONSE_TIMEOUT_MODE); } + /** + * Returns the {@link Preprocessor}s that preprocesses the components of a client. + */ + @UnstableApi + public ClientPreprocessors clientPreprocessors() { + return get(PREPROCESSORS); + } + /** * Returns a new {@link ClientOptionsBuilder} created from this {@link ClientOptions}. */ diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientOptionsBuilder.java b/core/src/main/java/com/linecorp/armeria/client/ClientOptionsBuilder.java index 70fa0cc947a..5c43659fe03 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientOptionsBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientOptionsBuilder.java @@ -227,4 +227,14 @@ public ClientOptionsBuilder contextHook(Supplier contex public ClientOptionsBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (ClientOptionsBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + @Override + public ClientOptionsBuilder preprocessor(HttpPreprocessor decorator) { + return (ClientOptionsBuilder) super.preprocessor(decorator); + } + + @Override + public ClientOptionsBuilder rpcPreprocessor(RpcPreprocessor decorator) { + return (ClientOptionsBuilder) super.rpcPreprocessor(decorator); + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientPreprocessors.java b/core/src/main/java/com/linecorp/armeria/client/ClientPreprocessors.java new file mode 100644 index 00000000000..ea2f8faa44f --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/ClientPreprocessors.java @@ -0,0 +1,145 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import java.util.List; +import java.util.Objects; +import java.util.function.Function; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * A set of {@link Function}s that transforms a {@link HttpPreprocessor} or + * {@link RpcPreprocessor} into another. + */ +@UnstableApi +public final class ClientPreprocessors { + + private static final ClientPreprocessors NONE = + new ClientPreprocessors(ImmutableList.of(), ImmutableList.of()); + + /** + * Returns an empty {@link ClientDecoration} which does not decorate a {@link Client}. + */ + public static ClientPreprocessors of() { + return NONE; + } + + /** + * Creates a new instance from a single {@link HttpPreprocessor}. + * + * @param preprocessor the {@link HttpPreprocessor} that transforms an + * {@link HttpPreClient} to another + */ + public static ClientPreprocessors of(HttpPreprocessor preprocessor) { + return builder().add(preprocessor).build(); + } + + /** + * Creates a new instance from a single {@link RpcPreprocessor}. + * + * @param preprocessor the {@link RpcPreprocessor} that transforms an {@link RpcPreClient} + * to another + */ + public static ClientPreprocessors ofRpc(RpcPreprocessor preprocessor) { + return builder().addRpc(preprocessor).build(); + } + + /** + * Returns a newly created {@link ClientPreprocessorsBuilder}. + */ + public static ClientPreprocessorsBuilder builder() { + return new ClientPreprocessorsBuilder(); + } + + private final List preprocessors; + private final List rpcPreprocessors; + + ClientPreprocessors(List preprocessors, List rpcPreprocessors) { + this.preprocessors = ImmutableList.copyOf(preprocessors); + this.rpcPreprocessors = ImmutableList.copyOf(rpcPreprocessors); + } + + /** + * Returns the HTTP-level preprocessors. + */ + public List preprocessors() { + return preprocessors; + } + + /** + * Returns the RPC-level preprocessors. + */ + public List rpcPreprocessors() { + return rpcPreprocessors; + } + + /** + * Decorates the specified {@link HttpPreClient} using preprocessors. + * + * @param execution the {@link HttpPreClient} being decorated + */ + public HttpPreClient decorate(HttpPreClient execution) { + for (HttpPreprocessor preprocessor : preprocessors) { + final HttpPreClient execution0 = execution; + execution = (ctx, req) -> preprocessor.execute(execution0, ctx, req); + } + return execution; + } + + /** + * Decorates the specified {@link RpcPreClient} using preprocessors. + * + * @param execution the {@link RpcPreClient} being decorated + */ + public RpcPreClient rpcDecorate(RpcPreClient execution) { + for (RpcPreprocessor rpcPreprocessor : rpcPreprocessors) { + final RpcPreClient execution0 = execution; + execution = (ctx, req) -> rpcPreprocessor.execute(execution0, ctx, req); + } + return execution; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + if (object == null || getClass() != object.getClass()) { + return false; + } + final ClientPreprocessors that = (ClientPreprocessors) object; + return Objects.equals(preprocessors, that.preprocessors) && + Objects.equals(rpcPreprocessors, that.rpcPreprocessors); + } + + @Override + public int hashCode() { + return Objects.hash(preprocessors, rpcPreprocessors); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("preprocessors", preprocessors) + .add("rpcPreprocessors", rpcPreprocessors) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientPreprocessorsBuilder.java b/core/src/main/java/com/linecorp/armeria/client/ClientPreprocessorsBuilder.java new file mode 100644 index 00000000000..8586fa5d8b0 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/ClientPreprocessorsBuilder.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static java.util.Objects.requireNonNull; + +import java.util.ArrayList; +import java.util.List; + +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Creates a new {@link ClientPreprocessors} using the builder pattern. + */ +@UnstableApi +public final class ClientPreprocessorsBuilder { + + private final List preprocessors = new ArrayList<>(); + private final List rpcPreprocessors = new ArrayList<>(); + + ClientPreprocessorsBuilder() {} + + /** + * Adds the specified {@link ClientPreprocessors}. + */ + public ClientPreprocessorsBuilder add(ClientPreprocessors preprocessors) { + requireNonNull(preprocessors, "preprocessors"); + preprocessors.preprocessors().forEach(this::add); + preprocessors.rpcPreprocessors().forEach(this::addRpc); + return this; + } + + /** + * Adds the specified HTTP-level {@code preprocessor}. + * + * @param preprocessor the {@link HttpPreprocessor} that preprocesses an invocation + */ + public ClientPreprocessorsBuilder add(HttpPreprocessor preprocessor) { + preprocessors.add(requireNonNull(preprocessor, "preprocessor")); + return this; + } + + /** + * Adds the specified RPC-level {@code preprocessor}. + * + * @param rpcPreprocessor the {@link HttpPreprocessor} that preprocesses an invocation + */ + public ClientPreprocessorsBuilder addRpc(RpcPreprocessor rpcPreprocessor) { + rpcPreprocessors.add(requireNonNull(rpcPreprocessor, "rpcPreprocessor")); + return this; + } + + /** + * Returns a newly-created {@link ClientPreprocessors} based on the decorators added to this builder. + */ + public ClientPreprocessors build() { + return new ClientPreprocessors(preprocessors, rpcPreprocessors); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java index 017197f2bde..1572175d268 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java @@ -251,9 +251,8 @@ ClientRequestContext newDerivedContext(RequestId id, @Nullable HttpRequest req, * Returns the {@link EndpointGroup} used for the current {@link Request}. * * @return the {@link EndpointGroup} if a user specified an {@link EndpointGroup} when initiating - * a {@link Request}. {@code null} if a user specified an {@link Endpoint}. + * a {@link Request}. */ - @Nullable EndpointGroup endpointGroup(); /** diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java index 1543cc10e34..65ce8bef323 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextBuilder.java @@ -15,7 +15,6 @@ */ package com.linecorp.armeria.client; -import static com.linecorp.armeria.internal.common.CancellationScheduler.noopCancellationTask; import static java.util.Objects.requireNonNull; import java.net.InetSocketAddress; @@ -123,23 +122,17 @@ public ClientRequestContext build() { endpointGroup = Endpoint.parse(authority()); } - final CancellationScheduler responseCancellationScheduler; - if (timedOut()) { - responseCancellationScheduler = CancellationScheduler.finished(false); - } else { - responseCancellationScheduler = CancellationScheduler.ofClient(0); - } final DefaultClientRequestContext ctx = new DefaultClientRequestContext( - eventLoop(), meterRegistry(), sessionProtocol(), id(), method(), requestTarget(), options, - request(), rpcRequest(), requestOptions, responseCancellationScheduler, + eventLoop(), meterRegistry(), sessionProtocol(), id(), method(), requestTarget(), + endpointGroup, options, + request(), rpcRequest(), requestOptions, CancellationScheduler.ofClient(0), isRequestStartTimeSet() ? requestStartTimeNanos() : System.nanoTime(), isRequestStartTimeSet() ? requestStartTimeMicros() : SystemInfo.currentTimeMicros()); - - ctx.init(endpointGroup).handle((unused, cause) -> { + if (timedOut()) { + ctx.timeoutNow(); + } + ctx.init().handle((unused, cause) -> { ctx.finishInitialization(cause == null); - if (!timedOut()) { - ctx.responseCancellationScheduler().initAndStart(ctx.eventLoop(), noopCancellationTask); - } return null; }); ctx.logBuilder().session(fakeChannel(ctx.eventLoop()), sessionProtocol(), sslSession(), diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java index f03e0a15de3..ddd30792a4f 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java @@ -52,7 +52,6 @@ public ClientRequestContext newDerivedContext(RequestId id, @Nullable HttpReques return unwrap().newDerivedContext(id, req, rpcReq, endpoint); } - @Nullable @Override public EndpointGroup endpointGroup() { return unwrap().endpointGroup(); diff --git a/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java b/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java index e812281120e..ad847f1c528 100644 --- a/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java @@ -29,6 +29,9 @@ import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.client.ClientUtil; +import com.linecorp.armeria.internal.client.DefaultClientRequestContext; +import com.linecorp.armeria.internal.client.TailPreClient; import io.micrometer.core.instrument.MeterRegistry; @@ -45,10 +48,14 @@ final class DefaultWebClient extends UserClient imple private BlockingWebClient blockingWebClient; @Nullable private RestClient restClient; + private final HttpPreClient preClient; DefaultWebClient(ClientBuilderParams params, HttpClient delegate, MeterRegistry meterRegistry) { super(params, delegate, meterRegistry, HttpResponse::of, (ctx, cause) -> HttpResponse.ofFailure(cause)); + final HttpPreClient tailPreClient = + TailPreClient.of(unwrap(), futureConverter(), errorResponseFactory()); + preClient = options().clientPreprocessors().decorate(tailPreClient); } @Override @@ -113,12 +120,10 @@ public HttpResponse execute(HttpRequest req, RequestOptions requestOptions) { newReq = req.withHeaders(req.headers().toBuilder().path(newPath)); } - return execute(protocol, - endpointGroup, - newReq.method(), - reqTarget, - newReq, - requestOptions); + final DefaultClientRequestContext ctx = new DefaultClientRequestContext( + protocol, newReq, newReq.method(), null, reqTarget, endpointGroup, requestOptions, options(), + meterRegistry()); + return ClientUtil.executeWithFallback(preClient, ctx, newReq, errorResponseFactory()); } private static HttpResponse abortRequestAndReturnFailureResponse( diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java index 7f119e28da4..bb94e9ff6e8 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java @@ -43,6 +43,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.Flags; import com.linecorp.armeria.common.HttpObject; import com.linecorp.armeria.common.HttpRequest; @@ -572,7 +573,7 @@ public void onComplete() {} final DefaultClientRequestContext reqCtx = new DefaultClientRequestContext( ctx.channel().eventLoop(), Flags.meterRegistry(), H1C, RequestId.random(), com.linecorp.armeria.common.HttpMethod.OPTIONS, - REQ_TARGET_ASTERISK, ClientOptions.of(), + REQ_TARGET_ASTERISK, EndpointGroup.of(), ClientOptions.of(), HttpRequest.of(com.linecorp.armeria.common.HttpMethod.OPTIONS, "*"), null, REQUEST_OPTIONS_FOR_UPGRADE_REQUEST, CancellationScheduler.noop(), System.nanoTime(), SystemInfo.currentTimeMicros()); diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpPreClient.java b/core/src/main/java/com/linecorp/armeria/client/HttpPreClient.java new file mode 100644 index 00000000000..b59ca8c4a70 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/HttpPreClient.java @@ -0,0 +1,29 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Prepares a {@link HttpRequest} before sending it to a remote {@link Endpoint}. + */ +@UnstableApi +@FunctionalInterface +public interface HttpPreClient extends PreClient { +} diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpPreprocessor.java b/core/src/main/java/com/linecorp/armeria/client/HttpPreprocessor.java new file mode 100644 index 00000000000..67121666afa --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/HttpPreprocessor.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static java.util.Objects.requireNonNull; + +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.UnstableApi; + +import io.netty.channel.EventLoop; + +/** + * An HTTP-based preprocessor that intercepts an outgoing request and allows users to + * customize certain properties before entering the decorating chain. The following + * illustrates a sample use-case: + *
{@code
+ * HttpPreprocessor preprocessor = (delegate, ctx, req) -> {
+ *     ctx.setEndpointGroup(Endpoint.of("overriding-host"));
+ *     return delegate.execute(ctx, req);
+ * };
+ * WebClient client = WebClient.builder()
+ *                             .preprocessor(preprocessor)
+ *                             .build();
+ * }
+ */ +@UnstableApi +@FunctionalInterface +public interface HttpPreprocessor extends Preprocessor { + + /** + * A simple {@link HttpPreprocessor} which overwrites the {@link SessionProtocol}, + * {@link EndpointGroup}, and {@link EventLoop} for a request. + */ + static HttpPreprocessor of(SessionProtocol sessionProtocol, EndpointGroup endpointGroup, + EventLoop eventLoop) { + requireNonNull(sessionProtocol, "sessionProtocol"); + requireNonNull(endpointGroup, "endpointGroup"); + requireNonNull(eventLoop, "eventLoop"); + return (delegate, ctx, req) -> { + ctx.setSessionProtocol(sessionProtocol); + ctx.setEndpointGroup(endpointGroup); + ctx.setEventLoop(eventLoop); + return delegate.execute(ctx, req); + }; + } + + /** + * A simple {@link HttpPreprocessor} which overwrites the {@link SessionProtocol} and + * {@link EndpointGroup} for a request. + */ + static HttpPreprocessor of(SessionProtocol sessionProtocol, EndpointGroup endpointGroup) { + requireNonNull(sessionProtocol, "sessionProtocol"); + requireNonNull(endpointGroup, "endpointGroup"); + return (delegate, ctx, req) -> { + ctx.setSessionProtocol(sessionProtocol); + ctx.setEndpointGroup(endpointGroup); + return delegate.execute(ctx, req); + }; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/PreClient.java b/core/src/main/java/com/linecorp/armeria/client/PreClient.java new file mode 100644 index 00000000000..9ad38ec6fc7 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/PreClient.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.Request; +import com.linecorp.armeria.common.Response; +import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.RpcResponse; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Prepares a {@link Request} before sending it to a remote {@link Endpoint}. + * + *

Note that this interface is not a user's entry point for sending a {@link Request}. It is rather + * a generic request processor interface which intercepts a {@link Request}. + * A user should implement {@link Preprocessor} and add it to the client instead. + * + * @param the type of outgoing {@link Request}. Must be {@link HttpRequest} or {@link RpcRequest}. + * @param the type of incoming {@link Response}. Must be {@link HttpResponse} or {@link RpcResponse}. + */ +@UnstableApi +@FunctionalInterface +public interface PreClient { + + /** + * Prepares a {@link Request} before sending it to a remote {@link Endpoint}. + * + * @return the {@link Response} to the specified {@link Request} + */ + O execute(PreClientRequestContext ctx, I req) throws Exception; +} diff --git a/core/src/main/java/com/linecorp/armeria/client/PreClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/client/PreClientRequestContext.java new file mode 100644 index 00000000000..112b2ddd5b5 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/PreClientRequestContext.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.Request; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.UnstableApi; + +import io.netty.channel.EventLoop; + +/** + * A {@link ClientRequestContext} which allows certain properties to be mutable before + * initialization is finalized. + */ +@UnstableApi +public interface PreClientRequestContext extends ClientRequestContext { + + /** + * Sets the {@link EndpointGroup} used for the current {@link Request}. + */ + void setEndpointGroup(EndpointGroup endpointGroup); + + /** + * Sets the {@link SessionProtocol} of the current {@link Request}. + */ + void setSessionProtocol(SessionProtocol sessionProtocol); + + /** + * Sets the {@link EventLoop} which will handle this request. Because changing + * the assigned {@link EventLoop} can lead to unexpected behavior, this property + * can be set only once. Because the assigned {@link EventLoop} can influence the number of + * connections made to an {@link Endpoint}, it is recommended to understand {@link EventLoopScheduler} + * before manually setting this value. + * + * @see EventLoopScheduler + */ + void setEventLoop(EventLoop eventLoop); +} diff --git a/core/src/main/java/com/linecorp/armeria/client/Preprocessor.java b/core/src/main/java/com/linecorp/armeria/client/Preprocessor.java new file mode 100644 index 00000000000..3ba8d46725a --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/Preprocessor.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.Request; +import com.linecorp.armeria.common.Response; + +/** + * Decorates a {@link PreClient}. Use either {@link HttpPreClient} or {@link RpcPreClient} + * depending on whether the client is HTTP-based or RPC-based. + * + * @param the {@link Request} type of the {@link Client} being decorated + * @param the {@link Response} type of the {@link Client} being decorated + */ +@FunctionalInterface +public interface Preprocessor { + + /** + * Creates a new instance that decorates the specified {@link PreClient}. + */ + O execute(PreClient delegate, PreClientRequestContext ctx, I req) throws Exception; +} diff --git a/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java b/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java index d2481315513..073389bde43 100644 --- a/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java @@ -192,8 +192,11 @@ private void execute0(ClientRequestContext ctx, RedirectContext redirectCtx, return; } + final HttpRequest req = derivedCtx.request(); + assert req != null; final HttpResponse response = executeWithFallback(unwrap(), derivedCtx, - (context, cause) -> HttpResponse.ofFailure(cause)); + (context, cause) -> HttpResponse.ofFailure(cause), + req); derivedCtx.log().whenAvailable(RequestLogProperty.RESPONSE_HEADERS).thenAccept(log -> { if (log.isAvailable(RequestLogProperty.RESPONSE_CAUSE)) { final Throwable cause = log.responseCause(); diff --git a/core/src/main/java/com/linecorp/armeria/client/RestClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/RestClientBuilder.java index c577104c969..d13f0441f7a 100644 --- a/core/src/main/java/com/linecorp/armeria/client/RestClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/RestClientBuilder.java @@ -259,4 +259,15 @@ public RestClientBuilder contextHook(Supplier contextHo public RestClientBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (RestClientBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + @Override + public RestClientBuilder preprocessor(HttpPreprocessor decorator) { + return (RestClientBuilder) super.preprocessor(decorator); + } + + @Override + @Deprecated + public RestClientBuilder rpcPreprocessor(RpcPreprocessor rpcPreprocessor) { + return (RestClientBuilder) super.rpcPreprocessor(rpcPreprocessor); + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/RpcPreClient.java b/core/src/main/java/com/linecorp/armeria/client/RpcPreClient.java new file mode 100644 index 00000000000..85dbb07dfa6 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/RpcPreClient.java @@ -0,0 +1,29 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.RpcResponse; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Prepares a {@link RpcRequest} before sending it to a remote {@link Endpoint}. + */ +@UnstableApi +@FunctionalInterface +public interface RpcPreClient extends PreClient { +} diff --git a/core/src/main/java/com/linecorp/armeria/client/RpcPreprocessor.java b/core/src/main/java/com/linecorp/armeria/client/RpcPreprocessor.java new file mode 100644 index 00000000000..3f0d560d0e2 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/RpcPreprocessor.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static java.util.Objects.requireNonNull; + +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.RpcResponse; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.UnstableApi; + +import io.netty.channel.EventLoop; + +/** + * An RPC-based preprocessor that intercepts an outgoing request and allows users to + * customize certain properties before entering the decorating chain. The following + * illustrates a sample use-case: + *

{@code
+ * RpcPreprocessor preprocessor = (delegate, ctx, req) -> {
+ *     ctx.setEndpointGroup(Endpoint.of("overriding-host"));
+ *     return delegate.execute(ctx, req);
+ * };
+ * Iface iface = ThriftClients.builder(Endpoint.of("overridden-host"))
+ *                            .rpcPreprocessor(rpcPreprocessor)
+ *                            .build(Iface.class);
+ * }
+ */ +@UnstableApi +@FunctionalInterface +public interface RpcPreprocessor extends Preprocessor { + + /** + * A simple {@link RpcPreprocessor} which overwrites the {@link SessionProtocol}, + * {@link EndpointGroup}, and {@link EventLoop} for a request. + */ + static RpcPreprocessor of(SessionProtocol sessionProtocol, EndpointGroup endpointGroup, + EventLoop eventLoop) { + requireNonNull(sessionProtocol, "sessionProtocol"); + requireNonNull(endpointGroup, "endpointGroup"); + requireNonNull(eventLoop, "eventLoop"); + return (delegate, ctx, req) -> { + ctx.setSessionProtocol(sessionProtocol); + ctx.setEndpointGroup(endpointGroup); + ctx.setEventLoop(eventLoop); + return delegate.execute(ctx, req); + }; + } + + /** + * A simple {@link RpcPreprocessor} which overwrites the {@link SessionProtocol}, + * {@link EndpointGroup}, and {@link EventLoop} for a request. + */ + static RpcPreprocessor of(SessionProtocol sessionProtocol, EndpointGroup endpointGroup) { + requireNonNull(sessionProtocol, "sessionProtocol"); + requireNonNull(endpointGroup, "endpointGroup"); + return (delegate, ctx, req) -> { + ctx.setSessionProtocol(sessionProtocol); + ctx.setEndpointGroup(endpointGroup); + return delegate.execute(ctx, req); + }; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/UserClient.java b/core/src/main/java/com/linecorp/armeria/client/UserClient.java index 28b38fa24e3..311b2c1a0f9 100644 --- a/core/src/main/java/com/linecorp/armeria/client/UserClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/UserClient.java @@ -23,15 +23,11 @@ import java.util.function.BiFunction; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.Request; -import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.Response; import com.linecorp.armeria.common.RpcRequest; @@ -39,7 +35,6 @@ import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.util.AbstractUnwrappable; -import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.internal.client.DefaultClientRequestContext; import io.micrometer.core.instrument.MeterRegistry; @@ -58,9 +53,6 @@ public abstract class UserClient extends AbstractUnwrappable> implements ClientBuilderParams { - private static final Logger logger = LoggerFactory.getLogger(UserClient.class); - private static boolean warnedNullRequestId; - private final ClientBuilderParams params; private final MeterRegistry meterRegistry; private final Function, O> futureConverter; @@ -118,6 +110,29 @@ public final ClientOptions options() { return params.options(); } + /** + * The {@link Function} that converts a {@link CompletableFuture} of response + * into a response, e.g. {@link HttpResponse#of(CompletionStage)} + * and {@link RpcResponse#from(CompletionStage)}. + */ + protected Function, O> futureConverter() { + return futureConverter; + } + + /** + * The {@link BiFunction} that returns a new response failed with the given exception. + */ + protected BiFunction errorResponseFactory() { + return errorResponseFactory; + } + + /** + * The {@link MeterRegistry} used for requests produced by this client. + */ + protected MeterRegistry meterRegistry() { + return meterRegistry; + } + /** * Executes the specified {@link Request} via the delegate. * @@ -125,7 +140,10 @@ public final ClientOptions options() { * @param method the method of the {@link Request} * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} + * + * @deprecated prefer {@link ClientOptions#clientPreprocessors()} to execute requests */ + @Deprecated protected final O execute(SessionProtocol protocol, HttpMethod method, RequestTarget reqTarget, I req) { return execute(protocol, method, reqTarget, req, RequestOptions.of()); } @@ -138,7 +156,10 @@ protected final O execute(SessionProtocol protocol, HttpMethod method, RequestTa * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} * @param requestOptions the {@link RequestOptions} of the {@link Request} + * + * @deprecated prefer {@link ClientOptions#clientPreprocessors()} to execute requests */ + @Deprecated protected final O execute(SessionProtocol protocol, HttpMethod method, RequestTarget reqTarget, I req, RequestOptions requestOptions) { return execute(protocol, endpointGroup(), method, reqTarget, req, requestOptions); @@ -152,7 +173,10 @@ protected final O execute(SessionProtocol protocol, HttpMethod method, RequestTa * @param method the method of the {@link Request} * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} + * + * @deprecated prefer {@link ClientOptions#clientPreprocessors()} to execute requests */ + @Deprecated protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, HttpMethod method, RequestTarget reqTarget, I req) { return execute(protocol, endpointGroup, method, reqTarget, req, RequestOptions.of()); @@ -167,13 +191,15 @@ protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, * @param reqTarget the {@link RequestTarget} of the {@link Request} * @param req the {@link Request} * @param requestOptions the {@link RequestOptions} of the {@link Request} + * + * @deprecated prefer {@link ClientOptions#clientPreprocessors()} to execute requests */ + @Deprecated protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, HttpMethod method, RequestTarget reqTarget, I req, RequestOptions requestOptions) { final HttpRequest httpReq; final RpcRequest rpcReq; - final RequestId id = nextRequestId(); if (req instanceof HttpRequest) { httpReq = (HttpRequest) req; @@ -184,23 +210,9 @@ protected final O execute(SessionProtocol protocol, EndpointGroup endpointGroup, } final DefaultClientRequestContext ctx = new DefaultClientRequestContext( - meterRegistry, protocol, id, method, reqTarget, options(), httpReq, rpcReq, - requestOptions, System.nanoTime(), SystemInfo.currentTimeMicros()); - - return initContextAndExecuteWithFallback(unwrap(), ctx, endpointGroup, - futureConverter, errorResponseFactory); - } + protocol, httpReq, method, rpcReq, reqTarget, endpointGroup, + requestOptions, options(), meterRegistry); - private RequestId nextRequestId() { - final RequestId id = options().requestIdGenerator().get(); - if (id == null) { - if (!warnedNullRequestId) { - warnedNullRequestId = true; - logger.warn("requestIdGenerator.get() returned null; using RequestId.random()"); - } - return RequestId.random(); - } else { - return id; - } + return initContextAndExecuteWithFallback(unwrap(), ctx, futureConverter, errorResponseFactory, req); } } diff --git a/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java index f6d1d3f656f..1e9ec4ff4a8 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java @@ -255,4 +255,15 @@ public WebClientBuilder contextCustomizer( public WebClientBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (WebClientBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + @Override + public WebClientBuilder preprocessor(HttpPreprocessor decorator) { + return (WebClientBuilder) super.preprocessor(decorator); + } + + @Override + @Deprecated + public WebClientBuilder rpcPreprocessor(RpcPreprocessor rpcPreprocessor) { + return (WebClientBuilder) super.rpcPreprocessor(rpcPreprocessor); + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/retry/RetryingClient.java b/core/src/main/java/com/linecorp/armeria/client/retry/RetryingClient.java index ec0454934e0..d0e28488bcd 100644 --- a/core/src/main/java/com/linecorp/armeria/client/retry/RetryingClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/retry/RetryingClient.java @@ -32,7 +32,6 @@ import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.HttpClient; import com.linecorp.armeria.client.ResponseTimeoutException; -import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.AggregationOptions; import com.linecorp.armeria.common.HttpHeaderNames; @@ -310,20 +309,20 @@ private void doExecute0(ClientRequestContext ctx, HttpRequestDuplicator rootReqD return; } + final HttpRequest ctxReq = derivedCtx.request(); + assert ctxReq != null; final HttpResponse response; - final EndpointGroup endpointGroup = derivedCtx.endpointGroup(); final ClientRequestContextExtension ctxExtension = derivedCtx.as(ClientRequestContextExtension.class); - if (!initialAttempt && ctxExtension != null && - endpointGroup != null && derivedCtx.endpoint() == null) { + if (!initialAttempt && ctxExtension != null && derivedCtx.endpoint() == null) { // clear the pending throwable to retry endpoint selection ClientPendingThrowableUtil.removePendingThrowable(derivedCtx); // if the endpoint hasn't been selected, try to initialize the ctx with a new endpoint/event loop response = initContextAndExecuteWithFallback( - unwrap(), ctxExtension, endpointGroup, HttpResponse::of, - (context, cause) -> HttpResponse.ofFailure(cause)); + unwrap(), ctxExtension, HttpResponse::of, + (context, cause) -> HttpResponse.ofFailure(cause), ctxReq); } else { response = executeWithFallback(unwrap(), derivedCtx, - (context, cause) -> HttpResponse.ofFailure(cause)); + (context, cause) -> HttpResponse.ofFailure(cause), ctxReq); } final RetryConfig config = mappedRetryConfig(ctx); if (!ctx.exchangeType().isResponseStreaming() || config.requiresResponseTrailers()) { diff --git a/core/src/main/java/com/linecorp/armeria/client/retry/RetryingRpcClient.java b/core/src/main/java/com/linecorp/armeria/client/retry/RetryingRpcClient.java index 9968568a458..996e3550a9a 100644 --- a/core/src/main/java/com/linecorp/armeria/client/retry/RetryingRpcClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/retry/RetryingRpcClient.java @@ -179,11 +179,11 @@ private void doExecute0(ClientRequestContext ctx, RpcRequest req, // clear the pending throwable to retry endpoint selection ClientPendingThrowableUtil.removePendingThrowable(derivedCtx); // if the endpoint hasn't been selected, try to initialize the ctx with a new endpoint/event loop - res = initContextAndExecuteWithFallback(unwrap(), ctxExtension, endpointGroup, RpcResponse::from, - (context, cause) -> RpcResponse.ofFailure(cause)); + res = initContextAndExecuteWithFallback(unwrap(), ctxExtension, RpcResponse::from, + (context, cause) -> RpcResponse.ofFailure(cause), req); } else { res = executeWithFallback(unwrap(), derivedCtx, - (context, cause) -> RpcResponse.ofFailure(cause)); + (context, cause) -> RpcResponse.ofFailure(cause), req); } final RetryConfig retryConfig = mappedRetryConfig(ctx); diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java index 1109efe8477..840169c0439 100644 --- a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java @@ -45,8 +45,10 @@ import com.linecorp.armeria.client.DecoratingRpcClientFunction; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreprocessor; import com.linecorp.armeria.client.ResponseTimeoutMode; import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreprocessor; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.redirect.RedirectConfig; @@ -396,4 +398,23 @@ public WebSocketClientBuilder contextHook(Supplier cont public WebSocketClientBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (WebSocketClientBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + /** + * Raises an {@link UnsupportedOperationException} because {@link WebSocketClient} does + * not support {@link HttpPreprocessor}. + * + * @deprecated HTTP preprocessor cannot be added to the {@link WebSocketClient}. + */ + @Override + @Deprecated + public WebSocketClientBuilder preprocessor(HttpPreprocessor preprocessor) { + throw new UnsupportedOperationException( + "WebSocketClientBuilder does not support preprocessor."); + } + + @Override + @Deprecated + public WebSocketClientBuilder rpcPreprocessor(RpcPreprocessor rpcPreprocessor) { + return (WebSocketClientBuilder) super.rpcPreprocessor(rpcPreprocessor); + } } diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/ClientRequestContextExtension.java b/core/src/main/java/com/linecorp/armeria/internal/client/ClientRequestContextExtension.java index 26a9081031e..b70d542080e 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/ClientRequestContextExtension.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/ClientRequestContextExtension.java @@ -20,6 +20,7 @@ import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.client.endpoint.EndpointSelector; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.logging.RequestLog; @@ -41,7 +42,7 @@ public interface ClientRequestContextExtension extends ClientRequestContext, Req * Returns a {@link CompletableFuture} that will be completed * if this {@link ClientRequestContext} is initialized with an {@link EndpointGroup}. * - * @see #init(EndpointGroup) + * @see #init() */ CompletableFuture whenInitialized(); @@ -53,7 +54,7 @@ public interface ClientRequestContextExtension extends ClientRequestContext, Req * {@code false} if the initialization has failed and this context's {@link RequestLog} has been * completed with the cause of the failure. */ - CompletableFuture init(EndpointGroup endpointGroup); + CompletableFuture init(); /** * Completes the {@link #whenInitialized()} with the specified value. @@ -75,4 +76,18 @@ public interface ClientRequestContextExtension extends ClientRequestContext, Req HttpHeaders internalRequestHeaders(); long remainingTimeoutNanos(); + + /** + * The context customizer must be run before the following conditions. + *
  • + *
      + * {@link EndpointSelector#selectNow(ClientRequestContext)} so that the customizer + * can inject the attributes which may be required by the EndpointSelector.
    + *
      + * mapEndpoint() to give an opportunity to override an Endpoint when using + * an additional authority. + *
    + *
  • + */ + void runContextCustomizer(); } diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java b/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java index 1c4ff86b472..1317b3f0a71 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java @@ -15,7 +15,6 @@ */ package com.linecorp.armeria.internal.client; -import static com.google.common.base.MoreObjects.firstNonNull; import static java.util.Objects.requireNonNull; import java.net.URI; @@ -26,6 +25,8 @@ import com.linecorp.armeria.client.Client; import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.PreClient; +import com.linecorp.armeria.client.PreClientRequestContext; import com.linecorp.armeria.client.UnprocessedRequestException; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.endpoint.EndpointGroup; @@ -54,20 +55,19 @@ public final class ClientUtil { O initContextAndExecuteWithFallback( U delegate, ClientRequestContextExtension ctx, - EndpointGroup endpointGroup, Function, O> futureConverter, - BiFunction errorResponseFactory) { + BiFunction errorResponseFactory, + I req) { requireNonNull(delegate, "delegate"); requireNonNull(ctx, "ctx"); - requireNonNull(endpointGroup, "endpointGroup"); requireNonNull(futureConverter, "futureConverter"); requireNonNull(errorResponseFactory, "errorResponseFactory"); boolean initialized = false; boolean success = false; try { - final CompletableFuture initFuture = ctx.init(endpointGroup); + final CompletableFuture initFuture = ctx.init(); initialized = initFuture.isDone(); if (initialized) { // Initialization has been done immediately. @@ -77,7 +77,7 @@ O initContextAndExecuteWithFallback( throw UnprocessedRequestException.of(Exceptions.peel(e)); } - return initContextAndExecuteWithFallback(delegate, ctx, errorResponseFactory, success); + return initContextAndExecuteWithFallback(delegate, ctx, errorResponseFactory, success, req); } else { return futureConverter.apply(initFuture.handle((success0, cause) -> { try { @@ -85,7 +85,8 @@ O initContextAndExecuteWithFallback( throw UnprocessedRequestException.of(Exceptions.peel(cause)); } - return initContextAndExecuteWithFallback(delegate, ctx, errorResponseFactory, success0); + return initContextAndExecuteWithFallback( + delegate, ctx, errorResponseFactory, success0, req); } catch (Throwable t) { fail(ctx, t); return errorResponseFactory.apply(ctx, t); @@ -107,11 +108,11 @@ O initContextAndExecuteWithFallback( private static > O initContextAndExecuteWithFallback( U delegate, ClientRequestContextExtension ctx, - BiFunction errorResponseFactory, boolean succeeded) + BiFunction errorResponseFactory, boolean succeeded, I req) throws Exception { if (succeeded) { - return pushAndExecute(delegate, ctx); + return pushAndExecute(delegate, ctx, req); } else { final Throwable cause = ctx.log().partial().requestCause(); assert cause != null; @@ -123,7 +124,7 @@ O initContextAndExecuteWithFallback( // See `init()` and `failEarly()` in `DefaultClientRequestContext`. // Call the decorator chain anyway so that the request is seen by the decorators. - final O res = pushAndExecute(delegate, ctx); + final O res = pushAndExecute(delegate, ctx, req); // We will use the fallback response which is created from the exception // raised in ctx.init(), so the response returned can be aborted. @@ -138,24 +139,39 @@ O initContextAndExecuteWithFallback( public static > O executeWithFallback(U delegate, ClientRequestContext ctx, - BiFunction errorResponseFactory) { + BiFunction errorResponseFactory, I req) { requireNonNull(delegate, "delegate"); requireNonNull(ctx, "ctx"); requireNonNull(errorResponseFactory, "errorResponseFactory"); try { - return pushAndExecute(delegate, ctx); + return pushAndExecute(delegate, ctx, req); } catch (Throwable cause) { fail(ctx, cause); return errorResponseFactory.apply(ctx, cause); } } + public static > + O executeWithFallback(U execution, + PreClientRequestContext ctx, I req, + BiFunction errorResponseFactory) { + final ClientRequestContextExtension ctxExt = ctx.as(ClientRequestContextExtension.class); + if (ctxExt != null) { + ctxExt.runContextCustomizer(); + } + try { + return execution.execute(ctx, req); + } catch (Exception e) { + final UnprocessedRequestException upe = UnprocessedRequestException.of(e); + fail(ctx, upe); + return errorResponseFactory.apply(ctx, upe); + } + } + private static > - O pushAndExecute(U delegate, ClientRequestContext ctx) throws Exception { - @SuppressWarnings("unchecked") - final I req = (I) firstNonNull(ctx.request(), ctx.rpcRequest()); + O pushAndExecute(U delegate, ClientRequestContext ctx, I req) throws Exception { try (SafeCloseable ignored = ctx.push()) { return delegate.execute(ctx, req); } diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java index 4ad2e2bdeee..3bb0da0b334 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java @@ -36,9 +36,13 @@ import javax.net.ssl.SSLSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.linecorp.armeria.client.ClientOptions; import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.PreClientRequestContext; import com.linecorp.armeria.client.RequestOptions; import com.linecorp.armeria.client.ResponseTimeoutMode; import com.linecorp.armeria.client.UnprocessedRequestException; @@ -68,6 +72,7 @@ import com.linecorp.armeria.common.logging.RequestLogProperty; import com.linecorp.armeria.common.util.ReleasableHolder; import com.linecorp.armeria.common.util.SafeCloseable; +import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.common.util.TextFormatter; import com.linecorp.armeria.common.util.TimeoutMode; import com.linecorp.armeria.common.util.UnmodifiableFuture; @@ -97,7 +102,9 @@ */ public final class DefaultClientRequestContext extends NonWrappingRequestContext - implements ClientRequestContextExtension { + implements ClientRequestContextExtension, PreClientRequestContext { + + private static final Logger logger = LoggerFactory.getLogger(DefaultClientRequestContext.class); private static final AtomicReferenceFieldUpdater additionalRequestHeadersUpdater = AtomicReferenceFieldUpdater.newUpdater( @@ -124,12 +131,13 @@ private static SessionProtocol desiredSessionProtocol(SessionProtocol protocol, private static final short STR_CHANNEL_AVAILABILITY = 1; private static final short STR_PARENT_LOG_AVAILABILITY = 1 << 1; + private static boolean warnedNullRequestId; private boolean initialized; @Nullable private EventLoop eventLoop; - @Nullable private EndpointGroup endpointGroup; + private SessionProtocol sessionProtocol; @Nullable private Endpoint endpoint; @Nullable @@ -161,18 +169,44 @@ private static SessionProtocol desiredSessionProtocol(SessionProtocol protocol, private String strVal; private short strValAvailabilities; - // We use null checks which are faster than checking if a list is empty, - // because it is more common to have no customizers than to have any. - @Nullable - private volatile Consumer customizer; - @Nullable private volatile CompletableFuture whenInitialized; private final ResponseTimeoutMode responseTimeoutMode; + public DefaultClientRequestContext(SessionProtocol sessionProtocol, HttpRequest httpRequest, + @Nullable RpcRequest rpcRequest, RequestTarget requestTarget, + EndpointGroup endpointGroup, RequestOptions requestOptions, + ClientOptions clientOptions) { + this(null, clientOptions.factory().meterRegistry(), + sessionProtocol, nextRequestId(clientOptions), httpRequest.method(), requestTarget, + endpointGroup, clientOptions, httpRequest, rpcRequest, requestOptions, serviceRequestContext(), + null, System.nanoTime(), SystemInfo.currentTimeMicros()); + } + + public DefaultClientRequestContext(SessionProtocol sessionProtocol, @Nullable HttpRequest httpRequest, + HttpMethod method, @Nullable RpcRequest rpcRequest, + RequestTarget requestTarget, EndpointGroup endpointGroup, + RequestOptions requestOptions, ClientOptions clientOptions, + MeterRegistry meterRegistry) { + this(null, meterRegistry, + sessionProtocol, nextRequestId(clientOptions), method, requestTarget, + endpointGroup, clientOptions, httpRequest, rpcRequest, requestOptions, serviceRequestContext(), + null, System.nanoTime(), SystemInfo.currentTimeMicros()); + } + + public DefaultClientRequestContext(SessionProtocol sessionProtocol, @Nullable HttpRequest httpRequest, + HttpMethod method, @Nullable RpcRequest rpcRequest, + RequestTarget requestTarget, EndpointGroup endpointGroup, + RequestOptions requestOptions, ClientOptions clientOptions) { + this(null, clientOptions.factory().meterRegistry(), + sessionProtocol, nextRequestId(clientOptions), method, requestTarget, + endpointGroup, clientOptions, httpRequest, rpcRequest, requestOptions, serviceRequestContext(), + null, System.nanoTime(), SystemInfo.currentTimeMicros()); + } + /** - * Creates a new instance. Note that {@link #init(EndpointGroup)} method must be invoked to finish + * Creates a new instance. Note that {@link #init()} method must be invoked to finish * the construction of this context. * * @param eventLoop the {@link EventLoop} associated with this context @@ -187,18 +221,18 @@ private static SessionProtocol desiredSessionProtocol(SessionProtocol protocol, */ public DefaultClientRequestContext( @Nullable EventLoop eventLoop, MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, HttpMethod method, RequestTarget reqTarget, + RequestId id, HttpMethod method, RequestTarget reqTarget, EndpointGroup endpointGroup, ClientOptions options, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, RequestOptions requestOptions, CancellationScheduler responseCancellationScheduler, long requestStartTimeNanos, long requestStartTimeMicros) { - this(eventLoop, meterRegistry, sessionProtocol, - id, method, reqTarget, options, req, rpcReq, requestOptions, serviceRequestContext(), + this(eventLoop, meterRegistry, sessionProtocol, id, method, reqTarget, endpointGroup, + options, req, rpcReq, requestOptions, serviceRequestContext(), requireNonNull(responseCancellationScheduler, "responseCancellationScheduler"), requestStartTimeNanos, requestStartTimeMicros); } /** - * Creates a new instance. Note that {@link #init(EndpointGroup)} method must be invoked to finish + * Creates a new instance. Note that {@link #init()} method must be invoked to finish * the construction of this context. * * @param sessionProtocol the {@link SessionProtocol} of the invocation @@ -212,12 +246,12 @@ id, method, reqTarget, options, req, rpcReq, requestOptions, serviceRequestConte */ public DefaultClientRequestContext( MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, HttpMethod method, RequestTarget reqTarget, - ClientOptions options, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, + RequestId id, HttpMethod method, RequestTarget reqTarget, EndpointGroup endpointGroup, + ClientOptions options, HttpRequest req, @Nullable RpcRequest rpcReq, RequestOptions requestOptions, long requestStartTimeNanos, long requestStartTimeMicros) { this(null, meterRegistry, sessionProtocol, - id, method, reqTarget, options, req, rpcReq, requestOptions, + id, method, reqTarget, endpointGroup, options, req, rpcReq, requestOptions, serviceRequestContext(), /* responseCancellationScheduler */ null, requestStartTimeNanos, requestStartTimeMicros); } @@ -225,18 +259,20 @@ public DefaultClientRequestContext( private DefaultClientRequestContext( @Nullable EventLoop eventLoop, MeterRegistry meterRegistry, SessionProtocol sessionProtocol, RequestId id, HttpMethod method, - RequestTarget reqTarget, ClientOptions options, + RequestTarget reqTarget, EndpointGroup endpointGroup, ClientOptions options, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, RequestOptions requestOptions, @Nullable ServiceRequestContext root, @Nullable CancellationScheduler responseCancellationScheduler, long requestStartTimeNanos, long requestStartTimeMicros) { - super(meterRegistry, desiredSessionProtocol(sessionProtocol, options), id, method, reqTarget, + super(meterRegistry, id, method, reqTarget, guessExchangeType(requestOptions, req), requestAutoAbortDelayMillis(options, requestOptions), req, rpcReq, getAttributes(root), options.contextHook()); + this.sessionProtocol = desiredSessionProtocol(sessionProtocol, options); this.eventLoop = eventLoop; this.options = requireNonNull(options, "options"); this.root = root; + this.endpointGroup = endpointGroup; log = RequestLog.builder(this); log.startRequest(requestStartTimeNanos, requestStartTimeMicros); @@ -272,16 +308,6 @@ private DefaultClientRequestContext( defaultRequestHeaders = options.get(ClientOptions.HEADERS); additionalRequestHeaders = HttpHeaders.of(); - - final Consumer customizer = options.contextCustomizer(); - final Consumer threadLocalCustomizer = copyThreadLocalCustomizer(); - if (customizer == ClientOptions.CONTEXT_CUSTOMIZER.defaultValue()) { - this.customizer = threadLocalCustomizer; - } else if (threadLocalCustomizer == null) { - this.customizer = customizer; - } else { - this.customizer = customizer.andThen(threadLocalCustomizer); - } responseTimeoutMode = responseTimeoutMode(options, requestOptions); } @@ -328,19 +354,19 @@ private static ServiceRequestContext serviceRequestContext() { } @Override - public CompletableFuture init(EndpointGroup endpointGroup) { + public CompletableFuture init() { assert endpoint == null : endpoint; assert !initialized; initialized = true; - try { - // Note: context customizer must be run before: - // - EndpointSelector.select() so that the customizer can inject the attributes which may be - // required by the EndpointSelector. - // - mapEndpoint() to give an opportunity to override an Endpoint when using - // an additional authority. - runContextCustomizer(); + final Throwable cancellationCause = cancellationCause(); + if (cancellationCause != null) { + acquireEventLoop(endpointGroup); + failEarly(cancellationCause); + return initFuture(false, null); + } + try { endpointGroup = mapEndpoint(endpointGroup); if (endpointGroup instanceof Endpoint) { return initEndpoint((Endpoint) endpointGroup); @@ -364,7 +390,6 @@ private EndpointGroup mapEndpoint(EndpointGroup endpointGroup) { } private CompletableFuture initEndpoint(Endpoint endpoint) { - endpointGroup = null; updateEndpoint(endpoint); acquireEventLoop(endpoint); return initFuture(true, null); @@ -460,11 +485,24 @@ private void acquireEventLoop(EndpointGroup endpointGroup) { } } - private void runContextCustomizer() { - final Consumer customizer = this.customizer; + @Override + public void runContextCustomizer() { + final Consumer customizer; + final Consumer optionsCustomizer = options.contextCustomizer(); + final Consumer threadLocalCustomizer = copyThreadLocalCustomizer(); + if (optionsCustomizer == ClientOptions.CONTEXT_CUSTOMIZER.defaultValue()) { + customizer = threadLocalCustomizer; + } else if (threadLocalCustomizer == null) { + customizer = optionsCustomizer; + } else { + customizer = optionsCustomizer.andThen(threadLocalCustomizer); + } if (customizer != null) { - this.customizer = null; - customizer.accept(this); + try { + customizer.accept(this); + } catch (Throwable t) { + cancel(UnprocessedRequestException.of(t)); + } } } @@ -518,10 +556,10 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx, RequestId id, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, - @Nullable Endpoint endpoint, @Nullable EndpointGroup endpointGroup, + @Nullable Endpoint endpoint, EndpointGroup endpointGroup, SessionProtocol sessionProtocol, HttpMethod method, RequestTarget reqTarget) { - super(ctx.meterRegistry(), sessionProtocol, id, method, reqTarget, ctx.exchangeType(), + super(ctx.meterRegistry(), id, method, reqTarget, ctx.exchangeType(), ctx.requestAutoAbortDelayMillis(), req, rpcReq, getAttributes(ctx.root()), ctx.hook()); // The new requests cannot be null if it was previously non-null. @@ -533,6 +571,7 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx, // So we don't check the nullness of rpcRequest unlike request. // See https://github.com/line/armeria/pull/3251 and https://github.com/line/armeria/issues/3248. + this.sessionProtocol = requireNonNull(sessionProtocol, "sessionProtocol"); options = ctx.options(); root = ctx.root(); @@ -614,6 +653,18 @@ public ServiceRequestContext root() { return root; } + @Override + public SessionProtocol sessionProtocol() { + return sessionProtocol; + } + + @Override + public void setSessionProtocol(SessionProtocol sessionProtocol) { + checkState(!initialized, "Cannot update sessionProtocol after initialization"); + this.sessionProtocol = desiredSessionProtocol(requireNonNull(sessionProtocol, "sessionProtocol"), + options); + } + @Override public ClientRequestContext newDerivedContext(RequestId id, @Nullable HttpRequest req, @@ -630,7 +681,7 @@ public ClientRequestContext newDerivedContext(RequestId id, if (reqTarget.form() != RequestTargetForm.ABSOLUTE) { // Not an absolute URI. - return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, null, + return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, endpointGroup, sessionProtocol(), newHeaders.method(), reqTarget); } @@ -645,11 +696,11 @@ public ClientRequestContext newDerivedContext(RequestId id, final HttpRequest newReq = req.withHeaders(req.headers() .toBuilder() .path(reqTarget.pathAndQuery())); - return new DefaultClientRequestContext(this, id, newReq, rpcReq, newEndpoint, null, + return new DefaultClientRequestContext(this, id, newReq, rpcReq, newEndpoint, newEndpoint, protocol, newHeaders.method(), reqTarget); } } - return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, endpointGroup(), + return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, endpointGroup, sessionProtocol(), method(), requestTarget()); } @@ -713,6 +764,14 @@ public ContextAwareEventLoop eventLoop() { return contextAwareEventLoop = ContextAwareEventLoop.of(this, eventLoop); } + @Override + public void setEventLoop(EventLoop eventLoop) { + checkState(!initialized, "Cannot update eventLoop after initialization"); + checkState(this.eventLoop == null, "eventLoop can be updated only once"); + this.eventLoop = requireNonNull(eventLoop, "eventLoop"); + initializeResponseCancellationScheduler(); + } + @Override public ByteBufAllocator alloc() { final Channel channel = channel(); @@ -731,12 +790,17 @@ public ClientOptions options() { return options; } - @Nullable @Override public EndpointGroup endpointGroup() { return endpointGroup; } + @Override + public void setEndpointGroup(EndpointGroup endpointGroup) { + checkState(!initialized, "Cannot update endpointGroup after initialization"); + this.endpointGroup = requireNonNull(endpointGroup, "endpointGroup"); + } + @Nullable @Override public Endpoint endpoint() { @@ -1082,4 +1146,17 @@ private static ResponseTimeoutMode responseTimeoutMode(ClientOptions options, } return options.responseTimeoutMode(); } + + private static RequestId nextRequestId(ClientOptions options) { + final RequestId id = options.requestIdGenerator().get(); + if (id == null) { + if (!warnedNullRequestId) { + warnedNullRequestId = true; + logger.warn("requestIdGenerator.get() returned null; using RequestId.random()"); + } + return RequestId.random(); + } else { + return id; + } + } } diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/TailPreClient.java b/core/src/main/java/com/linecorp/armeria/internal/client/TailPreClient.java new file mode 100644 index 00000000000..d4572f5b322 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/client/TailPreClient.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.client; + +import java.util.concurrent.CompletableFuture; +import java.util.function.BiFunction; +import java.util.function.Function; + +import com.linecorp.armeria.client.Client; +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreClient; +import com.linecorp.armeria.client.PreClient; +import com.linecorp.armeria.client.PreClientRequestContext; +import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreClient; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.Request; +import com.linecorp.armeria.common.Response; +import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.RpcResponse; + +public final class TailPreClient implements PreClient { + + private final Client delegate; + private final Function, O> futureConverter; + private final BiFunction errorResponseFactory; + + private TailPreClient(Client delegate, + Function, O> futureConverter, + BiFunction errorResponseFactory) { + this.delegate = delegate; + this.futureConverter = futureConverter; + this.errorResponseFactory = errorResponseFactory; + } + + public static HttpPreClient of( + HttpClient httpClient, + Function, HttpResponse> futureConverter, + BiFunction errorResponseFactory) { + final TailPreClient tail = + new TailPreClient<>(httpClient, futureConverter, errorResponseFactory); + return tail::execute; + } + + public static RpcPreClient ofRpc( + RpcClient rpcClient, + Function, RpcResponse> futureConverter, + BiFunction errorResponseFactory) { + final TailPreClient tail = + new TailPreClient<>(rpcClient, futureConverter, errorResponseFactory); + return tail::execute; + } + + @Override + public O execute(PreClientRequestContext ctx, I req) { + final ClientRequestContextExtension ctxExt = ctx.as(ClientRequestContextExtension.class); + assert ctxExt != null; + return ClientUtil.initContextAndExecuteWithFallback(delegate, ctxExt, + futureConverter, errorResponseFactory, req); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java index 7f5ff402ee6..8486e5fc879 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/CancellationScheduler.java @@ -20,6 +20,8 @@ import java.util.concurrent.CompletableFuture; +import com.google.common.annotations.VisibleForTesting; + import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.TimeoutMode; @@ -128,6 +130,9 @@ default void finishNow() { */ void updateTask(CancellationTask cancellationTask); + @VisibleForTesting + State state(); + enum State { INIT, SCHEDULED, diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java index f0e7bddd380..ba8055fde64 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultCancellationScheduler.java @@ -389,8 +389,9 @@ private Throwable getFinalCause(@Nullable Throwable cause) { return cause; } + @Override @VisibleForTesting - State state() { + public State state() { return state; } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java index 08d38b39579..61ec14b5e1b 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/NonWrappingRequestContext.java @@ -39,7 +39,6 @@ import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.RequestTargetForm; import com.linecorp.armeria.common.RpcRequest; -import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -60,7 +59,6 @@ public abstract class NonWrappingRequestContext implements RequestContextExtensi private final MeterRegistry meterRegistry; private final ConcurrentAttributes attrs; - private SessionProtocol sessionProtocol; private final RequestId id; private final HttpMethod method; private RequestTarget reqTarget; @@ -82,9 +80,8 @@ public abstract class NonWrappingRequestContext implements RequestContextExtensi * Creates a new instance. */ protected NonWrappingRequestContext( - MeterRegistry meterRegistry, SessionProtocol sessionProtocol, - RequestId id, HttpMethod method, RequestTarget reqTarget, ExchangeType exchangeType, - long requestAutoAbortDelayMillis, + MeterRegistry meterRegistry, RequestId id, HttpMethod method, RequestTarget reqTarget, + ExchangeType exchangeType, long requestAutoAbortDelayMillis, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq, @Nullable AttributesGetters rootAttributeMap, Supplier contextHook) { assert req != null || rpcReq != null; @@ -96,7 +93,6 @@ protected NonWrappingRequestContext( attrs = ConcurrentAttributes.fromParent(rootAttributeMap); } - this.sessionProtocol = requireNonNull(sessionProtocol, "sessionProtocol"); this.id = requireNonNull(id, "id"); this.method = requireNonNull(method, "method"); this.reqTarget = requireNonNull(reqTarget, "reqTarget"); @@ -153,11 +149,6 @@ public final void updateRpcRequest(RpcRequest rpcReq) { @Nullable protected abstract RequestTarget validateHeaders(RequestHeaders headers); - @Override - public final SessionProtocol sessionProtocol() { - return sessionProtocol; - } - /** * Returns the {@link Channel} that is handling this request, or {@code null} if the connection is not * established yet. diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java b/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java index 4bd6e94ffc9..6eda742ed7f 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/NoopCancellationScheduler.java @@ -30,8 +30,6 @@ final class NoopCancellationScheduler implements CancellationScheduler { private static final CompletableFuture THROWABLE_FUTURE = UnmodifiableFuture.wrap(new CompletableFuture<>()); - private static final CompletableFuture VOID_FUTURE = - UnmodifiableFuture.wrap(new CompletableFuture<>()); private NoopCancellationScheduler() { } @@ -113,4 +111,9 @@ public CompletableFuture whenCancelled() { @Override public void updateTask(CancellationTask cancellationTask) { } + + @Override + public State state() { + return State.INIT; + } } diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java index 7da08ba953e..ef5999036d7 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/DefaultServiceRequestContext.java @@ -92,6 +92,7 @@ public final class DefaultServiceRequestContext additionalResponseTrailersUpdater = AtomicReferenceFieldUpdater.newUpdater( DefaultServiceRequestContext.class, HttpHeaders.class, "additionalResponseTrailers"); + private final SessionProtocol sessionProtocol; private final Channel ch; private final EventLoop eventLoop; private final ServiceConfig cfg; @@ -170,11 +171,12 @@ public DefaultServiceRequestContext( HttpHeaders additionalResponseHeaders, HttpHeaders additionalResponseTrailers, Supplier contextHook) { - super(meterRegistry, sessionProtocol, id, + super(meterRegistry, id, requireNonNull(routingContext, "routingContext").method(), routingContext.requestTarget(), exchangeType, cfg.requestAutoAbortDelayMillis(), requireNonNull(req, "req"), null, null, contextHook); + this.sessionProtocol = requireNonNull(sessionProtocol, "sessionProtocol"); this.ch = requireNonNull(ch, "ch"); this.eventLoop = requireNonNull(eventLoop, "eventLoop"); this.cfg = requireNonNull(cfg, "cfg"); @@ -231,6 +233,11 @@ public Iterator, Object>> attrs() { return ownAttrs(); } + @Override + public SessionProtocol sessionProtocol() { + return sessionProtocol; + } + @Nonnull @Override public InetSocketAddress remoteAddress() { diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientContextCustomizerTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientContextCustomizerTest.java index e37301676a7..98979ce9a4f 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ClientContextCustomizerTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ClientContextCustomizerTest.java @@ -20,10 +20,15 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; @@ -48,27 +53,44 @@ protected void configure(ServerBuilder sb) { } }; - @Test - void contextCustomizer_ClientBuilder() { + public static Stream contextCustomizer_ClientBuilder_args() { + final HttpPreprocessor asyncPreprocessor = + (delegate, ctx, req) -> HttpResponse.of( + CompletableFuture.supplyAsync(() -> { + try { + return delegate.execute(ctx, req); + } catch (Exception e) { + throw new RuntimeException(e); + } + })); + return Stream.of( + Arguments.of(WebClient.builder(server.httpUri())), + Arguments.of(WebClient.builder(server.httpUri()) + .preprocessor(asyncPreprocessor)) + ); + } + + @ParameterizedTest + @MethodSource("contextCustomizer_ClientBuilder_args") + void contextCustomizer_ClientBuilder(WebClientBuilder builder) { final String traceId = "12345"; final AtomicReference threadRef = new AtomicReference<>(); final BlockingWebClient client = - WebClient.builder(server.httpUri()) - .contextCustomizer(ctx -> { - threadRef.set(Thread.currentThread()); - ctx.setAttr(TRACE_ID, traceId); - }) - .decorator((delegate, ctx, req) -> { - final HttpRequest newReq = req.mapHeaders(headers -> { - return headers.toBuilder() - .add("X-Trace-ID", ctx.attr(TRACE_ID)) - .build(); - }); - ctx.updateRequest(newReq); - return delegate.execute(ctx, newReq); - }).build() - .blocking(); + builder.contextCustomizer(ctx -> { + threadRef.set(Thread.currentThread()); + ctx.setAttr(TRACE_ID, traceId); + }) + .decorator((delegate, ctx, req) -> { + final HttpRequest newReq = req.mapHeaders(headers -> { + return headers.toBuilder() + .add("X-Trace-ID", ctx.attr(TRACE_ID)) + .build(); + }); + ctx.updateRequest(newReq); + return delegate.execute(ctx, newReq); + }).build() + .blocking(); assertThat(client.get("/foo").contentUtf8()).isEqualTo("12345:null"); assertThat(threadRef).hasValue(Thread.currentThread()); diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientOptionsBuilderTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientOptionsBuilderTest.java index 979e26aa164..bbc72409bcd 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ClientOptionsBuilderTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ClientOptionsBuilderTest.java @@ -19,6 +19,8 @@ import static org.assertj.core.api.Assertions.assertThat; import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.function.Function; import java.util.function.Supplier; @@ -29,9 +31,13 @@ import com.linecorp.armeria.client.logging.LoggingRpcClient; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; +import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.RpcResponse; +import com.linecorp.armeria.internal.client.DefaultClientRequestContext; class ClientOptionsBuilderTest { @Test @@ -195,6 +201,78 @@ void testDecoratorDowncast() { assertThat(outer.as(LoggingClient.class)).isNull(); } + @Test + void testPreprocessors() throws Exception { + final ClientOptionsBuilder b = ClientOptions.builder(); + final List processorsList = new ArrayList<>(); + final HttpPreprocessor http1 = new RunnableHttpPreprocessor(() -> processorsList.add("http1")); + final HttpPreprocessor http2 = new RunnableHttpPreprocessor(() -> processorsList.add("http2")); + final HttpPreprocessor http3 = new RunnableHttpPreprocessor(() -> processorsList.add("http3")); + + b.option(ClientOptions.PREPROCESSORS.newValue(ClientPreprocessors.builder() + .add(http1).add(http2).build())); + assertThat(b.build().clientPreprocessors().preprocessors()).containsExactly(http1, http2); + b.option(ClientOptions.PREPROCESSORS.newValue(ClientPreprocessors.builder().add(http3).build())); + assertThat(b.build().clientPreprocessors().preprocessors()).containsExactly(http1, http2, http3); + + final HttpRequest req = HttpRequest.of(HttpMethod.GET, "/"); + final DefaultClientRequestContext ctx = (DefaultClientRequestContext) ClientRequestContext.of(req); + b.build().clientPreprocessors().decorate((ctx0, req0) -> HttpResponse.of(200)) + .execute(ctx, req); + assertThat(processorsList).containsExactly("http3", "http2", "http1"); + + // Add an RPC decorator. + processorsList.clear(); + final RpcPreprocessor rpc1 = new RunnableRpcPreprocessor(() -> processorsList.add("rpc1")); + final RpcPreprocessor rpc2 = new RunnableRpcPreprocessor(() -> processorsList.add("rpc2")); + final RpcPreprocessor rpc3 = new RunnableRpcPreprocessor(() -> processorsList.add("rpc3")); + + b.option(ClientOptions.PREPROCESSORS.newValue( + ClientPreprocessors.builder().addRpc(rpc1).addRpc(rpc2).build())); + assertThat(b.build().clientPreprocessors().rpcPreprocessors()).containsExactly(rpc1, rpc2); + b.rpcPreprocessor(rpc3); + assertThat(b.build().clientPreprocessors().rpcPreprocessors()).containsSequence(rpc1, rpc2, rpc3); + + final RpcRequest rpcRequest = RpcRequest.of(Object.class, "method"); + final DefaultClientRequestContext rpcCtx = + (DefaultClientRequestContext) ClientRequestContext.of(rpcRequest, "http://127.0.0.1"); + b.build().clientPreprocessors().rpcDecorate((ctx0, req0) -> RpcResponse.of(200)) + .execute(rpcCtx, rpcRequest); + assertThat(processorsList).containsExactly("rpc3", "rpc2", "rpc1"); + } + + private static class RunnableHttpPreprocessor implements HttpPreprocessor { + + private final Runnable runnable; + + RunnableHttpPreprocessor(Runnable runnable) { + this.runnable = runnable; + } + + @Override + public HttpResponse execute(PreClient delegate, + PreClientRequestContext ctx, HttpRequest req) throws Exception { + runnable.run(); + return delegate.execute(ctx, req); + } + } + + private static class RunnableRpcPreprocessor implements RpcPreprocessor { + + private final Runnable runnable; + + RunnableRpcPreprocessor(Runnable runnable) { + this.runnable = runnable; + } + + @Override + public RpcResponse execute(PreClient delegate, + PreClientRequestContext ctx, RpcRequest req) throws Exception { + runnable.run(); + return delegate.execute(ctx, req); + } + } + private static final class FooClient implements HttpClient { FooClient() { } diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientOptionsTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientOptionsTest.java index 2ccc417f453..0742fce6853 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ClientOptionsTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ClientOptionsTest.java @@ -19,6 +19,7 @@ import static com.linecorp.armeria.client.ClientOptions.ENDPOINT_REMAPPER; import static com.linecorp.armeria.client.ClientOptions.HEADERS; import static com.linecorp.armeria.client.ClientOptions.MAX_RESPONSE_LENGTH; +import static com.linecorp.armeria.client.ClientOptions.PREPROCESSORS; import static com.linecorp.armeria.client.ClientOptions.REQUEST_ID_GENERATOR; import static com.linecorp.armeria.client.ClientOptions.RESPONSE_TIMEOUT_MILLIS; import static com.linecorp.armeria.client.ClientOptions.WRITE_TIMEOUT_MILLIS; @@ -42,6 +43,7 @@ import com.linecorp.armeria.client.logging.LoggingClient; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; +import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.RequestId; class ClientOptionsTest { @@ -124,7 +126,9 @@ public Stream provideArguments(ExtensionContext context) th arguments(HEADERS, HttpHeaders.of(HttpHeaderNames.USER_AGENT, "armeria")), arguments(DECORATION, ClientDecoration.of(LoggingClient.newDecorator())), arguments(REQUEST_ID_GENERATOR, requestIdGenerator), - arguments(ENDPOINT_REMAPPER, Function.identity())); + arguments(ENDPOINT_REMAPPER, Function.identity()), + arguments(PREPROCESSORS, ClientPreprocessors.of( + (delegate, ctx, req) -> HttpResponse.of(200)))); } } } diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpPreprocessorTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpPreprocessorTest.java new file mode 100644 index 00000000000..a05ca061607 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/HttpPreprocessorTest.java @@ -0,0 +1,120 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.internal.client.ClientRequestContextExtension; +import com.linecorp.armeria.internal.common.CancellationScheduler.State; +import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; + +class HttpPreprocessorTest { + + @RegisterExtension + static final EventLoopExtension eventLoop = new EventLoopExtension(); + + @Test + void overwriteByCustomPreprocessor() { + final HttpPreprocessor preprocessor = + HttpPreprocessor.of(SessionProtocol.HTTP, Endpoint.of("127.0.0.1"), + eventLoop.get()); + final WebClient client = WebClient.builder() + .preprocessor(preprocessor) + .decorator((delegate, ctx, req) -> HttpResponse.of(200)) + .build(); + final ClientRequestContext ctx; + try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + final AggregatedHttpResponse res = client.get("https://127.0.0.2").aggregate().join(); + assertThat(res.status().code()).isEqualTo(200); + ctx = captor.get(); + } + assertThat(ctx.sessionProtocol()).isEqualTo(SessionProtocol.HTTP); + assertThat(ctx.authority()).isEqualTo("127.0.0.1"); + assertThat(ctx.eventLoop().withoutContext()).isSameAs(eventLoop.get()); + } + + @Test + void preprocessorOrder() { + final List list = new ArrayList<>(); + final HttpPreprocessor p1 = RunnablePreprocessor.of(() -> list.add("1")); + final HttpPreprocessor p2 = RunnablePreprocessor.of(() -> list.add("2")); + final HttpPreprocessor p3 = RunnablePreprocessor.of(() -> list.add("3")); + + final WebClient client = WebClient.builder() + .preprocessor(p1) + .preprocessor(p2) + .preprocessor(p3) + .decorator((delegate, ctx, req) -> HttpResponse.of(200)) + .build(); + final AggregatedHttpResponse res = client.get("http://127.0.0.1").aggregate().join(); + assertThat(res.status().code()).isEqualTo(200); + assertThat(list).containsExactly("3", "2", "1"); + } + + @Test + void cancellationSchedulerIsInitializedCorrectly() { + final HttpPreprocessor preprocessor = (delegate, ctx, req) -> { + ctx.setEventLoop(eventLoop.get()); + return delegate.execute(ctx, req); + }; + final BlockingWebClient client = + WebClient.builder("http://1.2.3.4") + .preprocessor(preprocessor) + .responseTimeoutMode(ResponseTimeoutMode.FROM_START) + .responseTimeoutMillis(10_000) + .decorator((delegate, ctx, req) -> { + assertThat(ctx.as(ClientRequestContextExtension.class) + .responseCancellationScheduler() + .state()) + .isEqualTo(State.SCHEDULED); + return HttpResponse.of(200); + }) + .build() + .blocking(); + assertThat(client.get("/").status().code()).isEqualTo(200); + } + + private static final class RunnablePreprocessor implements HttpPreprocessor { + + private static HttpPreprocessor of(Runnable runnable) { + return new RunnablePreprocessor(runnable); + } + + private final Runnable runnable; + + private RunnablePreprocessor(Runnable runnable) { + this.runnable = runnable; + } + + @Override + public HttpResponse execute(PreClient delegate, + PreClientRequestContext ctx, HttpRequest req) throws Exception { + runnable.run(); + return delegate.execute(ctx, req); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java index a89871937ba..cbc6736cf9b 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/client/DefaultClientRequestContextTest.java @@ -37,6 +37,7 @@ import com.linecorp.armeria.client.Clients; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.RequestOptions; +import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpMethod; @@ -220,12 +221,13 @@ void testAuthorityOverridden() { final HttpRequest request1 = HttpRequest.of(RequestHeaders.of( HttpMethod.POST, "/foo", HttpHeaderNames.SCHEME, "http")); - final DefaultClientRequestContext ctx = newContext(ClientOptions.of(), request1); + final DefaultClientRequestContext ctx = newContext(ClientOptions.of(), request1, + Endpoint.of("endpoint.com", 8080)); assertThat(ctx.authority()).isNull(); assertThat(ctx.uri().toString()).isEqualTo("http:/foo"); assertThat(ctx.uri()).hasScheme("http").hasAuthority(null).hasPath("/foo"); - ctx.init(Endpoint.of("endpoint.com", 8080)); + ctx.init(); assertThat(ctx.authority()).isEqualTo("endpoint.com:8080"); assertThat(ctx.uri().toString()).isEqualTo("http://endpoint.com:8080/foo"); @@ -250,8 +252,9 @@ void testDefaultAuthorityOverridesInternal() { final ClientOptions clientOptions = ClientOptions.builder() .addHeader(HttpHeaderNames.AUTHORITY, "default.com") .build(); - final DefaultClientRequestContext ctx = newContext(clientOptions, request1); - ctx.init(Endpoint.of("example.com", 8080)); + final DefaultClientRequestContext ctx = newContext(clientOptions, request1, + Endpoint.of("example.com", 8080)); + ctx.init(); assertThat(ctx.authority()).isEqualTo("default.com"); assertThat(ctx.uri().toString()).isEqualTo("http://default.com/foo"); } @@ -277,7 +280,8 @@ void uriWithOnlySchemePath() { final HttpRequest request = HttpRequest.of(RequestHeaders.of( HttpMethod.POST, "/", HttpHeaderNames.SCHEME, "http")); - final DefaultClientRequestContext ctx = newContext(ClientOptions.of(), request); + final DefaultClientRequestContext ctx = newContext(ClientOptions.of(), request, + EndpointGroup.of()); ctx.updateRequest(request); assertThat(ctx.uri().toString()).isEqualTo("http:/"); } @@ -287,19 +291,22 @@ private static DefaultClientRequestContext newContext() { HttpMethod.POST, "/foo", HttpHeaderNames.SCHEME, "http", HttpHeaderNames.AUTHORITY, "example.com:8080")); - final DefaultClientRequestContext ctx = newContext(ClientOptions.of(), request); - ctx.init(Endpoint.of("example.com", 8080)); + final DefaultClientRequestContext ctx = newContext(ClientOptions.of(), request, + Endpoint.of("example.com", 8080)); + ctx.runContextCustomizer(); + ctx.init(); return ctx; } private static DefaultClientRequestContext newContext(ClientOptions clientOptions, - HttpRequest httpRequest) { + HttpRequest httpRequest, + EndpointGroup endpointGroup) { final RequestTarget reqTarget = RequestTarget.forClient(httpRequest.path()); assertThat(reqTarget).isNotNull(); return new DefaultClientRequestContext( mock(EventLoop.class), NoopMeterRegistry.get(), SessionProtocol.H2C, - RequestId.random(), HttpMethod.POST, reqTarget, clientOptions, httpRequest, + RequestId.random(), HttpMethod.POST, reqTarget, endpointGroup, clientOptions, httpRequest, null, RequestOptions.of(), CancellationScheduler.ofClient(0), System.nanoTime(), SystemInfo.currentTimeMicros()); } diff --git a/core/src/test/java/com/linecorp/armeria/internal/client/DerivedClientRequestContextClientTest.java b/core/src/test/java/com/linecorp/armeria/internal/client/DerivedClientRequestContextClientTest.java index 0b6ab9a2ccd..f6d71ab8bb5 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/client/DerivedClientRequestContextClientTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/client/DerivedClientRequestContextClientTest.java @@ -55,8 +55,9 @@ void shouldAcquireNewEventLoopForNewEndpoint() { final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/"); final DefaultClientRequestContext parent = new DefaultClientRequestContext( new SimpleMeterRegistry(), SessionProtocol.H2C, RequestId.random(), HttpMethod.GET, - RequestTarget.forClient("/"), ClientOptions.of(), request, null, RequestOptions.of(), 0, 0); - parent.init(group); + RequestTarget.forClient("/"), group, ClientOptions.of(), request, null, RequestOptions.of(), + 0, 0); + parent.init(); assertThat(parent.endpoint()).isEqualTo(endpointA); final ClientRequestContext child = ClientUtil.newDerivedContext(parent, request, null, false); @@ -70,8 +71,9 @@ void shouldAcquireSameEventLoopForSameEndpoint() { final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/"); final DefaultClientRequestContext parent = new DefaultClientRequestContext( new SimpleMeterRegistry(), SessionProtocol.H2C, RequestId.random(), HttpMethod.GET, - RequestTarget.forClient("/"), ClientOptions.of(), request, null, RequestOptions.of(), 0, 0); - parent.init(group); + RequestTarget.forClient("/"), group, + ClientOptions.of(), request, null, RequestOptions.of(), 0, 0); + parent.init(); assertThat(parent.endpoint()).isEqualTo(endpointA); final ClientRequestContext childA0 = ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, true); @@ -104,8 +106,9 @@ void shouldNotAcquireNewEventLoopForInitialAttempt() { final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/"); final DefaultClientRequestContext parent = new DefaultClientRequestContext( new SimpleMeterRegistry(), SessionProtocol.H2C, RequestId.random(), HttpMethod.GET, - RequestTarget.forClient("/"), ClientOptions.of(), request, null, RequestOptions.of(), 0, 0); - parent.init(group); + RequestTarget.forClient("/"), group, ClientOptions.of(), + request, null, RequestOptions.of(), 0, 0); + parent.init(); assertThat(parent.endpoint()).isEqualTo(endpointA); final ClientRequestContext child = ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, true); diff --git a/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroupBuilder.java b/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroupBuilder.java index 26857db0c01..2ecaa12e71d 100644 --- a/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroupBuilder.java +++ b/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroupBuilder.java @@ -42,8 +42,10 @@ import com.linecorp.armeria.client.DecoratingRpcClientFunction; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreprocessor; import com.linecorp.armeria.client.ResponseTimeoutMode; import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreprocessor; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.endpoint.AbstractDynamicEndpointGroupBuilder; import com.linecorp.armeria.client.endpoint.DynamicEndpointGroupSetters; @@ -438,6 +440,17 @@ public EurekaEndpointGroupBuilder responseTimeoutMode(ResponseTimeoutMode respon return (EurekaEndpointGroupBuilder) super.responseTimeoutMode(responseTimeoutMode); } + @Override + public EurekaEndpointGroupBuilder preprocessor(HttpPreprocessor decorator) { + return (EurekaEndpointGroupBuilder) super.preprocessor(decorator); + } + + @Override + @Deprecated + public EurekaEndpointGroupBuilder rpcPreprocessor(RpcPreprocessor rpcPreprocessor) { + return (EurekaEndpointGroupBuilder) super.rpcPreprocessor(rpcPreprocessor); + } + @Override public EurekaEndpointGroupBuilder allowEmptyEndpoints(boolean allowEmptyEndpoints) { dynamicEndpointGroupBuilder.allowEmptyEndpoints(allowEmptyEndpoints); diff --git a/eureka/src/main/java/com/linecorp/armeria/server/eureka/EurekaUpdatingListenerBuilder.java b/eureka/src/main/java/com/linecorp/armeria/server/eureka/EurekaUpdatingListenerBuilder.java index 9b48c079eb7..ecd9855636a 100644 --- a/eureka/src/main/java/com/linecorp/armeria/server/eureka/EurekaUpdatingListenerBuilder.java +++ b/eureka/src/main/java/com/linecorp/armeria/server/eureka/EurekaUpdatingListenerBuilder.java @@ -41,8 +41,10 @@ import com.linecorp.armeria.client.DecoratingRpcClientFunction; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreprocessor; import com.linecorp.armeria.client.ResponseTimeoutMode; import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreprocessor; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.retry.RetryRule; @@ -557,4 +559,15 @@ public EurekaUpdatingListenerBuilder contextCustomizer( public EurekaUpdatingListenerBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (EurekaUpdatingListenerBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + @Override + public EurekaUpdatingListenerBuilder preprocessor(HttpPreprocessor decorator) { + return (EurekaUpdatingListenerBuilder) super.preprocessor(decorator); + } + + @Override + @Deprecated + public EurekaUpdatingListenerBuilder rpcPreprocessor(RpcPreprocessor rpcPreprocessor) { + return (EurekaUpdatingListenerBuilder) super.rpcPreprocessor(rpcPreprocessor); + } } diff --git a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java index 428dcb16be4..dedcd2f49fc 100644 --- a/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java +++ b/grpc/src/main/java/com/linecorp/armeria/client/grpc/GrpcClientBuilder.java @@ -55,8 +55,10 @@ import com.linecorp.armeria.client.DecoratingRpcClientFunction; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreprocessor; import com.linecorp.armeria.client.ResponseTimeoutMode; import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreprocessor; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.redirect.RedirectConfig; import com.linecorp.armeria.common.RequestContext; @@ -601,6 +603,18 @@ public GrpcClientBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeout return (GrpcClientBuilder) super.responseTimeoutMode(responseTimeoutMode); } + @Override + public GrpcClientBuilder preprocessor(HttpPreprocessor decorator) { + return (GrpcClientBuilder) super.preprocessor(decorator); + } + + @Override + @Deprecated + public GrpcClientBuilder rpcPreprocessor(RpcPreprocessor decorator) { + throw new UnsupportedOperationException("rpcPreprocessor() does not support gRPC. " + + "Use preprocessor() instead."); + } + /** * Sets the specified {@link GrpcExceptionHandlerFunction} that maps a {@link Throwable} * to a gRPC {@link Status}. diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java index a75dea16adf..9150cca90f9 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaChannel.java @@ -21,13 +21,16 @@ import java.net.URI; import java.util.EnumMap; import java.util.Map; +import java.util.function.BiFunction; import com.google.common.base.Strings; import com.google.common.collect.Maps; import com.linecorp.armeria.client.ClientBuilderParams; import com.linecorp.armeria.client.ClientOptions; +import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreClient; import com.linecorp.armeria.client.RequestOptions; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.grpc.GrpcClientOptions; @@ -36,6 +39,7 @@ import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpRequestWriter; +import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.RequestHeadersBuilder; import com.linecorp.armeria.common.RequestTarget; @@ -49,8 +53,10 @@ import com.linecorp.armeria.common.util.SystemInfo; import com.linecorp.armeria.common.util.Unwrappable; import com.linecorp.armeria.internal.client.DefaultClientRequestContext; +import com.linecorp.armeria.internal.client.TailPreClient; import com.linecorp.armeria.internal.common.RequestTargetCache; import com.linecorp.armeria.internal.common.grpc.InternalGrpcExceptionHandler; +import com.linecorp.armeria.internal.common.grpc.StatusAndMetadata; import io.grpc.CallCredentials; import io.grpc.CallOptions; @@ -61,6 +67,7 @@ import io.grpc.DecompressorRegistry; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Status; import io.micrometer.core.instrument.MeterRegistry; import io.netty.handler.codec.http.HttpHeaderValues; @@ -166,10 +173,21 @@ public ClientCall newCall(MethodDescriptor method, CallOption client = httpClient; } + final BiFunction errorResponseFactory = + (unused, cause) -> { + final StatusAndMetadata statusAndMetadata = exceptionHandler.handle(ctx, cause); + Status status = statusAndMetadata.status(); + if (status.getDescription() == null) { + status = status.withDescription(cause.getMessage()); + } + return HttpResponse.ofFailure(status.asRuntimeException()); + }; + final HttpPreClient preClient = + options().clientPreprocessors() + .decorate(TailPreClient.of(client, HttpResponse::of, errorResponseFactory)); + return new ArmeriaClientCall<>( ctx, - params.endpointGroup(), - client, req, method, simpleMethodNames, @@ -183,7 +201,9 @@ public ClientCall newCall(MethodDescriptor method, CallOption jsonMarshaller, unsafeWrapResponseBuffers, exceptionHandler, - useMethodMarshaller); + useMethodMarshaller, + preClient, + errorResponseFactory); } @Override @@ -248,6 +268,7 @@ private DefaultClientRequestContext newContext(HttpMethod method, HttpReq options().requestIdGenerator().get(), method, reqTarget, + endpointGroup(), options(), req, null, diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java index d00c4d572cc..46b138d1018 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/client/grpc/ArmeriaClientCall.java @@ -15,7 +15,6 @@ */ package com.linecorp.armeria.internal.client.grpc; -import static com.linecorp.armeria.internal.client.ClientUtil.initContextAndExecuteWithFallback; import static com.linecorp.armeria.internal.client.grpc.protocol.InternalGrpcWebUtil.messageBuf; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -39,8 +38,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.linecorp.armeria.client.ClientRequestContext; -import com.linecorp.armeria.client.HttpClient; -import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.client.HttpPreClient; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpRequestWriter; @@ -61,6 +59,7 @@ import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.common.util.TimeoutMode; +import com.linecorp.armeria.internal.client.ClientUtil; import com.linecorp.armeria.internal.client.DefaultClientRequestContext; import com.linecorp.armeria.internal.client.grpc.protocol.InternalGrpcWebUtil; import com.linecorp.armeria.internal.common.grpc.ForwardingCompressor; @@ -109,8 +108,7 @@ final class ArmeriaClientCall extends ClientCall ArmeriaClientCall.class, Runnable.class, "pendingTask"); private final DefaultClientRequestContext ctx; - private final EndpointGroup endpointGroup; - private final HttpClient httpClient; + private final BiFunction errorResponseFactory; private final HttpRequestWriter req; private final MethodDescriptor method; private final Map, String> simpleMethodNames; @@ -126,6 +124,7 @@ final class ArmeriaClientCall extends ClientCall private final boolean grpcWebText; private final Compressor compressor; private final InternalGrpcExceptionHandler exceptionHandler; + private final HttpPreClient preClient; private boolean endpointInitialized; @Nullable @@ -146,8 +145,6 @@ final class ArmeriaClientCall extends ClientCall ArmeriaClientCall( DefaultClientRequestContext ctx, - EndpointGroup endpointGroup, - HttpClient httpClient, HttpRequestWriter req, MethodDescriptor method, Map, String> simpleMethodNames, @@ -161,10 +158,10 @@ final class ArmeriaClientCall extends ClientCall @Nullable GrpcJsonMarshaller jsonMarshaller, boolean unsafeWrapResponseBuffers, InternalGrpcExceptionHandler exceptionHandler, - boolean useMethodMarshaller) { + boolean useMethodMarshaller, + HttpPreClient preClient, + BiFunction errorResponseFactory) { this.ctx = ctx; - this.endpointGroup = endpointGroup; - this.httpClient = httpClient; this.req = req; this.method = method; this.simpleMethodNames = simpleMethodNames; @@ -177,6 +174,8 @@ final class ArmeriaClientCall extends ClientCall grpcWebText = GrpcSerializationFormats.isGrpcWebText(serializationFormat); this.maxInboundMessageSizeBytes = maxInboundMessageSizeBytes; this.exceptionHandler = exceptionHandler; + this.preClient = preClient; + this.errorResponseFactory = errorResponseFactory; ctx.whenInitialized().handle((unused1, unused2) -> { runPendingTask(); @@ -245,19 +244,8 @@ public void start(Listener responseListener, Metadata metadata) { } // Must come after handling deadline. - prepareHeaders(compressor, metadata, remainingNanos); - - final BiFunction errorResponseFactory = - (unused, cause) -> { - final StatusAndMetadata statusAndMetadata = exceptionHandler.handle(ctx, cause); - Status status = statusAndMetadata.status(); - if (status.getDescription() == null) { - status = status.withDescription(cause.getMessage()); - } - return HttpResponse.ofFailure(status.asRuntimeException()); - }; - final HttpResponse res = initContextAndExecuteWithFallback( - httpClient, ctx, endpointGroup, HttpResponse::of, errorResponseFactory); + final HttpRequest newReq = prepareHeaders(compressor, metadata, remainingNanos); + final HttpResponse res = ClientUtil.executeWithFallback(preClient, ctx, newReq, errorResponseFactory); final HttpStreamDeframer deframer = new HttpStreamDeframer( decompressorRegistry, ctx, this, exceptionHandler, @@ -493,7 +481,7 @@ public void transportReportHeaders(Metadata metadata) { }); } - private void prepareHeaders(Compressor compressor, Metadata metadata, long remainingNanos) { + private HttpRequest prepareHeaders(Compressor compressor, Metadata metadata, long remainingNanos) { final RequestHeadersBuilder newHeaders = req.headers().toBuilder(); if (compressor != Identity.NONE) { newHeaders.set(GrpcHeaderNames.GRPC_ENCODING, compressor.getMessageEncoding()); @@ -512,6 +500,7 @@ private void prepareHeaders(Compressor compressor, Metadata metadata, long remai final HttpRequest newReq = req.withHeaders(newHeaders); ctx.updateRequest(newReq); + return newReq; } private void closeWhenListenerThrows(Throwable t) { diff --git a/oauth2/src/main/java/com/linecorp/armeria/client/auth/oauth2/OAuth2Client.java b/oauth2/src/main/java/com/linecorp/armeria/client/auth/oauth2/OAuth2Client.java index 34a521f1793..5d70210d20f 100644 --- a/oauth2/src/main/java/com/linecorp/armeria/client/auth/oauth2/OAuth2Client.java +++ b/oauth2/src/main/java/com/linecorp/armeria/client/auth/oauth2/OAuth2Client.java @@ -72,7 +72,7 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex HttpHeaderNames.AUTHORIZATION, token.authorization()).build()); ctx.updateRequest(newReq); return executeWithFallback(unwrap(), ctx, - (context, cause0) -> HttpResponse.ofFailure(cause0)); + (context, cause0) -> HttpResponse.ofFailure(cause0), newReq); }); return HttpResponse.of(future); } diff --git a/retrofit2/src/main/java/com/linecorp/armeria/client/retrofit2/ArmeriaRetrofitBuilder.java b/retrofit2/src/main/java/com/linecorp/armeria/client/retrofit2/ArmeriaRetrofitBuilder.java index 66d46a84b55..2cff44b6f65 100644 --- a/retrofit2/src/main/java/com/linecorp/armeria/client/retrofit2/ArmeriaRetrofitBuilder.java +++ b/retrofit2/src/main/java/com/linecorp/armeria/client/retrofit2/ArmeriaRetrofitBuilder.java @@ -43,8 +43,10 @@ import com.linecorp.armeria.client.DecoratingRpcClientFunction; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreprocessor; import com.linecorp.armeria.client.ResponseTimeoutMode; import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreprocessor; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.redirect.RedirectConfig; @@ -456,4 +458,14 @@ public ArmeriaRetrofitBuilder contextCustomizer( public ArmeriaRetrofitBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (ArmeriaRetrofitBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + @Override + public ArmeriaRetrofitBuilder preprocessor(HttpPreprocessor decorator) { + return (ArmeriaRetrofitBuilder) super.preprocessor(decorator); + } + + @Override + public ArmeriaRetrofitBuilder rpcPreprocessor(RpcPreprocessor decorator) { + return (ArmeriaRetrofitBuilder) super.rpcPreprocessor(decorator); + } } diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java index 7591c0d71aa..17f31b6c8d8 100644 --- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java @@ -42,8 +42,10 @@ import com.linecorp.armeria.client.DecoratingRpcClientFunction; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.HttpPreprocessor; import com.linecorp.armeria.client.ResponseTimeoutMode; import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreprocessor; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.redirect.RedirectConfig; import com.linecorp.armeria.common.RequestId; @@ -379,4 +381,16 @@ public ThriftClientBuilder contextCustomizer( public ThriftClientBuilder responseTimeoutMode(ResponseTimeoutMode responseTimeoutMode) { return (ThriftClientBuilder) super.responseTimeoutMode(responseTimeoutMode); } + + @Override + @Deprecated + public ThriftClientBuilder preprocessor(HttpPreprocessor decorator) { + throw new UnsupportedOperationException("preprocessor() does not support Thrift. " + + "Use rpcPreprocessor() instead."); + } + + @Override + public ThriftClientBuilder rpcPreprocessor(RpcPreprocessor decorator) { + return (ThriftClientBuilder) super.rpcPreprocessor(decorator); + } } diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java index 7149ce3c327..2445fa0ef0a 100644 --- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java @@ -24,6 +24,7 @@ import com.linecorp.armeria.client.ClientBuilderParams; import com.linecorp.armeria.client.RequestOptions; import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.RpcPreClient; import com.linecorp.armeria.client.UserClient; import com.linecorp.armeria.client.thrift.THttpClient; import com.linecorp.armeria.common.ExchangeType; @@ -32,6 +33,9 @@ import com.linecorp.armeria.common.RpcRequest; import com.linecorp.armeria.common.RpcResponse; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.client.ClientUtil; +import com.linecorp.armeria.internal.client.DefaultClientRequestContext; +import com.linecorp.armeria.internal.client.TailPreClient; import com.linecorp.armeria.internal.common.RequestTargetCache; import io.micrometer.core.instrument.MeterRegistry; @@ -43,9 +47,14 @@ final class DefaultTHttpClient extends UserClient imple .exchangeType(ExchangeType.UNARY) .build(); + private final RpcPreClient preClient; + DefaultTHttpClient(ClientBuilderParams params, RpcClient delegate, MeterRegistry meterRegistry) { super(params, delegate, meterRegistry, RpcResponse::from, (ctx, cause) -> RpcResponse.ofFailure(decodeException(cause, null))); + final RpcPreClient tailPreClient = + TailPreClient.ofRpc(unwrap(), futureConverter(), errorResponseFactory()); + preClient = options().clientPreprocessors().rpcDecorate(tailPreClient); } @Override @@ -77,8 +86,10 @@ private RpcResponse execute0( RequestTargetCache.putForClient(path, reqTarget); final RpcRequest call = RpcRequest.of(serviceType, method, args); - return execute(scheme().sessionProtocol(), HttpMethod.POST, - reqTarget, call, UNARY_REQUEST_OPTIONS); + final DefaultClientRequestContext ctx = new DefaultClientRequestContext( + scheme().sessionProtocol(), null, HttpMethod.POST, call, reqTarget, endpointGroup(), + UNARY_REQUEST_OPTIONS, options(), meterRegistry()); + return ClientUtil.executeWithFallback(preClient, ctx, call, errorResponseFactory()); } @Override diff --git a/thrift/thrift0.13/src/test/java/com/linecorp/armeria/client/thrift/RpcPreprocessorTest.java b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/client/thrift/RpcPreprocessorTest.java new file mode 100644 index 00000000000..3bdb2f503a9 --- /dev/null +++ b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/client/thrift/RpcPreprocessorTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.thrift; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.ClientRequestContextCaptor; +import com.linecorp.armeria.client.Clients; +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.PreClient; +import com.linecorp.armeria.client.PreClientRequestContext; +import com.linecorp.armeria.client.RpcPreprocessor; +import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.RpcResponse; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.testing.junit5.common.EventLoopExtension; + +import testing.thrift.main.HelloService; + +class RpcPreprocessorTest { + + @RegisterExtension + static final EventLoopExtension eventLoop = new EventLoopExtension(); + + @Test + void overwriteByCustomPreprocessor() throws Exception { + final RpcPreprocessor rpcPreprocessor = + RpcPreprocessor.of(SessionProtocol.HTTP, Endpoint.of("127.0.0.1"), + eventLoop.get()); + final HelloService.Iface iface = + ThriftClients.builder("http://127.0.0.2") + .rpcPreprocessor(rpcPreprocessor) + .rpcDecorator((delegate, ctx, req) -> RpcResponse.of("world")) + .build(HelloService.Iface.class); + final ClientRequestContext ctx; + try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + assertThat(iface.hello("world")).isEqualTo("world"); + ctx = captor.get(); + } + assertThat(ctx.sessionProtocol()).isEqualTo(SessionProtocol.HTTP); + assertThat(ctx.authority()).isEqualTo("127.0.0.1"); + assertThat(ctx.eventLoop().withoutContext()).isSameAs(eventLoop.get()); + } + + @Test + void preprocessorOrder() throws Exception { + final List list = new ArrayList<>(); + final RpcPreprocessor p1 = RunnablePreprocessor.of(() -> list.add("1")); + final RpcPreprocessor p2 = RunnablePreprocessor.of(() -> list.add("2")); + final RpcPreprocessor p3 = RunnablePreprocessor.of(() -> list.add("3")); + + final HelloService.Iface iface = + ThriftClients.builder("http://127.0.0.2") + .rpcPreprocessor(p1) + .rpcPreprocessor(p2) + .rpcPreprocessor(p3) + .rpcDecorator((delegate, ctx, req) -> RpcResponse.of("world")) + .build(HelloService.Iface.class); + assertThat(iface.hello("world")).isEqualTo("world"); + assertThat(list).containsExactly("3", "2", "1"); + } + + private static final class RunnablePreprocessor implements RpcPreprocessor { + + private static RpcPreprocessor of(Runnable runnable) { + return new RunnablePreprocessor(runnable); + } + + private final Runnable runnable; + + private RunnablePreprocessor(Runnable runnable) { + this.runnable = runnable; + } + + @Override + public RpcResponse execute(PreClient delegate, + PreClientRequestContext ctx, RpcRequest req) throws Exception { + runnable.run(); + return delegate.execute(ctx, req); + } + } +}