Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure MCP connection is closed when HTTP connection is closed #15

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.quarkiverse.mcp.server.test.close;

import static io.restassured.RestAssured.given;

import java.net.URI;
import java.net.URISyntaxException;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.test.McpClient;
import io.quarkiverse.mcp.server.test.McpServerTest;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.http.ContentType;
import io.vertx.core.json.JsonObject;

public class CloseTest extends McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class));

@Test
public void testPing() throws URISyntaxException {
URI endpoint = initClient();

JsonObject closeMessage = newMessage("q/close");

given()
.contentType(ContentType.JSON)
.when()
.body(closeMessage.encode())
.post(endpoint)
.then()
.statusCode(200);

given()
.contentType(ContentType.JSON)
.when()
.body(closeMessage.encode())
.post(endpoint)
.then()
.statusCode(400);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ private void initializing(JsonObject message, RoutingContext ctx, McpConnectionI
private static final String TOOLS_LIST = "tools/list";
private static final String TOOLS_CALL = "tools/call";
private static final String PING = "ping";
// non-standard messages
private static final String Q_CLOSE = "q/close";

private void operation(JsonObject message, RoutingContext ctx, McpConnectionImpl connection) {
String method = message.getString("method");
Expand All @@ -132,6 +134,7 @@ private void operation(JsonObject message, RoutingContext ctx, McpConnectionImpl
case TOOLS_LIST -> toolHandler.toolsList(message, ctx);
case TOOLS_CALL -> toolHandler.toolsCall(message, ctx, connection);
case PING -> ping(message, ctx);
case Q_CLOSE -> close(ctx, connection);
default -> throw new IllegalArgumentException("Unsupported method: " + method);
}
}
Expand All @@ -144,6 +147,16 @@ private void ping(JsonObject message, RoutingContext ctx) {
ctx.end(newResult(id, new JsonObject()).encode());
}

private void close(RoutingContext ctx, McpConnectionImpl connection) {
if (connectionManager.remove(connection.id())) {
LOG.infof("Connection %s closed", connection.id());
ctx.end();
} else {
LOG.errorf("Unable to close connection %s", connection.id());
ctx.fail(400);
}
}

private InitializeRequest decodeInitializeRequest(JsonObject params) {
JsonObject clientInfo = params.getJsonObject("clientInfo");
Implementation implementation = new Implementation(clientInfo.getString("name"), clientInfo.getString("version"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkiverse.mcp.server.runtime;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Base64;
import java.util.UUID;
import java.util.function.Consumer;
Expand All @@ -10,8 +12,11 @@
import io.quarkus.arc.ArcContainer;
import io.quarkus.runtime.annotations.Recorder;
import io.vertx.core.Handler;
import io.vertx.core.http.HttpConnection;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.net.impl.ConnectionBase;
import io.vertx.ext.web.Route;
import io.vertx.ext.web.RoutingContext;

Expand Down Expand Up @@ -53,6 +58,8 @@ public void handle(RoutingContext ctx) {

McpConnectionImpl connection = new McpConnectionImpl(id, response);
connectionManager.add(connection);
// TODO we cannot override the close handler set/used by Quarkus HTTP
setCloseHandler(ctx.request(), id, connectionManager);

// /mcp/messages?id=generatedId
String endpointPath = mcpPath + "/messages/" + id;
Expand All @@ -63,6 +70,33 @@ public void handle(RoutingContext ctx) {
};
}

private void setCloseHandler(HttpServerRequest request, String connectionId, ConnectionManager connectionManager) {
HttpConnection connection = request.connection();
if (connection instanceof ConnectionBase base) {
try {
MethodHandles.Lookup lookup = MethodHandles.privateLookupIn(ConnectionBase.class, MethodHandles.lookup());
VarHandle varHandle = lookup.findVarHandle(ConnectionBase.class, "closeHandler", Handler.class);
Handler<Void> closeHandler = (Handler<Void>) varHandle.get(base);
base.closeHandler(new Handler<Void>() {
@Override
public void handle(Void event) {
if (closeHandler != null) {
closeHandler.handle(event);
}
if (connectionManager.remove(connectionId)) {
LOG.infof("Connection %s closed", connectionId);
}
// Connection may have been removed earlier...
}
});
} catch (Exception e) {
LOG.warnf(e, "Unable to set close handler - client should close the connection [%s] explicitly", connectionId);
}
} else {
LOG.warnf("Unable to set close handler - client should close the connection [%s] explicitly", connectionId);
}
}

public Consumer<Route> addBodyHandler(Handler<RoutingContext> bodyHandler) {
return new Consumer<Route>() {

Expand Down
Loading