Skip to content

Commit

Permalink
Merge pull request #34 from mkouba/issue-32
Browse files Browse the repository at this point in the history
SSE transport: always send server messages as SSE events
  • Loading branch information
mkouba authored Jan 8, 2025
2 parents b491ef6 + cc2a770 commit b86b5b0
Show file tree
Hide file tree
Showing 30 changed files with 383 additions and 288 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ private void initializing(JsonObject message, Responder responder, McpConnection
String method = message.getString("method");
if (NOTIFICATIONS_INITIALIZED.equals(method)) {
if (connection.setInitialized()) {
responder.send(null);
LOG.infof("Client successfully initialized [%s]", connection.id());
}
} else if (PING.equals(method)) {
Expand Down Expand Up @@ -128,7 +127,6 @@ private void ping(JsonObject message, Responder responder) {
private void close(JsonObject message, Responder responder, McpConnection connection) {
if (connectionManager.remove(connection.id())) {
LOG.infof("Connection %s closed", connection.id());
responder.send(null);
} else {
responder.sendError(message.getValue("id"), JsonRPC.INTERNAL_ERROR,
"Unable to obtain the connection to be closed:" + connection.id());
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<module>core</module>
<module>transports/sse</module>
<module>transports/stdio</module>
<module>test-utils</module>
</modules>

<scm>
Expand Down
25 changes: 25 additions & 0 deletions test-utils/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>io.quarkiverse.mcp</groupId>
<artifactId>quarkus-mcp-server-parent</artifactId>
<version>999-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>quarkus-mcp-server-test-utils</artifactId>
<name>Quarkus MCP Server Test Utils</name>

<dependencies>
<dependency>
<groupId>org.jboss.logging</groupId>
<artifactId>jboss-logging</artifactId>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
</dependency>
</dependencies>

</project>
145 changes: 145 additions & 0 deletions test-utils/src/main/java/io/quarkiverse/mcp/server/test/SseClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package io.quarkiverse.mcp.server.test;

import java.io.EOFException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpClient.Version;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse.BodyHandlers;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicInteger;

import org.awaitility.Awaitility;
import org.jboss.logging.Logger;

public class SseClient {

private static final Logger LOG = Logger.getLogger(SseClient.class);

private final URI testUri;

private final AtomicInteger idGenerator;

public final List<SseEvent> events;

public SseClient(URI uri) {
this.testUri = uri;
this.idGenerator = new AtomicInteger();
this.events = new CopyOnWriteArrayList<>();
}

public int nextId() {
return idGenerator.incrementAndGet();
}

public SseEvent waitForFirstEvent() {
nextId();
Awaitility.await().until(() -> !events.isEmpty());
return events.get(0);
}

public SseEvent waitForLastEvent() {
int lastId = idGenerator.get();
Awaitility.await().until(() -> events.size() >= lastId);
return events.get(lastId - 1);
}

public void connect() {
HttpClient client = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder()
.uri(testUri)
.version(Version.HTTP_1_1)
.header("Accept", "text/event-stream")
.GET()
.build();

client.sendAsync(request, BodyHandlers.fromLineSubscriber(new SseEventSubscriber()))
.thenAccept(response -> {
if (response.statusCode() == 200) {
LOG.infof("Connected to SSE stream: %s", testUri);
} else {
LOG.errorf("Failed to connect %s: %s", response.statusCode(), testUri);
}
})
.exceptionally(e -> {
Throwable root = getRootCause(e);
if (!(root instanceof EOFException)) {
LOG.error(e);
}
return null;
});

}

class SseEventSubscriber implements Flow.Subscriber<String> {

private Flow.Subscription subscription;

private String event = "message";
private StringBuilder dataBuffer = new StringBuilder();

@Override
public void onSubscribe(Flow.Subscription subscription) {
this.subscription = subscription;
subscription.request(1);
}

@Override
public void onNext(String line) {
LOG.debugf("Received next line:\n%s", line);
if (line.startsWith(":")) {
// Skip comments
} else if (line.isBlank()) {
// Flush
events.add(new SseEvent(event, dataBuffer.toString()));
event = "message";
dataBuffer = new StringBuilder();
} else if (line.contains(":")) {
int colon = line.indexOf(":");
String field = line.substring(0, colon).strip();
String value = line.substring(colon + 1).strip();
handleField(field, value);
} else {
// The whole line is the field name
handleField(line, "");
}
subscription.request(1);
}

@Override
public void onError(Throwable t) {
Throwable root = getRootCause(t);
if (!(root instanceof EOFException)) {
LOG.errorf(t, "Error in SSE stream");
}
}

@Override
public void onComplete() {
LOG.debug("SSE stream complete");
}

private void handleField(String field, String value) {
switch (field) {
case "event" -> event = value;
case "data" -> dataBuffer.append(value).append("\n");
}
}
}

public record SseEvent(String name, String data) {
}

private static Throwable getRootCause(Throwable exception) {
final List<Throwable> chain = new ArrayList<>();
Throwable curr = exception;
while (curr != null && !chain.contains(curr)) {
chain.add(curr);
curr = curr.getCause();
}
return chain.isEmpty() ? null : chain.get(chain.size() - 1);
}
}
9 changes: 5 additions & 4 deletions transports/sse/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest-client</artifactId>
<groupId>io.rest-assured</groupId>
<artifactId>rest-assured</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.rest-assured</groupId>
<artifactId>rest-assured</artifactId>
<groupId>io.quarkiverse.mcp</groupId>
<artifactId>quarkus-mcp-server-test-utils</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@

import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import org.awaitility.Awaitility;
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.junit.jupiter.api.AfterEach;

import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.restassured.http.ContentType;
Expand All @@ -28,32 +23,46 @@ public abstract class McpServerTest {
@TestHTTPResource
URI testUri;

SseClient sseClient;

public static QuarkusUnitTest defaultConfig() {
// TODO in theory, we should also add SseClient to all test archives
// but the test CL can see the class and we don't need Quarkus to analyze this util class
QuarkusUnitTest config = new QuarkusUnitTest();
if (System.getProperty("logTraffic") != null) {
config.overrideConfigKey("quarkus.mcp.server.sse.traffic-logging.enabled", "true");
}
return config;
}

protected List<SseEvent<String>> sseMessages;

AtomicInteger idGenerator = new AtomicInteger();
@AfterEach
void cleanup() {
sseClient = null;
}

protected URI initClient() throws URISyntaxException {
return initClient(null);
}

protected JsonObject waitForLastJsonMessage() {
SseClient.SseEvent event = sseClient.waitForLastEvent();
if ("message".equals(event.name())) {
return new JsonObject(event.data());
}
throw new IllegalStateException("Message event not received: " + event);
}

protected URI initClient(Consumer<JsonObject> initResultAssert) throws URISyntaxException {
McpClient mcpClient = QuarkusRestClientBuilder.newBuilder()
.baseUri(testUri)
.build(McpClient.class);

sseMessages = new CopyOnWriteArrayList<>();
mcpClient.init().subscribe().with(s -> sseMessages.add(s), e -> {
});
Awaitility.await().until(() -> !sseMessages.isEmpty());
URI endpoint = new URI(sseMessages.get(0).data());
String testUriStr = testUri.toString();
if (testUriStr.endsWith("/")) {
testUriStr = testUriStr.substring(0, testUriStr.length() - 1);
}
sseClient = new SseClient(URI.create(testUriStr + "/mcp/sse"));
sseClient.connect();
var event = sseClient.waitForFirstEvent();
String messagesUri = testUriStr + event.data().strip();
URI endpoint = URI.create(messagesUri);

LOG.infof("Client received endpoint: %s", endpoint);

JsonObject initMessage = newMessage("initialize")
Expand All @@ -64,14 +73,15 @@ protected URI initClient(Consumer<JsonObject> initResultAssert) throws URISyntax
.put("version", "1.0"))
.put("protocolVersion", "2024-11-05"));

JsonObject initResponse = new JsonObject(given()
given()
.contentType(ContentType.JSON)
.when()
.body(initMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());
.statusCode(200);

JsonObject initResponse = waitForLastJsonMessage();

JsonObject initResult = assertResponseMessage(initMessage, initResponse);
assertNotNull(initResult);
Expand Down Expand Up @@ -102,10 +112,13 @@ protected JsonObject assertResponseMessage(JsonObject message, JsonObject respon
}

protected JsonObject newMessage(String method) {
if (sseClient == null) {
throw new IllegalStateException("SSE client not initialized");
}
return new JsonObject()
.put("jsonrpc", "2.0")
.put("method", method)
.put("id", idGenerator.incrementAndGet());
.put("id", sseClient.nextId());
}

}
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package io.quarkiverse.mcp.server.test.close;

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.equalTo;

import java.net.URI;
import java.net.URISyntaxException;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.runtime.JsonRPC;
import io.quarkiverse.mcp.server.test.McpClient;
import io.quarkiverse.mcp.server.test.McpServerTest;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.http.ContentType;
Expand All @@ -20,7 +17,7 @@ public class CloseTest extends McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = defaultConfig()
.withApplicationRoot(root -> root.addClasses(McpClient.class));
.withEmptyApplication();

@Test
public void testCloseMessage() throws URISyntaxException {
Expand All @@ -42,7 +39,6 @@ public void testCloseMessage() throws URISyntaxException {
.body(closeMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.body("error.code", equalTo(JsonRPC.INTERNAL_ERROR));
.statusCode(400);
}
}
Loading

0 comments on commit b86b5b0

Please sign in to comment.