From 3963f2a042803ec9b33939e89a4cfd3d9b1eebcc Mon Sep 17 00:00:00 2001 From: Ikhun Um Date: Mon, 22 Jan 2024 13:47:38 +0900 Subject: [PATCH] Make `WebSocketService` interface (#5368) Motivation: `WebsocketService` is a final class, so it is not easy to make another service run WebSocket and HTTP requests on the same path. `WebSocketService` is changed as a marker interface in #5280. However, the interface does not have API, the service cannot take advantage of the protocol detection of the original `WebSocketService`. Graphql WebSocket implementation copied and pasted similar code for WebSocket detection. The design of the code here was determined after discussions with other maintainers based on the idea from https://github.com/line/armeria/pull/5280#discussion_r1440384188 Modifications: - Add `WebSocketService` interface and rename the original one to `DefaultWebSocketService` - Add `WebSocketProtocolHandler` that is in charge of: - Upgrade HTTP request to WebSocket - Decode HTTP request to WebSocket for inbound - Encode WebSocket to HTTP response for outbound - Refactor `DefaultWebSocketService` to implement new methods `WebSocketService` and `WebSocketProtocolHandler` - Modify Webservicebuilder so that a fallback service can be configured. - The fallback service would be useful when WebSocket and HTTP request is served at the same path. Result: - You can now set a fallback service to `WebSocketService` via `WebSocketServiceBuilder.fallbackService()` - Refactor `WebsocketService` to enable composion and delegatation. --- .../WebSocketClientFrameDecoder.java | 2 +- .../armeria/common/websocket/WebSocket.java | 9 + .../websocket/WebSocketFrameDecoder.java | 5 +- .../websocket/DefaultWebSocketService.java | 385 ++++++++++++++++++ .../WebSocketServiceFrameDecoder.java | 4 +- .../server/websocket/package-info.java | 23 ++ .../armeria/server/ServiceConfigBuilder.java | 4 +- .../websocket/WebSocketProtocolHandler.java | 48 +++ .../server/websocket/WebSocketService.java | 298 +------------- .../websocket/WebSocketServiceBuilder.java | 17 +- .../websocket/WebSocketUpgradeResult.java | 84 ++++ .../WebSocketFrameEncoderAndDecoderTest.java | 11 +- .../DelegatingWebSocketServiceTest.java | 149 +++++++ 13 files changed, 742 insertions(+), 297 deletions(-) create mode 100644 core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java rename core/src/main/java/com/linecorp/armeria/{ => internal}/server/websocket/WebSocketServiceFrameDecoder.java (93%) create mode 100644 core/src/main/java/com/linecorp/armeria/internal/server/websocket/package-info.java create mode 100644 core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketProtocolHandler.java create mode 100644 core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketUpgradeResult.java create mode 100644 core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java index 7390953ac5f..e35061099de 100644 --- a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java @@ -25,7 +25,7 @@ final class WebSocketClientFrameDecoder extends WebSocketFrameDecoder { WebSocketClientFrameDecoder(ClientRequestContext ctx, int maxFramePayloadLength, boolean allowMaskMismatch) { - super(ctx, maxFramePayloadLength, allowMaskMismatch); + super(maxFramePayloadLength, allowMaskMismatch); this.ctx = ctx; } diff --git a/core/src/main/java/com/linecorp/armeria/common/websocket/WebSocket.java b/core/src/main/java/com/linecorp/armeria/common/websocket/WebSocket.java index 30d31f33793..10ae7240377 100644 --- a/core/src/main/java/com/linecorp/armeria/common/websocket/WebSocket.java +++ b/core/src/main/java/com/linecorp/armeria/common/websocket/WebSocket.java @@ -17,6 +17,7 @@ import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper; /** * A {@link StreamMessage} that publishes {@link WebSocketFrame}s. @@ -30,4 +31,12 @@ public interface WebSocket extends StreamMessage { static WebSocketWriter streaming() { return new DefaultWebSocket(); } + + /** + * Returns a new {@link WebSocket} whose stream is produced from the specified {@link StreamMessage}. + */ + @UnstableApi + static WebSocket of(StreamMessage delegate) { + return new WebSocketWrapper(delegate); + } } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java index ad31efcb993..3e262bfe3c0 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java @@ -34,7 +34,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.stream.HttpDecoder; import com.linecorp.armeria.common.stream.StreamDecoderInput; @@ -64,7 +63,6 @@ enum State { CORRUPT } - private final RequestContext ctx; private final int maxFramePayloadLength; private final boolean allowMaskMismatch; @Nullable @@ -81,8 +79,7 @@ enum State { private boolean receivedClosingHandshake; private State state = State.READING_FIRST; - protected WebSocketFrameDecoder(RequestContext ctx, int maxFramePayloadLength, boolean allowMaskMismatch) { - this.ctx = ctx; + protected WebSocketFrameDecoder(int maxFramePayloadLength, boolean allowMaskMismatch) { this.maxFramePayloadLength = maxFramePayloadLength; this.allowMaskMismatch = allowMaskMismatch; } diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java new file mode 100644 index 00000000000..a0e733351e7 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/DefaultWebSocketService.java @@ -0,0 +1,385 @@ +/* + * Copyright 2022 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.internal.server.websocket; + +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.generateSecWebSocketAccept; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.isHttp1WebSocketUpgradeRequest; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.isHttp2WebSocketUpgradeRequest; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.newCloseWebSocketFrame; + +import java.util.Set; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Splitter; +import com.google.common.net.HostAndPort; + +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.ResponseHeadersBuilder; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.stream.ClosedStreamException; +import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder; +import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper; +import com.linecorp.armeria.server.HttpService; +import com.linecorp.armeria.server.ServiceConfig; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.websocket.WebSocketProtocolHandler; +import com.linecorp.armeria.server.websocket.WebSocketService; +import com.linecorp.armeria.server.websocket.WebSocketServiceBuilder; +import com.linecorp.armeria.server.websocket.WebSocketServiceHandler; +import com.linecorp.armeria.server.websocket.WebSocketUpgradeResult; + +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import io.netty.util.AttributeKey; + +/** + * An {@link HttpService} that supports + * The WebSocket Protocol. + * This service has a few different default values for {@link ServiceConfig} from a normal {@link HttpService} + * because of the nature of WebSocket. See {@link WebSocketServiceBuilder} for more information. + */ +public final class DefaultWebSocketService implements WebSocketService, WebSocketProtocolHandler { + + private static final Logger logger = LoggerFactory.getLogger(DefaultWebSocketService.class); + + private static final AttributeKey DECODER = + AttributeKey.valueOf(DefaultWebSocketService.class, "DECODER"); + + private static final String SUB_PROTOCOL_WILDCARD = "*"; + + private static final ResponseHeaders UNSUPPORTED_WEB_SOCKET_VERSION = + ResponseHeaders.builder(HttpStatus.BAD_REQUEST) + .add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue()) + .contentType(MediaType.PLAIN_TEXT_UTF_8) + .build(); + + private static final Splitter commaSplitter = Splitter.on(',').trimResults().omitEmptyStrings(); + + // Server-side encoder do not mask the payloads. + private static final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(false); + + private final WebSocketServiceHandler handler; + @Nullable + private final HttpService fallbackService; + private final int maxFramePayloadLength; + private final boolean allowMaskMismatch; + private final Set subprotocols; + private final Set allowedOrigins; + private final boolean allowAnyOrigin; + + public DefaultWebSocketService(WebSocketServiceHandler handler, @Nullable HttpService fallbackService, + int maxFramePayloadLength, boolean allowMaskMismatch, + Set subprotocols, Set allowedOrigins, + boolean allowAnyOrigin) { + this.handler = handler; + this.fallbackService = fallbackService; + this.maxFramePayloadLength = maxFramePayloadLength; + this.allowMaskMismatch = allowMaskMismatch; + this.subprotocols = subprotocols; + this.allowedOrigins = allowedOrigins; + this.allowAnyOrigin = allowAnyOrigin; + } + + @Override + public WebSocket serve(ServiceRequestContext ctx, WebSocket in) throws Exception { + return handler.handle(ctx, in); + } + + @Override + public WebSocketUpgradeResult upgrade(ServiceRequestContext ctx, HttpRequest req) throws Exception { + final HttpMethod method = ctx.method(); + switch (method) { + case GET: + return upgradeHttp1(ctx, req); + case CONNECT: + return upgradeHttp2(ctx, req); + default: + final HttpResponse httpResponse = + failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED)); + return WebSocketUpgradeResult.ofFailure(httpResponse); + } + } + + /** + * Handles the HTTP/1.1 web socket handshake described in + * The WebSocket Protocol. + * These are examples of a request and its corresponding response: + * + *

Request: + *

+     * GET /chat HTTP/1.1
+     * Host: server.example.com
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+     * Origin: http://example.com
+     * Sec-WebSocket-Protocol: chat, superchat
+     * Sec-WebSocket-Version: 13
+     * 
+ * + *

Response: + *

+     * HTTP/1.1 101 Switching Protocols
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+     * Sec-WebSocket-Protocol: chat
+     * 
+ */ + private WebSocketUpgradeResult upgradeHttp1(ServiceRequestContext ctx, HttpRequest req) throws Exception { + if (!ctx.sessionProtocol().isExplicitHttp1()) { + final HttpResponse httpResponse = + failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED)); + return WebSocketUpgradeResult.ofFailure(httpResponse); + } + final RequestHeaders headers = req.headers(); + if (!isHttp1WebSocketUpgradeRequest(headers)) { + final HttpResponse httpResponse = + failOrFallback(ctx, req, () -> HttpResponse.of( + HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, + "The upgrade header must contain:\n" + + " Upgrade: websocket\n" + + " Connection: Upgrade")); + return WebSocketUpgradeResult.ofFailure(httpResponse); + } + + HttpResponse invalidResponse = checkOrigin(ctx, headers); + if (invalidResponse != null) { + return WebSocketUpgradeResult.ofFailure(invalidResponse); + } + + invalidResponse = checkVersion(headers); + if (invalidResponse != null) { + return WebSocketUpgradeResult.ofFailure(invalidResponse); + } + + final String webSocketKey = headers.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, ""); + if (webSocketKey.isEmpty()) { + return WebSocketUpgradeResult.ofFailure( + HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, + "missing Sec-WebSocket-Key header")); + } + return WebSocketUpgradeResult.ofSuccess(); + } + + private HttpResponse failOrFallback(ServiceRequestContext ctx, HttpRequest req, + Supplier invalidResponse) throws Exception { + if (fallbackService != null) { + return fallbackService.serve(ctx, req); + } else { + return invalidResponse.get(); + } + } + + private void maybeAddSubprotocol(RequestHeaders headers, + ResponseHeadersBuilder responseHeadersBuilder) { + final String subprotocols = headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, ""); + if (subprotocols.isEmpty()) { + return; + } + commaSplitter.splitToStream(subprotocols) + .filter(sub -> SUB_PROTOCOL_WILDCARD.equals(sub) || + this.subprotocols.contains(sub)) + .findFirst().ifPresent(selectedSubprotocol -> responseHeadersBuilder.add( + HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol)); + } + + /** + * Handles the HTTP/2 web socket handshake described in + * Bootstrapping WebSockets with HTTP/2. + * These are examples of a request and the corresponding response: + * + *

Request: + *

+     * HEADERS + END_HEADERS
+     * :method = CONNECT
+     * :protocol = websocket
+     * :scheme = https
+     * :path = /chat
+     * :authority = server.example.com
+     * sec-websocket-protocol = chat, superchat
+     * sec-websocket-extensions = permessage-deflate
+     * sec-websocket-version = 13
+     * origin = http://www.example.com
+     * 
+ * + *

Response: + *

+     * HEADERS + END_HEADERS
+     * :status = 200
+     * sec-websocket-protocol = chat
+     * 
+ */ + private WebSocketUpgradeResult upgradeHttp2(ServiceRequestContext ctx, HttpRequest req) throws Exception { + if (!ctx.sessionProtocol().isExplicitHttp2()) { + final HttpResponse fallbackResponse = + failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED)); + return WebSocketUpgradeResult.ofFailure(fallbackResponse); + } + final RequestHeaders headers = req.headers(); + if (!isHttp2WebSocketUpgradeRequest(headers)) { + logger.trace("RequestHeaders does not contain headers for WebSocket upgrade. headers: {}", headers); + final HttpResponse fallbackResponse = failOrFallback(ctx, req, () -> HttpResponse.of( + HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, + "The upgrade header must contain:\n" + + " :protocol = websocket")); + return WebSocketUpgradeResult.ofFailure(fallbackResponse); + } + + HttpResponse invalidResponse = checkOrigin(ctx, headers); + if (invalidResponse != null) { + return WebSocketUpgradeResult.ofFailure(invalidResponse); + } + + invalidResponse = checkVersion(headers); + if (invalidResponse != null) { + return WebSocketUpgradeResult.ofFailure(invalidResponse); + } + + return WebSocketUpgradeResult.ofSuccess(); + } + + @Nullable + private HttpResponse checkOrigin(ServiceRequestContext ctx, RequestHeaders headers) { + if (allowAnyOrigin) { + return null; + } + final String origin = headers.get(HttpHeaderNames.ORIGIN, ""); + if (origin.isEmpty()) { + return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, + "missing the origin header"); + } + + if (allowedOrigins.isEmpty()) { + // Only the same-origin is allowed. + if (!isSameOrigin(ctx, headers, origin)) { + return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, + "not allowed origin: " + origin); + } + return null; + } + if (!allowedOrigins.contains(origin)) { + return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, + "not allowed origin: " + origin + ", allowed: " + allowedOrigins); + } + return null; + } + + private static boolean isSameOrigin(ServiceRequestContext ctx, RequestHeaders headers, String origin) { + final int schemeDelimiter = origin.indexOf("://"); + if (schemeDelimiter < 0) { + return false; + } + + final String scheme = origin.substring(0, schemeDelimiter); + final SessionProtocol originSessionProtocol = SessionProtocol.find(scheme); + if (originSessionProtocol == null) { + return false; + } + + if ((ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp()) || + (ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) { + // The same scheme. + } else { + return false; + } + + final String authority = headers.authority(); + assert authority != null; + final HostAndPort authorityHostAndPort = HostAndPort.fromString(authority); + final String authorityHost = authorityHostAndPort.getHost(); + final int authorityPort = authorityHostAndPort.getPortOrDefault( + ctx.sessionProtocol().defaultPort()); + + final HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3)); + final String originHost = originHostAndPort.getHost(); + final int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort()); + + return authorityPort == originPort && authorityHost.equals(originHost); + } + + @Nullable + private static HttpResponse checkVersion(RequestHeaders headers) { + // Currently we only support v13. + final String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION); + if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) { + return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION, + HttpData.ofUtf8("Only 13 version is supported.")); + } + return null; + } + + @Override + public WebSocket decode(ServiceRequestContext ctx, HttpRequest req) { + final WebSocketServiceFrameDecoder decoder = + new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); + ctx.setAttr(DECODER, decoder); + return new WebSocketWrapper(req.decode(decoder, ctx.alloc())); + } + + @Override + public HttpResponse encode(ServiceRequestContext ctx, WebSocket out) { + final RequestHeaders requestHeaders = ctx.request().headers(); + final ResponseHeadersBuilder responseHeadersBuilder; + if (ctx.sessionProtocol().isExplicitHttp1()) { + final String webSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, ""); + final String accept = generateSecWebSocketAccept(webSocketKey); + responseHeadersBuilder = + ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS) + .add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()) + .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()) + .add(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); + } else { + // As described in https://datatracker.ietf.org/doc/html/rfc8441#section-5, + // HTTP/2 does not use Sec-WebSocket-Key and Sec-WebSocket-Accept headers. + responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK); + } + maybeAddSubprotocol(requestHeaders, responseHeadersBuilder); + + final WebSocketServiceFrameDecoder decoder = ctx.attr(DECODER); + assert decoder != null; + decoder.setOutboundWebSocket(out); + final StreamMessage data = + out.recoverAndResume(cause -> { + if (cause instanceof ClosedStreamException) { + return StreamMessage.aborted(cause); + } + ctx.logBuilder().responseCause(cause); + return StreamMessage.of(newCloseWebSocketFrame(cause)); + }) + .map(frame -> HttpData.wrap(encoder.encode(ctx, frame))); + return HttpResponse.of(responseHeadersBuilder.build(), data); + } + + @Override + public WebSocketProtocolHandler protocolHandler() { + return this; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/WebSocketServiceFrameDecoder.java similarity index 93% rename from core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java rename to core/src/main/java/com/linecorp/armeria/internal/server/websocket/WebSocketServiceFrameDecoder.java index 28bbd71aa2c..5794961591a 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/WebSocketServiceFrameDecoder.java @@ -13,7 +13,7 @@ * License for the specific language governing permissions and limitations * under the License. */ -package com.linecorp.armeria.server.websocket; +package com.linecorp.armeria.internal.server.websocket; import com.linecorp.armeria.common.HttpRequestWriter; import com.linecorp.armeria.common.Request; @@ -27,7 +27,7 @@ final class WebSocketServiceFrameDecoder extends WebSocketFrameDecoder { WebSocketServiceFrameDecoder(ServiceRequestContext ctx, int maxFramePayloadLength, boolean allowMaskMismatch) { - super(ctx, maxFramePayloadLength, allowMaskMismatch); + super(maxFramePayloadLength, allowMaskMismatch); this.ctx = ctx; } diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/websocket/package-info.java b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/package-info.java new file mode 100644 index 00000000000..b09d3070fbe --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/server/websocket/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2016 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Various classes used internally. Anything in this package can be changed or removed at any time. + */ +@NonNullByDefault +package com.linecorp.armeria.internal.server.websocket; + +import com.linecorp.armeria.common.annotation.NonNullByDefault; diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java index 7937be9926d..4a78a286860 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java @@ -43,8 +43,8 @@ import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.internal.common.websocket.WebSocketUtil; +import com.linecorp.armeria.internal.server.websocket.DefaultWebSocketService; import com.linecorp.armeria.server.logging.AccessLogWriter; -import com.linecorp.armeria.server.websocket.WebSocketService; final class ServiceConfigBuilder implements ServiceConfigSetters { @@ -321,7 +321,7 @@ ServiceConfig build(ServiceNaming defaultServiceNaming, unhandledExceptionsReporter); } - final boolean webSocket = service.as(WebSocketService.class) != null; + final boolean webSocket = service.as(DefaultWebSocketService.class) != null; final long requestTimeoutMillis; if (this.requestTimeoutMillis != null) { requestTimeoutMillis = this.requestTimeoutMillis; diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketProtocolHandler.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketProtocolHandler.java new file mode 100644 index 00000000000..28755571bc2 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketProtocolHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.websocket; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.server.ServiceRequestContext; + +/** + * A WebSocket protocol handler for {@link WebSocketService}. + */ +@UnstableApi +public interface WebSocketProtocolHandler { + + /** + * Upgrades the given {@link HttpRequest} to a {@link WebSocket}. + * + *

If the upgrade succeeds, {@link WebSocketUpgradeResult#ofSuccess()} is returned. + * If the upgrade fails, {@link WebSocketUpgradeResult#ofFailure(HttpResponse)} is returned. + */ + WebSocketUpgradeResult upgrade(ServiceRequestContext ctx, HttpRequest req) throws Exception; + + /** + * Decodes the specified {@link HttpRequest} to a {@link WebSocket}. + */ + WebSocket decode(ServiceRequestContext ctx, HttpRequest req); + + /** + * Encodes the specified {@link WebSocket} to an {@link HttpResponse}. + */ + HttpResponse encode(ServiceRequestContext ctx, WebSocket out); +} diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java index 9c53cf2e553..ea6b36c628e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 LINE Corporation + * Copyright 2024 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance @@ -13,47 +13,17 @@ * License for the specific language governing permissions and limitations * under the License. */ -package com.linecorp.armeria.server.websocket; - -import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.generateSecWebSocketAccept; -import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.isHttp1WebSocketUpgradeRequest; -import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.isHttp2WebSocketUpgradeRequest; -import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.newCloseWebSocketFrame; - -import java.util.Set; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import com.google.common.base.Splitter; -import com.google.common.net.HostAndPort; +package com.linecorp.armeria.server.websocket; -import com.linecorp.armeria.common.HttpData; -import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; -import com.linecorp.armeria.common.HttpStatus; -import com.linecorp.armeria.common.MediaType; -import com.linecorp.armeria.common.RequestHeaders; -import com.linecorp.armeria.common.ResponseHeaders; -import com.linecorp.armeria.common.ResponseHeadersBuilder; -import com.linecorp.armeria.common.SessionProtocol; -import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; -import com.linecorp.armeria.common.stream.ClosedStreamException; -import com.linecorp.armeria.common.stream.StreamMessage; import com.linecorp.armeria.common.websocket.WebSocket; -import com.linecorp.armeria.common.websocket.WebSocketFrame; -import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder; -import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper; -import com.linecorp.armeria.server.AbstractHttpService; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.ServiceConfig; import com.linecorp.armeria.server.ServiceRequestContext; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.websocketx.WebSocketVersion; - /** * An {@link HttpService} that supports * The WebSocket Protocol. @@ -61,273 +31,41 @@ * because of the nature of WebSocket. See {@link WebSocketServiceBuilder} for more information. */ @UnstableApi -public final class WebSocketService extends AbstractHttpService { - - private static final Logger logger = LoggerFactory.getLogger(WebSocketService.class); - - private static final String SUB_PROTOCOL_WILDCARD = "*"; - - private static final ResponseHeaders UNSUPPORTED_WEB_SOCKET_VERSION = - ResponseHeaders.builder(HttpStatus.BAD_REQUEST) - .add(HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue()) - .contentType(MediaType.PLAIN_TEXT_UTF_8) - .build(); - - private static final Splitter commaSplitter = Splitter.on(',').trimResults().omitEmptyStrings(); - - // Server-side encoder do not mask the payloads. - private static final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(false); +public interface WebSocketService extends HttpService { /** * Returns a new {@link WebSocketService} with the {@link WebSocketServiceHandler}. */ - public static WebSocketService of(WebSocketServiceHandler handler) { + static WebSocketService of(WebSocketServiceHandler handler) { return new WebSocketServiceBuilder(handler).build(); } /** * Returns a new {@link WebSocketServiceBuilder} with the {@link WebSocketServiceHandler}. */ - public static WebSocketServiceBuilder builder(WebSocketServiceHandler handler) { + static WebSocketServiceBuilder builder(WebSocketServiceHandler handler) { return new WebSocketServiceBuilder(handler); } - private final WebSocketServiceHandler handler; - private final int maxFramePayloadLength; - private final boolean allowMaskMismatch; - private final Set subprotocols; - private final Set allowedOrigins; - private final boolean allowAnyOrigin; - - WebSocketService(WebSocketServiceHandler handler, int maxFramePayloadLength, boolean allowMaskMismatch, - Set subprotocols, Set allowedOrigins, boolean allowAnyOrigin) { - this.handler = handler; - this.maxFramePayloadLength = maxFramePayloadLength; - this.allowMaskMismatch = allowMaskMismatch; - this.subprotocols = subprotocols; - this.allowedOrigins = allowedOrigins; - this.allowAnyOrigin = allowAnyOrigin; - } - - /** - * Handles the HTTP/1.1 web socket handshake described in - * The WebSocket Protocol. - * These are examples of a request and its corresponding response: - * - *

Request: - *

-     * GET /chat HTTP/1.1
-     * Host: server.example.com
-     * Upgrade: websocket
-     * Connection: Upgrade
-     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
-     * Origin: http://example.com
-     * Sec-WebSocket-Protocol: chat, superchat
-     * Sec-WebSocket-Version: 13
-     * 
- * - *

Response: - *

-     * HTTP/1.1 101 Switching Protocols
-     * Upgrade: websocket
-     * Connection: Upgrade
-     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
-     * Sec-WebSocket-Protocol: chat
-     * 
- */ @Override - protected HttpResponse doGet(ServiceRequestContext ctx, HttpRequest req) throws Exception { - if (!ctx.sessionProtocol().isExplicitHttp1()) { - return HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED); - } - final RequestHeaders headers = req.headers(); - if (!isHttp1WebSocketUpgradeRequest(headers)) { - return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, - "The upgrade header must contain:\n" + - " Upgrade: websocket\n" + - " Connection: Upgrade"); - } - - HttpResponse invalidResponse = checkOrigin(ctx, headers); - if (invalidResponse != null) { - return invalidResponse; - } - - invalidResponse = checkVersion(headers); - if (invalidResponse != null) { - return invalidResponse; - } - - final String webSocketKey = headers.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, ""); - if (webSocketKey.isEmpty()) { - return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, - "missing Sec-WebSocket-Key header"); + default HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception { + final WebSocketUpgradeResult upgradeResult = protocolHandler().upgrade(ctx, req); + if (!upgradeResult.isSuccess()) { + return upgradeResult.fallbackResponse(); } - final String accept = generateSecWebSocketAccept(webSocketKey); - final ResponseHeadersBuilder responseHeadersBuilder = - ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS) - .add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()) - .add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()) - .add(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); - maybeAddSubprotocol(headers, responseHeadersBuilder); - return handleUpgradeRequest(ctx, req, responseHeadersBuilder.build()); - } - - private void maybeAddSubprotocol(RequestHeaders headers, - ResponseHeadersBuilder responseHeadersBuilder) { - final String subprotocols = headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, ""); - if (subprotocols.isEmpty()) { - return; - } - commaSplitter.splitToStream(subprotocols) - .filter(sub -> SUB_PROTOCOL_WILDCARD.equals(sub) || - this.subprotocols.contains(sub)) - .findFirst().ifPresent(selectedSubprotocol -> responseHeadersBuilder.add( - HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol)); - } - - private HttpResponse handleUpgradeRequest(ServiceRequestContext ctx, HttpRequest req, - ResponseHeaders responseHeaders) { - final WebSocketServiceFrameDecoder decoder = - new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); - final StreamMessage inboundFrames = req.decode(decoder, ctx.alloc()); - final WebSocket outboundFrames = handler.handle(ctx, new WebSocketWrapper(inboundFrames)); - decoder.setOutboundWebSocket(outboundFrames); - return HttpResponse.of( - responseHeaders, outboundFrames.recoverAndResume(cause -> { - if (cause instanceof ClosedStreamException) { - return StreamMessage.aborted(cause); - } - ctx.logBuilder().responseCause(cause); - return StreamMessage.of(newCloseWebSocketFrame(cause)); - }) - .map(frame -> HttpData.wrap(encoder.encode(ctx, frame)))); + final WebSocket in = protocolHandler().decode(ctx, req); + final WebSocket out = serve(ctx, in); + return protocolHandler().encode(ctx, out); } /** - * Handles the HTTP/2 web socket handshake described in - * Bootstrapping WebSockets with HTTP/2. - * These are examples of a request and the corresponding response: - * - *

Request: - *

-     * HEADERS + END_HEADERS
-     * :method = CONNECT
-     * :protocol = websocket
-     * :scheme = https
-     * :path = /chat
-     * :authority = server.example.com
-     * sec-websocket-protocol = chat, superchat
-     * sec-websocket-extensions = permessage-deflate
-     * sec-websocket-version = 13
-     * origin = http://www.example.com
-     * 
- * - *

Response: - *

-     * HEADERS + END_HEADERS
-     * :status = 200
-     * sec-websocket-protocol = chat
-     * 
+ * Serves the specified {@link WebSocket} and returns the {@link WebSocket} to send responses. */ - @Override - protected HttpResponse doConnect(ServiceRequestContext ctx, HttpRequest req) throws Exception { - if (!ctx.sessionProtocol().isExplicitHttp2()) { - return HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED); - } - final RequestHeaders headers = req.headers(); - if (!isHttp2WebSocketUpgradeRequest(headers)) { - logger.trace("RequestHeaders does not contain headers for WebSocket upgrade. headers: {}", headers); - return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, - "The upgrade header must contain:\n" + - " :protocol = websocket"); - } - - HttpResponse invalidResponse = checkOrigin(ctx, headers); - if (invalidResponse != null) { - return invalidResponse; - } - - invalidResponse = checkVersion(headers); - if (invalidResponse != null) { - return invalidResponse; - } - - // As described in https://datatracker.ietf.org/doc/html/rfc8441#section-5, - // HTTP/2 does not use Sec-WebSocket-Key and Sec-WebSocket-Accept headers. - - final ResponseHeadersBuilder responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK); - maybeAddSubprotocol(headers, responseHeadersBuilder); - return handleUpgradeRequest(ctx, req, responseHeadersBuilder.build()); - } - - @Nullable - private HttpResponse checkOrigin(ServiceRequestContext ctx, RequestHeaders headers) { - if (allowAnyOrigin) { - return null; - } - final String origin = headers.get(HttpHeaderNames.ORIGIN, ""); - if (origin.isEmpty()) { - return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, - "missing the origin header"); - } - - if (allowedOrigins.isEmpty()) { - // Only the same-origin is allowed. - if (!isSameOrigin(ctx, headers, origin)) { - return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, - "not allowed origin: " + origin); - } - return null; - } - if (!allowedOrigins.contains(origin)) { - return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, - "not allowed origin: " + origin + ", allowed: " + allowedOrigins); - } - return null; - } + WebSocket serve(ServiceRequestContext ctx, WebSocket in) throws Exception; - private static boolean isSameOrigin(ServiceRequestContext ctx, RequestHeaders headers, String origin) { - final int schemeDelimiter = origin.indexOf("://"); - if (schemeDelimiter < 0) { - return false; - } - - final String scheme = origin.substring(0, schemeDelimiter); - final SessionProtocol originSessionProtocol = SessionProtocol.find(scheme); - if (originSessionProtocol == null) { - return false; - } - - if ((ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp()) || - (ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) { - // The same scheme. - } else { - return false; - } - - final String authority = headers.authority(); - assert authority != null; - final HostAndPort authorityHostAndPort = HostAndPort.fromString(authority); - final String authorityHost = authorityHostAndPort.getHost(); - final int authorityPort = authorityHostAndPort.getPortOrDefault(ctx.sessionProtocol().defaultPort()); - - final HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3)); - final String originHost = originHostAndPort.getHost(); - final int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort()); - - return authorityPort == originPort && authorityHost.equals(originHost); - } - - @Nullable - private static HttpResponse checkVersion(RequestHeaders headers) { - // Currently we only support v13. - final String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION); - if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) { - return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION, - HttpData.ofUtf8("Only 13 version is supported.")); - } - return null; - } + /** + * Returns the {@link WebSocketProtocolHandler} of this service. + */ + WebSocketProtocolHandler protocolHandler(); } diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java index 4003a934257..8a22f06d2f8 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java @@ -26,9 +26,11 @@ import com.google.common.collect.ImmutableSet; +import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; import com.linecorp.armeria.internal.common.websocket.WebSocketUtil; +import com.linecorp.armeria.internal.server.websocket.DefaultWebSocketService; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.ServiceConfig; @@ -59,6 +61,8 @@ public final class WebSocketServiceBuilder { private boolean allowMaskMismatch; private Set subprotocols = ImmutableSet.of(); private Set allowedOrigins = ImmutableSet.of(); + @Nullable + private HttpService fallbackService; WebSocketServiceBuilder(WebSocketServiceHandler handler) { this.handler = requireNonNull(handler, "handler"); @@ -142,11 +146,20 @@ private static Set validateOrigins(Iterable allowedOrigins) { return copied; } + /** + * Sets the fallback {@link HttpService} to use when the request is not a valid WebSocket upgrade request. + * This is useful when you want to serve both WebSocket and HTTP requests at the same path. + */ + public WebSocketServiceBuilder fallbackService(HttpService fallbackService) { + this.fallbackService = requireNonNull(fallbackService, "fallbackService"); + return this; + } + /** * Returns a newly-created {@link WebSocketService} with the properties set so far. */ public WebSocketService build() { - return new WebSocketService(handler, maxFramePayloadLength, allowMaskMismatch, - subprotocols, allowedOrigins, allowedOrigins.contains(ANY_ORIGIN)); + return new DefaultWebSocketService(handler, fallbackService, maxFramePayloadLength, allowMaskMismatch, + subprotocols, allowedOrigins, allowedOrigins.contains(ANY_ORIGIN)); } } diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketUpgradeResult.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketUpgradeResult.java new file mode 100644 index 00000000000..9a7172ae7f8 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketUpgradeResult.java @@ -0,0 +1,84 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.websocket; + +import static java.util.Objects.requireNonNull; + +import com.google.common.base.MoreObjects; + +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.annotation.Nullable; + +/** + * The result of a WebSocket upgrade. + */ +public final class WebSocketUpgradeResult { + + private static final WebSocketUpgradeResult SUCCESS = new WebSocketUpgradeResult(null); + + /** + * Returns a successful {@link WebSocketUpgradeResult}. + */ + public static WebSocketUpgradeResult ofSuccess() { + return SUCCESS; + } + + /** + * Returns a failed {@link WebSocketUpgradeResult} with the fallback {@link HttpResponse}. + */ + public static WebSocketUpgradeResult ofFailure(HttpResponse fallbackResponse) { + requireNonNull(fallbackResponse, "failureResponse"); + return new WebSocketUpgradeResult(fallbackResponse); + } + + @Nullable + private final HttpResponse fallbackResponse; + + private WebSocketUpgradeResult(@Nullable HttpResponse fallbackResponse) { + this.fallbackResponse = fallbackResponse; + } + + /** + * Returns {@code true} if the upgrade was successful. + */ + public boolean isSuccess() { + return fallbackResponse == null; + } + + /** + * Returns the fallback {@link HttpResponse} if the upgrade failed. + * + * @throws IllegalStateException if the upgrade was successful. + */ + public HttpResponse fallbackResponse() { + if (fallbackResponse == null) { + throw new IllegalStateException("WebSocket was successfully upgraded."); + } + return fallbackResponse; + } + + @Override + public String toString() { + if (isSuccess()) { + return "WebSocketUpgradeResult(success)"; + } + + return MoreObjects.toStringHelper(this) + .add("fallback", fallbackResponse) + .toString(); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java index 409e0ebe583..b71154c8121 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java @@ -52,7 +52,6 @@ import com.linecorp.armeria.common.HttpRequestWriter; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpResponseWriter; -import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; import com.linecorp.armeria.common.websocket.WebSocketFrame; @@ -115,7 +114,7 @@ public void testWebSocketProtocolViolation() throws InterruptedException { final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(true); final HttpRequestWriter requestWriter = HttpRequest.streaming(RequestHeaders.of(HttpMethod.GET, "/")); final WebSocketFrameDecoder decoder = - new TestWebSocketFrameDecoder(ctx, maxPayloadLength, false, true); + new TestWebSocketFrameDecoder(maxPayloadLength, false, true); final CompletableFuture whenComplete = new CompletableFuture<>(); requestWriter.decode(decoder, ctx.alloc()).subscribe(subscriber(whenComplete)); @@ -142,7 +141,7 @@ public void testWebSocketEncodingAndDecoding(boolean maskPayload, boolean allowM final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(maskPayload); final HttpRequestWriter requestWriter = HttpRequest.streaming(RequestHeaders.of(HttpMethod.GET, "/")); final WebSocketFrameDecoder decoder = new TestWebSocketFrameDecoder( - ctx, 1024 * 1024, allowMaskMismatch, maskPayload); + 1024 * 1024, allowMaskMismatch, maskPayload); requestWriter.decode(decoder, ctx.alloc()).subscribe(subscriber(new CompletableFuture<>())); executeTests(encoder, requestWriter); httpResponseWriter.abort(); @@ -235,9 +234,9 @@ private static class TestWebSocketFrameDecoder extends WebSocketFrameDecoder { private final boolean expectMaskedFrames; - TestWebSocketFrameDecoder(RequestContext ctx, int maxFramePayloadLength, - boolean allowMaskMismatch, boolean expectMaskedFrames) { - super(ctx, maxFramePayloadLength, allowMaskMismatch); + TestWebSocketFrameDecoder(int maxFramePayloadLength, boolean allowMaskMismatch, + boolean expectMaskedFrames) { + super(maxFramePayloadLength, allowMaskMismatch); this.expectMaskedFrames = expectMaskedFrames; } diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java new file mode 100644 index 00000000000..0628138c968 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/DelegatingWebSocketServiceTest.java @@ -0,0 +1,149 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.websocket; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.client.BlockingWebClient; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketSession; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.Flags; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketFrameType; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceConfig; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +class DelegatingWebSocketServiceTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final WebSocketService delegate = + WebSocketService + .builder(new EchoWebSocketHandler()) + .fallbackService((ctx, req) -> HttpResponse.of("fallback")) + .build(); + + sb.service("/ws-or-http", new DelegatingWebSocketService(delegate)); + } + }; + + @Test + void shouldReturnMessageInUpperCase() { + final WebSocketClient client = WebSocketClient.of(server.httpUri()); + final WebSocketSession session = client.connect("/ws-or-http").join(); + final WebSocketWriter outbound = session.outbound(); + outbound.write("hello"); + outbound.write("world"); + outbound.close(); + final List responses = session.inbound().collect().join().stream().map(WebSocketFrame::text) + .collect(toImmutableList()); + assertThat(responses).contains("HELLO", "WORLD"); + } + + @Test + void shouldReturnFallbackResponse() { + final BlockingWebClient client = server.blockingWebClient(); + AggregatedHttpResponse response = client.get("/ws-or-http"); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("fallback"); + response = client.post("/ws-or-http", ""); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo("fallback"); + } + + @Test + void shouldNotSetDefaultSettings() { + final ServiceConfig serviceConfig = server.server().serviceConfigs().get(0); + assertThat(serviceConfig.service().as(DelegatingWebSocketService.class)).isNotNull(); + // The default settings for `WebSocketService` should be applied only to `DefaultWebSocketService`. + assertThat(serviceConfig.requestTimeoutMillis()).isEqualTo(Flags.defaultRequestTimeoutMillis()); + } + + private static class EchoWebSocketHandler implements WebSocketServiceHandler { + + @Override + public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { + final WebSocketWriter writer = WebSocket.streaming(); + in.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + writer.write(webSocketFrame); + } + + @Override + public void onError(Throwable t) { + writer.close(t); + } + + @Override + public void onComplete() { + writer.close(); + } + }); + return writer; + } + } + + private static class DelegatingWebSocketService implements WebSocketService { + + private final WebSocketService delegate; + + DelegatingWebSocketService(WebSocketService delegate) { + this.delegate = delegate; + } + + @Override + public WebSocket serve(ServiceRequestContext ctx, WebSocket in) throws Exception { + final StreamMessage transformed = in.map(frame -> { + if (frame.type() == WebSocketFrameType.TEXT) { + final String text = frame.text(); + return WebSocketFrame.ofText(text.toUpperCase()); + } + return frame; + }); + return delegate.serve(ctx, WebSocket.of(transformed)); + } + + @Override + public WebSocketProtocolHandler protocolHandler() { + return delegate.protocolHandler(); + } + } +}