From cc2a770ae1b71a1d961203e7c4093de53bc8edeb Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Wed, 8 Jan 2025 11:18:11 +0100 Subject: [PATCH] SSE transport: always send server messages as SSE events - also do not require the "text/event-stream" header from the client - introduce SseClient test util class, get rid of quarkus-rest-client in tests - fixes #32 - fixes #31 --- .../mcp/server/runtime/McpMessageHandler.java | 2 - pom.xml | 1 + test-utils/pom.xml | 25 +++ .../mcp/server/test/SseClient.java | 145 ++++++++++++++++++ transports/sse/deployment/pom.xml | 9 +- .../mcp/server/test/McpClient.java | 26 ---- .../mcp/server/test/McpServerTest.java | 57 ++++--- .../mcp/server/test/close/CloseTest.java | 8 +- .../mcp/server/test/ping/PingTest.java | 11 +- .../test/prompts/InvalidPromptNameTest.java | 13 +- .../prompts/MissingPromptArgumentTest.java | 16 +- .../test/prompts/PromptInternalErrorTest.java | 13 +- .../mcp/server/test/prompts/PromptsTest.java | 19 ++- .../resources/InvalidResourceUriTest.java | 13 +- .../resources/ResourceInternalErrorTest.java | 13 +- .../server/test/resources/ResourcesTest.java | 19 ++- .../test/serverinfo/CustomServerInfoTest.java | 3 +- .../serverinfo/DefaultServerInfoTest.java | 3 +- .../test/tools/InvalidToolNameTest.java | 13 +- .../test/tools/MissingToolArgumentTest.java | 13 +- .../test/tools/ToolInternalErrorTest.java | 13 +- .../mcp/server/test/tools/ToolsTest.java | 18 +-- transports/sse/integration-tests/pom.xml | 6 + .../mcp/server/sse/it/McpClient.java | 26 ---- .../mcp/server/sse/it/McpClientInit.java | 42 ----- .../src/main/resources/application.properties | 1 + .../mcp/server/sse/it/ServerFeaturesTest.java | 96 +++++++----- .../server/sse/runtime/SseMcpConnection.java | 2 +- .../sse/runtime/SseMcpMessageHandler.java | 38 ++--- .../sse/runtime/SseMcpServerRecorder.java | 7 - 30 files changed, 383 insertions(+), 288 deletions(-) create mode 100644 test-utils/pom.xml create mode 100644 test-utils/src/main/java/io/quarkiverse/mcp/server/test/SseClient.java delete mode 100644 transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpClient.java delete mode 100644 transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClient.java delete mode 100644 transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClientInit.java diff --git a/core/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessageHandler.java b/core/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessageHandler.java index 5ad11cf..8c32d87 100644 --- a/core/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessageHandler.java +++ b/core/runtime/src/main/java/io/quarkiverse/mcp/server/runtime/McpMessageHandler.java @@ -79,7 +79,6 @@ private void initializing(JsonObject message, Responder responder, McpConnection String method = message.getString("method"); if (NOTIFICATIONS_INITIALIZED.equals(method)) { if (connection.setInitialized()) { - responder.send(null); LOG.infof("Client successfully initialized [%s]", connection.id()); } } else if (PING.equals(method)) { @@ -128,7 +127,6 @@ private void ping(JsonObject message, Responder responder) { private void close(JsonObject message, Responder responder, McpConnection connection) { if (connectionManager.remove(connection.id())) { LOG.infof("Connection %s closed", connection.id()); - responder.send(null); } else { responder.sendError(message.getValue("id"), JsonRPC.INTERNAL_ERROR, "Unable to obtain the connection to be closed:" + connection.id()); diff --git a/pom.xml b/pom.xml index f94c385..c15caf0 100644 --- a/pom.xml +++ b/pom.xml @@ -17,6 +17,7 @@ core transports/sse transports/stdio + test-utils diff --git a/test-utils/pom.xml b/test-utils/pom.xml new file mode 100644 index 0000000..28f2bb3 --- /dev/null +++ b/test-utils/pom.xml @@ -0,0 +1,25 @@ + + + 4.0.0 + + + io.quarkiverse.mcp + quarkus-mcp-server-parent + 999-SNAPSHOT + ../pom.xml + + quarkus-mcp-server-test-utils + Quarkus MCP Server Test Utils + + + + org.jboss.logging + jboss-logging + + + org.awaitility + awaitility + + + + diff --git a/test-utils/src/main/java/io/quarkiverse/mcp/server/test/SseClient.java b/test-utils/src/main/java/io/quarkiverse/mcp/server/test/SseClient.java new file mode 100644 index 0000000..1f6ef04 --- /dev/null +++ b/test-utils/src/main/java/io/quarkiverse/mcp/server/test/SseClient.java @@ -0,0 +1,145 @@ +package io.quarkiverse.mcp.server.test; + +import java.io.EOFException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicInteger; + +import org.awaitility.Awaitility; +import org.jboss.logging.Logger; + +public class SseClient { + + private static final Logger LOG = Logger.getLogger(SseClient.class); + + private final URI testUri; + + private final AtomicInteger idGenerator; + + public final List events; + + public SseClient(URI uri) { + this.testUri = uri; + this.idGenerator = new AtomicInteger(); + this.events = new CopyOnWriteArrayList<>(); + } + + public int nextId() { + return idGenerator.incrementAndGet(); + } + + public SseEvent waitForFirstEvent() { + nextId(); + Awaitility.await().until(() -> !events.isEmpty()); + return events.get(0); + } + + public SseEvent waitForLastEvent() { + int lastId = idGenerator.get(); + Awaitility.await().until(() -> events.size() >= lastId); + return events.get(lastId - 1); + } + + public void connect() { + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = HttpRequest.newBuilder() + .uri(testUri) + .version(Version.HTTP_1_1) + .header("Accept", "text/event-stream") + .GET() + .build(); + + client.sendAsync(request, BodyHandlers.fromLineSubscriber(new SseEventSubscriber())) + .thenAccept(response -> { + if (response.statusCode() == 200) { + LOG.infof("Connected to SSE stream: %s", testUri); + } else { + LOG.errorf("Failed to connect %s: %s", response.statusCode(), testUri); + } + }) + .exceptionally(e -> { + Throwable root = getRootCause(e); + if (!(root instanceof EOFException)) { + LOG.error(e); + } + return null; + }); + + } + + class SseEventSubscriber implements Flow.Subscriber { + + private Flow.Subscription subscription; + + private String event = "message"; + private StringBuilder dataBuffer = new StringBuilder(); + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String line) { + LOG.debugf("Received next line:\n%s", line); + if (line.startsWith(":")) { + // Skip comments + } else if (line.isBlank()) { + // Flush + events.add(new SseEvent(event, dataBuffer.toString())); + event = "message"; + dataBuffer = new StringBuilder(); + } else if (line.contains(":")) { + int colon = line.indexOf(":"); + String field = line.substring(0, colon).strip(); + String value = line.substring(colon + 1).strip(); + handleField(field, value); + } else { + // The whole line is the field name + handleField(line, ""); + } + subscription.request(1); + } + + @Override + public void onError(Throwable t) { + Throwable root = getRootCause(t); + if (!(root instanceof EOFException)) { + LOG.errorf(t, "Error in SSE stream"); + } + } + + @Override + public void onComplete() { + LOG.debug("SSE stream complete"); + } + + private void handleField(String field, String value) { + switch (field) { + case "event" -> event = value; + case "data" -> dataBuffer.append(value).append("\n"); + } + } + } + + public record SseEvent(String name, String data) { + } + + private static Throwable getRootCause(Throwable exception) { + final List chain = new ArrayList<>(); + Throwable curr = exception; + while (curr != null && !chain.contains(curr)) { + chain.add(curr); + curr = curr.getCause(); + } + return chain.isEmpty() ? null : chain.get(chain.size() - 1); + } +} diff --git a/transports/sse/deployment/pom.xml b/transports/sse/deployment/pom.xml index 80c2002..735d818 100644 --- a/transports/sse/deployment/pom.xml +++ b/transports/sse/deployment/pom.xml @@ -32,13 +32,14 @@ test - io.quarkus - quarkus-rest-client + io.rest-assured + rest-assured test - io.rest-assured - rest-assured + io.quarkiverse.mcp + quarkus-mcp-server-test-utils + ${project.version} test diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpClient.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpClient.java deleted file mode 100644 index cf550f4..0000000 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpClient.java +++ /dev/null @@ -1,26 +0,0 @@ -package io.quarkiverse.mcp.server.test; - -import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; -import static jakarta.ws.rs.core.MediaType.SERVER_SENT_EVENTS; - -import jakarta.ws.rs.GET; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.Produces; - -import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; -import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; -import org.jboss.resteasy.reactive.client.SseEvent; - -import io.smallrye.mutiny.Multi; - -@Path("mcp") -@RegisterRestClient -public interface McpClient { - - @GET - @Path("sse") - @ClientHeaderParam(name = CONTENT_TYPE, value = SERVER_SENT_EVENTS) - @Produces(SERVER_SENT_EVENTS) - Multi> init(); - -} diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpServerTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpServerTest.java index 188f6a0..a443cb1 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpServerTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/McpServerTest.java @@ -6,16 +6,11 @@ import java.net.URI; import java.net.URISyntaxException; -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; -import org.awaitility.Awaitility; import org.jboss.logging.Logger; -import org.jboss.resteasy.reactive.client.SseEvent; +import org.junit.jupiter.api.AfterEach; -import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; import io.quarkus.test.QuarkusUnitTest; import io.quarkus.test.common.http.TestHTTPResource; import io.restassured.http.ContentType; @@ -28,7 +23,11 @@ public abstract class McpServerTest { @TestHTTPResource URI testUri; + SseClient sseClient; + public static QuarkusUnitTest defaultConfig() { + // TODO in theory, we should also add SseClient to all test archives + // but the test CL can see the class and we don't need Quarkus to analyze this util class QuarkusUnitTest config = new QuarkusUnitTest(); if (System.getProperty("logTraffic") != null) { config.overrideConfigKey("quarkus.mcp.server.sse.traffic-logging.enabled", "true"); @@ -36,24 +35,34 @@ public static QuarkusUnitTest defaultConfig() { return config; } - protected List> sseMessages; - - AtomicInteger idGenerator = new AtomicInteger(); + @AfterEach + void cleanup() { + sseClient = null; + } protected URI initClient() throws URISyntaxException { return initClient(null); } + protected JsonObject waitForLastJsonMessage() { + SseClient.SseEvent event = sseClient.waitForLastEvent(); + if ("message".equals(event.name())) { + return new JsonObject(event.data()); + } + throw new IllegalStateException("Message event not received: " + event); + } + protected URI initClient(Consumer initResultAssert) throws URISyntaxException { - McpClient mcpClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(testUri) - .build(McpClient.class); - - sseMessages = new CopyOnWriteArrayList<>(); - mcpClient.init().subscribe().with(s -> sseMessages.add(s), e -> { - }); - Awaitility.await().until(() -> !sseMessages.isEmpty()); - URI endpoint = new URI(sseMessages.get(0).data()); + String testUriStr = testUri.toString(); + if (testUriStr.endsWith("/")) { + testUriStr = testUriStr.substring(0, testUriStr.length() - 1); + } + sseClient = new SseClient(URI.create(testUriStr + "/mcp/sse")); + sseClient.connect(); + var event = sseClient.waitForFirstEvent(); + String messagesUri = testUriStr + event.data().strip(); + URI endpoint = URI.create(messagesUri); + LOG.infof("Client received endpoint: %s", endpoint); JsonObject initMessage = newMessage("initialize") @@ -64,14 +73,15 @@ protected URI initClient(Consumer initResultAssert) throws URISyntax .put("version", "1.0")) .put("protocolVersion", "2024-11-05")); - JsonObject initResponse = new JsonObject(given() + given() .contentType(ContentType.JSON) .when() .body(initMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject initResponse = waitForLastJsonMessage(); JsonObject initResult = assertResponseMessage(initMessage, initResponse); assertNotNull(initResult); @@ -102,10 +112,13 @@ protected JsonObject assertResponseMessage(JsonObject message, JsonObject respon } protected JsonObject newMessage(String method) { + if (sseClient == null) { + throw new IllegalStateException("SSE client not initialized"); + } return new JsonObject() .put("jsonrpc", "2.0") .put("method", method) - .put("id", idGenerator.incrementAndGet()); + .put("id", sseClient.nextId()); } } diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/close/CloseTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/close/CloseTest.java index 2a2079b..c19326e 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/close/CloseTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/close/CloseTest.java @@ -1,7 +1,6 @@ package io.quarkiverse.mcp.server.test.close; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; import java.net.URI; import java.net.URISyntaxException; @@ -9,8 +8,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import io.quarkiverse.mcp.server.runtime.JsonRPC; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.restassured.http.ContentType; @@ -20,7 +17,7 @@ public class CloseTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() - .withApplicationRoot(root -> root.addClasses(McpClient.class)); + .withEmptyApplication(); @Test public void testCloseMessage() throws URISyntaxException { @@ -42,7 +39,6 @@ public void testCloseMessage() throws URISyntaxException { .body(closeMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INTERNAL_ERROR)); + .statusCode(400); } } diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/ping/PingTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/ping/PingTest.java index 93dfaca..33ca35f 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/ping/PingTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/ping/PingTest.java @@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.restassured.http.ContentType; @@ -20,7 +19,7 @@ public class PingTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() - .withApplicationRoot(root -> root.addClasses(McpClient.class)); + .withEmptyApplication(); @Test public void testPing() throws URISyntaxException { @@ -28,14 +27,14 @@ public void testPing() throws URISyntaxException { JsonObject pingMessage = newMessage("ping"); - JsonObject pingResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(pingMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject pingResponse = waitForLastJsonMessage(); JsonObject pingResult = assertResponseMessage(pingMessage, pingResponse); assertNotNull(pingResult); diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/InvalidPromptNameTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/InvalidPromptNameTest.java index 745d0f2..b4a7ee2 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/InvalidPromptNameTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/InvalidPromptNameTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.prompts; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -12,7 +12,6 @@ import io.quarkiverse.mcp.server.runtime.JsonRPC; import io.quarkiverse.mcp.server.test.Checks; import io.quarkiverse.mcp.server.test.FooService; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkiverse.mcp.server.test.Options; import io.quarkus.test.QuarkusUnitTest; @@ -24,7 +23,7 @@ public class InvalidPromptNameTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, FooService.class, Options.class, Checks.class, MyPrompts.class)); + root -> root.addClasses(FooService.class, Options.class, Checks.class, MyPrompts.class)); @Test public void testError() throws URISyntaxException { @@ -41,9 +40,11 @@ public void testError() throws URISyntaxException { .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INVALID_PARAMS), "error.message", - equalTo("Invalid prompt name: nonexistent")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.INVALID_PARAMS, response.getJsonObject("error").getInteger("code")); + assertEquals("Invalid prompt name: nonexistent", response.getJsonObject("error").getString("message")); } } diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/MissingPromptArgumentTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/MissingPromptArgumentTest.java index c216a67..acab957 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/MissingPromptArgumentTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/MissingPromptArgumentTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.prompts; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -12,7 +12,6 @@ import io.quarkiverse.mcp.server.runtime.JsonRPC; import io.quarkiverse.mcp.server.test.Checks; import io.quarkiverse.mcp.server.test.FooService; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkiverse.mcp.server.test.Options; import io.quarkus.test.QuarkusUnitTest; @@ -24,7 +23,7 @@ public class MissingPromptArgumentTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, FooService.class, Options.class, Checks.class, MyPrompts.class)); + root -> root.addClasses(FooService.class, Options.class, Checks.class, MyPrompts.class)); @Test public void testError() throws URISyntaxException { @@ -35,15 +34,16 @@ public void testError() throws URISyntaxException { .put("name", "uni_bar") .put("arguments", new JsonObject())); - given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INVALID_PARAMS), "error.message", - equalTo("Missing required argument: val")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.INVALID_PARAMS, response.getJsonObject("error").getInteger("code")); + assertEquals("Missing required argument: val", response.getJsonObject("error").getString("message")); } } diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptInternalErrorTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptInternalErrorTest.java index 510d7fc..4851af0 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptInternalErrorTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptInternalErrorTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.prompts; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -12,7 +12,6 @@ import io.quarkiverse.mcp.server.Prompt; import io.quarkiverse.mcp.server.PromptMessage; import io.quarkiverse.mcp.server.runtime.JsonRPC; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.restassured.http.ContentType; @@ -24,7 +23,7 @@ public class PromptInternalErrorTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, MyPrompts.class)); + root -> root.addClasses(MyPrompts.class)); @Test public void testError() throws URISyntaxException { @@ -41,9 +40,11 @@ public void testError() throws URISyntaxException { .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INTERNAL_ERROR), "error.message", - equalTo("Internal error")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.INTERNAL_ERROR, response.getJsonObject("error").getInteger("code")); + assertEquals("Internal error", response.getJsonObject("error").getString("message")); } public static class MyPrompts { diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptsTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptsTest.java index 693a203..4782ad4 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptsTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/prompts/PromptsTest.java @@ -13,7 +13,6 @@ import io.quarkiverse.mcp.server.test.Checks; import io.quarkiverse.mcp.server.test.FooService; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkiverse.mcp.server.test.Options; import io.quarkus.test.QuarkusUnitTest; @@ -26,7 +25,7 @@ public class PromptsTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, FooService.class, Options.class, Checks.class, MyPrompts.class)); + root -> root.addClasses(FooService.class, Options.class, Checks.class, MyPrompts.class)); @Test public void testPrompts() throws URISyntaxException { @@ -34,14 +33,14 @@ public void testPrompts() throws URISyntaxException { JsonObject promptListMessage = newMessage("prompts/list"); - JsonObject promptListResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(promptListMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject promptListResponse = waitForLastJsonMessage(); JsonObject promptListResult = assertResponseMessage(promptListMessage, promptListResponse); assertNotNull(promptListResult); @@ -91,14 +90,14 @@ private void assertPromptMessage(String expectedText, URI endpoint, String name, .put("name", name) .put("arguments", arguments)); - JsonObject promptGetResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(promptGetMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject promptGetResponse = waitForLastJsonMessage(); JsonObject promptGetResult = assertResponseMessage(promptGetMessage, promptGetResponse); assertNotNull(promptGetResult); diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/InvalidResourceUriTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/InvalidResourceUriTest.java index c4f1596..b78f624 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/InvalidResourceUriTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/InvalidResourceUriTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.resources; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -11,7 +11,6 @@ import io.quarkiverse.mcp.server.runtime.JsonRPC; import io.quarkiverse.mcp.server.test.Checks; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.restassured.http.ContentType; @@ -22,7 +21,7 @@ public class InvalidResourceUriTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, Checks.class, MyResources.class)); + root -> root.addClasses(Checks.class, MyResources.class)); @Test public void testError() throws URISyntaxException { @@ -38,9 +37,11 @@ public void testError() throws URISyntaxException { .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.RESOURCE_NOT_FOUND), "error.message", - equalTo("Invalid resource uri: file:///nonexistent")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.RESOURCE_NOT_FOUND, response.getJsonObject("error").getInteger("code")); + assertEquals("Invalid resource uri: file:///nonexistent", response.getJsonObject("error").getString("message")); } } diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourceInternalErrorTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourceInternalErrorTest.java index 5f0ba4b..eb6e426 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourceInternalErrorTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourceInternalErrorTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.resources; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -12,7 +12,6 @@ import io.quarkiverse.mcp.server.Resource; import io.quarkiverse.mcp.server.ResourceResponse; import io.quarkiverse.mcp.server.runtime.JsonRPC; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.restassured.http.ContentType; @@ -23,7 +22,7 @@ public class ResourceInternalErrorTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, MyResources.class)); + root -> root.addClasses(MyResources.class)); @Test public void testError() throws URISyntaxException { @@ -39,9 +38,11 @@ public void testError() throws URISyntaxException { .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INTERNAL_ERROR), "error.message", - equalTo("Internal error")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.INTERNAL_ERROR, response.getJsonObject("error").getInteger("code")); + assertEquals("Internal error", response.getJsonObject("error").getString("message")); } public static class MyResources { diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourcesTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourcesTest.java index 16c6edd..918fab4 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourcesTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/resources/ResourcesTest.java @@ -11,7 +11,6 @@ import org.junit.jupiter.api.extension.RegisterExtension; import io.quarkiverse.mcp.server.test.Checks; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.restassured.http.ContentType; @@ -23,7 +22,7 @@ public class ResourcesTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, MyResources.class, Checks.class)); + root -> root.addClasses(MyResources.class, Checks.class)); @Test public void testResources() throws URISyntaxException { @@ -31,14 +30,14 @@ public void testResources() throws URISyntaxException { JsonObject resourcesListMessage = newMessage("resources/list"); - JsonObject resourcesListResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(resourcesListMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject resourcesListResponse = waitForLastJsonMessage(); JsonObject resourcesListResult = assertResponseMessage(resourcesListMessage, resourcesListResponse); assertNotNull(resourcesListResult); @@ -73,14 +72,14 @@ private void assertResourceRead(String expectedText, String expectedUri, URI end .put("params", new JsonObject() .put("uri", uri)); - JsonObject resourceReadResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(resourceReadMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject resourceReadResponse = waitForLastJsonMessage(); JsonObject resourceReadResult = assertResponseMessage(resourceReadMessage, resourceReadResponse); assertNotNull(resourceReadResult); diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/CustomServerInfoTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/CustomServerInfoTest.java index 0d3fb69..dd9d88b 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/CustomServerInfoTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/CustomServerInfoTest.java @@ -8,7 +8,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.vertx.core.json.JsonObject; @@ -20,7 +19,7 @@ public class CustomServerInfoTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() - .withApplicationRoot(root -> root.addClasses(McpClient.class)) + .withEmptyApplication() .overrideConfigKey("quarkus.mcp.server.server-info.name", NAME) .overrideConfigKey("quarkus.mcp.server.server-info.version", VERSION); diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/DefaultServerInfoTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/DefaultServerInfoTest.java index 2c518f7..df341c6 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/DefaultServerInfoTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/serverinfo/DefaultServerInfoTest.java @@ -9,7 +9,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.vertx.core.json.JsonObject; @@ -18,7 +17,7 @@ public class DefaultServerInfoTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() - .withApplicationRoot(root -> root.addClasses(McpClient.class)); + .withEmptyApplication(); @Test public void testServerInfo() throws URISyntaxException { diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/InvalidToolNameTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/InvalidToolNameTest.java index 0426cd4..09aa3fc 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/InvalidToolNameTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/InvalidToolNameTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.tools; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -12,7 +12,6 @@ import io.quarkiverse.mcp.server.runtime.JsonRPC; import io.quarkiverse.mcp.server.test.Checks; import io.quarkiverse.mcp.server.test.FooService; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkiverse.mcp.server.test.Options; import io.quarkus.test.QuarkusUnitTest; @@ -24,7 +23,7 @@ public class InvalidToolNameTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, FooService.class, Options.class, Checks.class, MyTools.class)); + root -> root.addClasses(FooService.class, Options.class, Checks.class, MyTools.class)); @Test public void testError() throws URISyntaxException { @@ -41,9 +40,11 @@ public void testError() throws URISyntaxException { .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INVALID_PARAMS), "error.message", - equalTo("Invalid tool name: nonexistent")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.INVALID_PARAMS, response.getJsonObject("error").getInteger("code")); + assertEquals("Invalid tool name: nonexistent", response.getJsonObject("error").getString("message")); } } diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/MissingToolArgumentTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/MissingToolArgumentTest.java index d222f20..59e3085 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/MissingToolArgumentTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/MissingToolArgumentTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.tools; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -12,7 +12,6 @@ import io.quarkiverse.mcp.server.runtime.JsonRPC; import io.quarkiverse.mcp.server.test.Checks; import io.quarkiverse.mcp.server.test.FooService; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkiverse.mcp.server.test.Options; import io.quarkus.test.QuarkusUnitTest; @@ -24,7 +23,7 @@ public class MissingToolArgumentTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, FooService.class, Options.class, Checks.class, MyTools.class)); + root -> root.addClasses(FooService.class, Options.class, Checks.class, MyTools.class)); @Test public void testError() throws URISyntaxException { @@ -41,9 +40,11 @@ public void testError() throws URISyntaxException { .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INVALID_PARAMS), "error.message", - equalTo("Missing required argument: price")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.INVALID_PARAMS, response.getJsonObject("error").getInteger("code")); + assertEquals("Missing required argument: price", response.getJsonObject("error").getString("message")); } } diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolInternalErrorTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolInternalErrorTest.java index 9b82e78..9d0da96 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolInternalErrorTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolInternalErrorTest.java @@ -1,7 +1,7 @@ package io.quarkiverse.mcp.server.test.tools; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; import java.net.URISyntaxException; @@ -12,7 +12,6 @@ import io.quarkiverse.mcp.server.TextContent; import io.quarkiverse.mcp.server.Tool; import io.quarkiverse.mcp.server.runtime.JsonRPC; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkus.test.QuarkusUnitTest; import io.restassured.http.ContentType; @@ -23,7 +22,7 @@ public class ToolInternalErrorTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, MyTools.class)); + root -> root.addClasses(MyTools.class)); @Test public void testError() throws URISyntaxException { @@ -41,9 +40,11 @@ public void testError() throws URISyntaxException { .body(message.encode()) .post(endpoint) .then() - .statusCode(200) - .body("error.code", equalTo(JsonRPC.INTERNAL_ERROR), "error.message", - equalTo("Internal error")); + .statusCode(200); + + JsonObject response = waitForLastJsonMessage(); + assertEquals(JsonRPC.INTERNAL_ERROR, response.getJsonObject("error").getInteger("code")); + assertEquals("Internal error", response.getJsonObject("error").getString("message")); } public static class MyTools { diff --git a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolsTest.java b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolsTest.java index f599083..1b9d334 100644 --- a/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolsTest.java +++ b/transports/sse/deployment/src/test/java/io/quarkiverse/mcp/server/test/tools/ToolsTest.java @@ -13,7 +13,6 @@ import io.quarkiverse.mcp.server.test.Checks; import io.quarkiverse.mcp.server.test.FooService; -import io.quarkiverse.mcp.server.test.McpClient; import io.quarkiverse.mcp.server.test.McpServerTest; import io.quarkiverse.mcp.server.test.Options; import io.quarkus.test.QuarkusUnitTest; @@ -26,7 +25,7 @@ public class ToolsTest extends McpServerTest { @RegisterExtension static final QuarkusUnitTest config = defaultConfig() .withApplicationRoot( - root -> root.addClasses(McpClient.class, FooService.class, Options.class, Checks.class, MyTools.class)); + root -> root.addClasses(FooService.class, Options.class, Checks.class, MyTools.class)); @Test public void testTools() throws URISyntaxException { @@ -34,14 +33,14 @@ public void testTools() throws URISyntaxException { JsonObject toolListMessage = newMessage("tools/list"); - JsonObject toolListResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(toolListMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject toolListResponse = waitForLastJsonMessage(); JsonObject toolListResult = assertResponseMessage(toolListMessage, toolListResponse); assertNotNull(toolListResult); @@ -91,14 +90,15 @@ private void assertToolCall(String expectedText, URI endpoint, String name, Json .put("name", name) .put("arguments", arguments)); - JsonObject toolGetResponse = new JsonObject(given() + given() .contentType(ContentType.JSON) .when() .body(toolGetMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject toolGetResponse = waitForLastJsonMessage(); JsonObject toolGetResult = assertResponseMessage(toolGetMessage, toolGetResponse); assertNotNull(toolGetResult); diff --git a/transports/sse/integration-tests/pom.xml b/transports/sse/integration-tests/pom.xml index 56477dd..10bc59c 100644 --- a/transports/sse/integration-tests/pom.xml +++ b/transports/sse/integration-tests/pom.xml @@ -43,6 +43,12 @@ rest-assured test + + io.quarkiverse.mcp + quarkus-mcp-server-test-utils + ${project.version} + test + diff --git a/transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClient.java b/transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClient.java deleted file mode 100644 index eedcbb8..0000000 --- a/transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClient.java +++ /dev/null @@ -1,26 +0,0 @@ -package io.quarkiverse.mcp.server.sse.it; - -import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; -import static jakarta.ws.rs.core.MediaType.SERVER_SENT_EVENTS; - -import jakarta.ws.rs.GET; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.Produces; - -import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; -import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; -import org.jboss.resteasy.reactive.client.SseEvent; - -import io.smallrye.mutiny.Multi; - -@Path("mcp") -@RegisterRestClient -public interface McpClient { - - @GET - @Path("sse") - @ClientHeaderParam(name = CONTENT_TYPE, value = SERVER_SENT_EVENTS) - @Produces(SERVER_SENT_EVENTS) - Multi> init(); - -} diff --git a/transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClientInit.java b/transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClientInit.java deleted file mode 100644 index 78746f1..0000000 --- a/transports/sse/integration-tests/src/main/java/io/quarkiverse/mcp/server/sse/it/McpClientInit.java +++ /dev/null @@ -1,42 +0,0 @@ -package io.quarkiverse.mcp.server.sse.it; - -import java.net.URI; -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; - -import jakarta.enterprise.event.Observes; - -import org.awaitility.Awaitility; -import org.eclipse.microprofile.config.inject.ConfigProperty; -import org.jboss.resteasy.reactive.client.SseEvent; - -import io.quarkus.logging.Log; -import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; -import io.vertx.ext.web.Router; - -// A workaround to init the MCP client via SSE in the test -public class McpClientInit { - - void initRoute(@Observes Router router, @ConfigProperty(name = "quarkus.http.host") String host, - @ConfigProperty(name = "quarkus.http.test-port") int port, - @ConfigProperty(name = "quarkus.http.root-path") String root) { - router.route("/test-init-mcp-client").blockingHandler(rc -> { - try { - URI baseUri = new URI("http://" + host + ":" + port + root); - Log.infof("Test base URI: %s", baseUri); - List> sseMessages = new CopyOnWriteArrayList<>(); - McpClient mcpClient = QuarkusRestClientBuilder.newBuilder() - .baseUri(baseUri) - .build(McpClient.class); - mcpClient.init().subscribe().with(s -> sseMessages.add(s), e -> { - }); - Awaitility.await().until(() -> !sseMessages.isEmpty()); - URI endpoint = new URI(sseMessages.get(0).data()); - rc.end(endpoint.toString()); - } catch (Throwable e) { - throw new RuntimeException(e); - } - }); - } - -} diff --git a/transports/sse/integration-tests/src/main/resources/application.properties b/transports/sse/integration-tests/src/main/resources/application.properties index e69de29..7fb2b28 100644 --- a/transports/sse/integration-tests/src/main/resources/application.properties +++ b/transports/sse/integration-tests/src/main/resources/application.properties @@ -0,0 +1 @@ +#quarkus.mcp.server.sse.traffic-logging.enabled=true \ No newline at end of file diff --git a/transports/sse/integration-tests/src/test/java/io/quarkiverse/mcp/server/sse/it/ServerFeaturesTest.java b/transports/sse/integration-tests/src/test/java/io/quarkiverse/mcp/server/sse/it/ServerFeaturesTest.java index 9b22631..ef6dbfa 100644 --- a/transports/sse/integration-tests/src/test/java/io/quarkiverse/mcp/server/sse/it/ServerFeaturesTest.java +++ b/transports/sse/integration-tests/src/test/java/io/quarkiverse/mcp/server/sse/it/ServerFeaturesTest.java @@ -7,24 +7,28 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Base64; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import org.jboss.logging.Logger; import org.junit.jupiter.api.Test; +import io.quarkiverse.mcp.server.test.SseClient; import io.quarkus.test.common.http.TestHTTPResource; import io.quarkus.test.junit.QuarkusTest; import io.restassured.http.ContentType; +import io.vertx.core.json.DecodeException; import io.vertx.core.json.JsonArray; import io.vertx.core.json.JsonObject; @QuarkusTest public class ServerFeaturesTest { + private static final Logger LOG = Logger.getLogger(ServerFeaturesTest.class); + @TestHTTPResource URI testUri; - AtomicInteger idGenerator = new AtomicInteger(); + SseClient sseClient; @Test public void testPrompt() throws URISyntaxException { @@ -32,14 +36,14 @@ public void testPrompt() throws URISyntaxException { JsonObject promptListMessage = newMessage("prompts/list"); - JsonObject promptListResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(promptListMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject promptListResponse = lastEventToJson(); JsonObject promptListResult = assertResponseMessage(promptListMessage, promptListResponse); assertNotNull(promptListResult); @@ -62,14 +66,14 @@ public void testTool() throws URISyntaxException { JsonObject toolListMessage = newMessage("tools/list"); - JsonObject toolListResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(toolListMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject toolListResponse = lastEventToJson(); JsonObject toolListResult = assertResponseMessage(toolListMessage, toolListResponse); assertNotNull(toolListResult); @@ -95,14 +99,14 @@ public void testResource() throws URISyntaxException { JsonObject resourceListMessage = newMessage("resources/list"); - JsonObject resourceListResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(resourceListMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject resourceListResponse = lastEventToJson(); JsonObject resourceListResult = assertResponseMessage(resourceListMessage, resourceListResponse); assertNotNull(resourceListResult); @@ -116,13 +120,15 @@ public void testResource() throws URISyntaxException { } protected URI initClient() throws URISyntaxException { - URI endpoint = new URI(given().baseUri(testUri.toString()) - .contentType(ContentType.TEXT) - .when() - .get("test-init-mcp-client") - .then() - .statusCode(200) - .extract().body().asString()); + String testUriStr = testUri.toString(); + if (testUriStr.endsWith("/")) { + testUriStr = testUriStr.substring(0, testUriStr.length() - 1); + } + sseClient = new SseClient(URI.create(testUriStr + "/mcp/sse")); + sseClient.connect(); + var event = sseClient.waitForFirstEvent(); + String messagesUri = testUriStr + event.data().strip(); + URI endpoint = URI.create(messagesUri); JsonObject initMessage = newMessage("initialize") .put("params", @@ -132,14 +138,14 @@ protected URI initClient() throws URISyntaxException { .put("version", "1.0")) .put("protocolVersion", "2024-11-05")); - JsonObject initResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(initMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject initResponse = lastEventToJson(); JsonObject initResult = assertResponseMessage(initMessage, initResponse); assertNotNull(initResult); @@ -159,6 +165,17 @@ protected URI initClient() throws URISyntaxException { return endpoint; } + private JsonObject lastEventToJson() { + SseClient.SseEvent event = null; + try { + event = sseClient.waitForLastEvent(); + return new JsonObject(event.data()); + } catch (DecodeException e) { + LOG.errorf("Error parsing:\n%s", event); + throw e; + } + } + protected JsonObject assertResponseMessage(JsonObject message, JsonObject response) { assertEquals(message.getInteger("id"), response.getInteger("id")); assertEquals("2.0", response.getString("jsonrpc")); @@ -166,10 +183,13 @@ protected JsonObject assertResponseMessage(JsonObject message, JsonObject respon } protected JsonObject newMessage(String method) { + if (sseClient == null) { + throw new IllegalStateException(); + } return new JsonObject() .put("jsonrpc", "2.0") .put("method", method) - .put("id", idGenerator.incrementAndGet()); + .put("id", sseClient.nextId()); } private void assertPrompt(JsonObject prompt, String name, String description, Consumer argumentsAsserter) { @@ -188,14 +208,14 @@ private void assertPromptMessage(String expectedText, URI endpoint, String name, .put("name", name) .put("arguments", arguments)); - JsonObject promptGetResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(promptGetMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject promptGetResponse = lastEventToJson(); JsonObject promptGetResult = assertResponseMessage(promptGetMessage, promptGetResponse); assertNotNull(promptGetResult); @@ -224,14 +244,14 @@ private void assertToolCall(String expectedText, URI endpoint, String name, Json .put("name", name) .put("arguments", arguments)); - JsonObject toolGetResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(toolGetMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject toolGetResponse = lastEventToJson(); JsonObject toolGetResult = assertResponseMessage(toolGetMessage, toolGetResponse); assertNotNull(toolGetResult); @@ -258,14 +278,14 @@ private void assertResourceRead(String expectedBlob, String expectedUri, URI end .put("params", new JsonObject() .put("uri", uri)); - JsonObject resourceReadResponse = new JsonObject(given() - .contentType(ContentType.JSON) + given().contentType(ContentType.JSON) .when() .body(resourceReadMessage.encode()) .post(endpoint) .then() - .statusCode(200) - .extract().body().asString()); + .statusCode(200); + + JsonObject resourceReadResponse = lastEventToJson(); JsonObject resourceReadResult = assertResponseMessage(resourceReadMessage, resourceReadResponse); assertNotNull(resourceReadResult); diff --git a/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpConnection.java b/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpConnection.java index c229067..271ae07 100644 --- a/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpConnection.java +++ b/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpConnection.java @@ -15,7 +15,7 @@ public class SseMcpConnection extends McpConnectionBase { public void sendEvent(String name, String data) { response.write("event: " + name + "\n"); response.write("data: " + data + "\n"); - response.write("\n\n"); + response.write("\n"); } } diff --git a/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpMessageHandler.java b/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpMessageHandler.java index fd0ee65..3525c32 100644 --- a/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpMessageHandler.java +++ b/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpMessageHandler.java @@ -34,24 +34,26 @@ protected SseMcpMessageHandler(McpRuntimeConfig config, ConnectionManager connec @Override public void handle(RoutingContext ctx) { - SseResponder responder = new SseResponder(trafficLogger, ctx); HttpServerRequest request = ctx.request(); String connectionId = ctx.pathParam("id"); if (connectionId == null) { - responder.badRequest("Connection id is missing"); + LOG.errorf("Connection id is missing: %s", ctx.normalizedPath()); + ctx.fail(400); return; } if (request.method() != HttpMethod.POST) { ctx.response().putHeader(HttpHeaders.ALLOW, "POST"); - responder.failure(405, "Invalid HTTP method %s [connectionId: %s]", ctx.request().method(), connectionId); + LOG.errorf("Invalid HTTP method %s [connectionId: %s]", ctx.request().method(), connectionId); + ctx.fail(405); return; } McpConnection connection = connectionManager.get(connectionId); if (connection == null) { - responder.sendError(null, JsonRPC.INTERNAL_ERROR, - "Unable to obtain the connection: " + connectionId); + LOG.errorf("Connection not found: %s", connectionId); + ctx.fail(400); return; } + SseResponder responder = new SseResponder(trafficLogger, (SseMcpConnection) connection); JsonObject message; try { @@ -60,6 +62,7 @@ public void handle(RoutingContext ctx) { String msg = "Unable to parse the JSON message"; LOG.errorf(e, msg); responder.sendError(null, JsonRPC.PARSE_ERROR, msg); + ctx.end(); return; } if (trafficLogger != null) { @@ -68,43 +71,28 @@ public void handle(RoutingContext ctx) { if (JsonRPC.validate(message, responder)) { handle(message, connection, responder); } + ctx.end(); } class SseResponder implements Responder { - final RoutingContext ctx; + final SseMcpConnection connection; final TrafficLogger trafficLogger; - SseResponder(TrafficLogger trafficLogger, RoutingContext ctx) { + SseResponder(TrafficLogger trafficLogger, SseMcpConnection connection) { this.trafficLogger = trafficLogger; - this.ctx = ctx; + this.connection = connection; } @Override public void send(JsonObject message) { if (message == null) { - ctx.end(); return; } if (trafficLogger != null) { trafficLogger.messageSent(message); } - setJsonContentType(ctx); - ctx.end(message.toBuffer()); - } - - public void badRequest(String logMessage, Object... params) { - LOG.errorf(logMessage, params); - ctx.fail(400); - } - - public void failure(int statusCode, String logMessage, Object... params) { - LOG.errorf(logMessage, params); - ctx.fail(statusCode); - } - - private void setJsonContentType(RoutingContext ctx) { - ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + connection.sendEvent("message", message.encode()); } } diff --git a/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpServerRecorder.java b/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpServerRecorder.java index c85b599..4aa9765 100644 --- a/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpServerRecorder.java +++ b/transports/sse/runtime/src/main/java/io/quarkiverse/mcp/server/sse/runtime/SseMcpServerRecorder.java @@ -49,13 +49,6 @@ public Handler createSseEndpointHandler(String mcpPath) { @Override public void handle(RoutingContext ctx) { - String contentType = ctx.request().getHeader(HttpHeaders.CONTENT_TYPE); - if (!"text/event-stream".equals(contentType)) { - LOG.errorf("Invalid content type: %s", contentType); - ctx.fail(400); - return; - } - HttpServerResponse response = ctx.response(); response.setChunked(true); response.headers().add(HttpHeaders.TRANSFER_ENCODING, "chunked");