diff --git a/deployment/src/test/java/io/quarkiverse/mcp/server/test/close/CloseTest.java b/deployment/src/test/java/io/quarkiverse/mcp/server/test/close/CloseTest.java new file mode 100644 index 0000000..30b2ba3 --- /dev/null +++ b/deployment/src/test/java/io/quarkiverse/mcp/server/test/close/CloseTest.java @@ -0,0 +1,45 @@ +package io.quarkiverse.mcp.server.test.close; + +import static io.restassured.RestAssured.given; + +import java.net.URI; +import java.net.URISyntaxException; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkiverse.mcp.server.test.McpClient; +import io.quarkiverse.mcp.server.test.McpServerTest; +import io.quarkus.test.QuarkusUnitTest; +import io.restassured.http.ContentType; +import io.vertx.core.json.JsonObject; + +public class CloseTest extends McpServerTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot(root -> root.addClasses(McpClient.class)); + + @Test + public void testPing() throws URISyntaxException { + URI endpoint = initClient(); + + JsonObject closeMessage = newMessage("q/close"); + + given() + .contentType(ContentType.JSON) + .when() + .body(closeMessage.encode()) + .post(endpoint) + .then() + .statusCode(200); + + given() + .contentType(ContentType.JSON) + .when() + .body(closeMessage.encode()) + .post(endpoint) + .then() + .statusCode(400); + } +} diff --git a/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessagesHandler.java b/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessagesHandler.java index 84e79c1..16192c2 100644 --- a/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessagesHandler.java +++ b/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessagesHandler.java @@ -123,6 +123,8 @@ private void initializing(JsonObject message, RoutingContext ctx, McpConnectionI private static final String TOOLS_LIST = "tools/list"; private static final String TOOLS_CALL = "tools/call"; private static final String PING = "ping"; + // non-standard messages + private static final String Q_CLOSE = "q/close"; private void operation(JsonObject message, RoutingContext ctx, McpConnectionImpl connection) { String method = message.getString("method"); @@ -132,6 +134,7 @@ private void operation(JsonObject message, RoutingContext ctx, McpConnectionImpl case TOOLS_LIST -> toolHandler.toolsList(message, ctx); case TOOLS_CALL -> toolHandler.toolsCall(message, ctx, connection); case PING -> ping(message, ctx); + case Q_CLOSE -> close(ctx, connection); default -> throw new IllegalArgumentException("Unsupported method: " + method); } } @@ -144,6 +147,16 @@ private void ping(JsonObject message, RoutingContext ctx) { ctx.end(newResult(id, new JsonObject()).encode()); } + private void close(RoutingContext ctx, McpConnectionImpl connection) { + if (connectionManager.remove(connection.id())) { + LOG.infof("Connection %s closed", connection.id()); + ctx.end(); + } else { + LOG.errorf("Unable to close connection %s", connection.id()); + ctx.fail(400); + } + } + private InitializeRequest decodeInitializeRequest(JsonObject params) { JsonObject clientInfo = params.getJsonObject("clientInfo"); Implementation implementation = new Implementation(clientInfo.getString("name"), clientInfo.getString("version")); diff --git a/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpServerRecorder.java b/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpServerRecorder.java index 134af4c..a6355ac 100644 --- a/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpServerRecorder.java +++ b/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpServerRecorder.java @@ -1,5 +1,7 @@ package io.quarkiverse.mcp.server.runtime; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; import java.util.Base64; import java.util.UUID; import java.util.function.Consumer; @@ -10,8 +12,11 @@ import io.quarkus.arc.ArcContainer; import io.quarkus.runtime.annotations.Recorder; import io.vertx.core.Handler; +import io.vertx.core.http.HttpConnection; import io.vertx.core.http.HttpHeaders; +import io.vertx.core.http.HttpServerRequest; import io.vertx.core.http.HttpServerResponse; +import io.vertx.core.net.impl.ConnectionBase; import io.vertx.ext.web.Route; import io.vertx.ext.web.RoutingContext; @@ -53,6 +58,8 @@ public void handle(RoutingContext ctx) { McpConnectionImpl connection = new McpConnectionImpl(id, response); connectionManager.add(connection); + // TODO we cannot override the close handler set/used by Quarkus HTTP + setCloseHandler(ctx.request(), id, connectionManager); // /mcp/messages?id=generatedId String endpointPath = mcpPath + "/messages/" + id; @@ -63,6 +70,33 @@ public void handle(RoutingContext ctx) { }; } + private void setCloseHandler(HttpServerRequest request, String connectionId, ConnectionManager connectionManager) { + HttpConnection connection = request.connection(); + if (connection instanceof ConnectionBase base) { + try { + MethodHandles.Lookup lookup = MethodHandles.privateLookupIn(ConnectionBase.class, MethodHandles.lookup()); + VarHandle varHandle = lookup.findVarHandle(ConnectionBase.class, "closeHandler", Handler.class); + Handler closeHandler = (Handler) varHandle.get(base); + base.closeHandler(new Handler() { + @Override + public void handle(Void event) { + if (closeHandler != null) { + closeHandler.handle(event); + } + if (connectionManager.remove(connectionId)) { + LOG.infof("Connection %s closed", connectionId); + } + // Connection may have been removed earlier... + } + }); + } catch (Exception e) { + LOG.warnf(e, "Unable to set close handler - client should close the connection [%s] explicitly", connectionId); + } + } else { + LOG.warnf("Unable to set close handler - client should close the connection [%s] explicitly", connectionId); + } + } + public Consumer addBodyHandler(Handler bodyHandler) { return new Consumer() {