Skip to content

Commit

Permalink
Merge pull request #14 from mkouba/add-basic-integration-test
Browse files Browse the repository at this point in the history
Add basic integration test
  • Loading branch information
mkouba authored Dec 16, 2024
2 parents 3968a7e + 315b2df commit 605ee9b
Show file tree
Hide file tree
Showing 12 changed files with 428 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
import org.jboss.jandex.DotName;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;

import org.jboss.jandex.Type.Kind;

import io.quarkiverse.mcp.server.Content;
import io.quarkiverse.mcp.server.ImageContent;
import io.quarkiverse.mcp.server.PromptMessage;
import io.quarkiverse.mcp.server.PromptResponse;
import io.quarkiverse.mcp.server.ResourceContent;
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.ToolResponse;
import io.quarkiverse.mcp.server.deployment.FeatureMethodBuildItem.Feature;
import io.quarkiverse.mcp.server.runtime.ExecutionModel;
import io.quarkiverse.mcp.server.runtime.FeatureArgument;
Expand Down Expand Up @@ -54,6 +62,8 @@
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveHierarchyBuildItem;
import io.quarkus.deployment.recording.RecorderContext;
import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.ClassCreator;
Expand Down Expand Up @@ -163,6 +173,68 @@ void collectFeatureMethods(BeanDiscoveryFinishedBuildItem beanDiscovery, Invoker
}
}

@Record(RUNTIME_INIT)
@BuildStep
void generateMetadata(McpServerRecorder recorder, RecorderContext recorderContext,
List<FeatureMethodBuildItem> featureMethods,
TransformedAnnotationsBuildItem transformedAnnotations,
BuildProducer<GeneratedClassBuildItem> generatedClasses, BuildProducer<SyntheticBeanBuildItem> syntheticBeans) {

// Note that the generated McpMetadata impl must be considered an application class
// so that it can see the generated invokers
ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClasses, true);

String metadataClassName = "io.quarkiverse.mcp.server.runtime.McpMetadata_Impl";
ClassCreator metadataCreator = ClassCreator.builder().classOutput(classOutput)
.className(metadataClassName)
.interfaces(McpMetadata.class)
.build();

// io.quarkiverse.mcp.server.runtime.McpMetadata.prompts()
MethodCreator promptsMethod = metadataCreator.getMethodCreator("prompts", List.class);
ResultHandle retPrompts = Gizmo.newArrayList(promptsMethod);
for (FeatureMethodBuildItem prompt : featureMethods.stream().filter(FeatureMethodBuildItem::isPrompt).toList()) {
processFeatureMethod(promptsMethod, prompt, retPrompts, transformedAnnotations, DotNames.PROMPT_ARG);
}
promptsMethod.returnValue(retPrompts);

// io.quarkiverse.mcp.server.runtime.McpMetadata.tools()
MethodCreator toolsMethod = metadataCreator.getMethodCreator("tools", List.class);
ResultHandle retTools = Gizmo.newArrayList(toolsMethod);
for (FeatureMethodBuildItem tool : featureMethods.stream().filter(FeatureMethodBuildItem::isTool).toList()) {
processFeatureMethod(toolsMethod, tool, retTools, transformedAnnotations, DotNames.TOOL_ARG);
}
toolsMethod.returnValue(retTools);

metadataCreator.close();

syntheticBeans.produce(SyntheticBeanBuildItem.configure(McpMetadata.class)
.scope(Singleton.class)
.setRuntimeInit()
.runtimeValue(recorderContext.newInstance(metadataClassName))
.done());
}

@BuildStep
void registerForReflection(List<FeatureMethodBuildItem> featureMethods,
BuildProducer<ReflectiveClassBuildItem> reflectiveClasses,
BuildProducer<ReflectiveHierarchyBuildItem> reflectiveHierarchies) {
// FIXME this is not ideal, JsonObject.encode() may use Jackson under the hood which requires reflection
for (FeatureMethodBuildItem m : featureMethods) {
for (org.jboss.jandex.Type paramType : m.getMethod().parameterTypes()) {
if (paramType.kind() == Kind.PRIMITIVE) {
continue;
}
reflectiveHierarchies.produce(ReflectiveHierarchyBuildItem.builder(paramType).build());
}
}
reflectiveClasses.produce(ReflectiveClassBuildItem.builder(Content.class, TextContent.class, ImageContent.class,
ResourceContent.class, PromptResponse.class, PromptMessage.class, ToolResponse.class, FeatureMethodInfo.class,
FeatureArgument.class).methods().build());
reflectiveHierarchies.produce(ReflectiveHierarchyBuildItem.builder(List.class).build());
reflectiveHierarchies.produce(ReflectiveHierarchyBuildItem.builder(Map.class).build());
}

private void validateFeatureMethod(MethodInfo method, Feature feature) {
if (Modifier.isStatic(method.flags())) {
throw new IllegalStateException("MCP feature method must not be static: " + method);
Expand Down Expand Up @@ -215,48 +287,6 @@ private boolean hasFeatureMethod(BeanInfo bean) {
return beanClass.hasAnnotation(DotNames.PROMPT) || beanClass.hasAnnotation(DotNames.TOOL);
}

@Record(RUNTIME_INIT)
@BuildStep
void generateMetadata(McpServerRecorder recorder, RecorderContext recorderContext,
List<FeatureMethodBuildItem> featureMethods,
TransformedAnnotationsBuildItem transformedAnnotations,
BuildProducer<GeneratedClassBuildItem> generatedClasses, BuildProducer<SyntheticBeanBuildItem> syntheticBeans) {

// Note that the generated McpMetadata impl must be considered an application class
// so that it can see the generated invokers
ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClasses, true);

String metadataClassName = "io.quarkiverse.mcp.server.runtime.McpMetadata_Impl";
ClassCreator metadataCreator = ClassCreator.builder().classOutput(classOutput)
.className(metadataClassName)
.interfaces(McpMetadata.class)
.build();

// io.quarkiverse.mcp.server.runtime.McpMetadata.prompts()
MethodCreator promptsMethod = metadataCreator.getMethodCreator("prompts", List.class);
ResultHandle retPrompts = Gizmo.newArrayList(promptsMethod);
for (FeatureMethodBuildItem prompt : featureMethods.stream().filter(FeatureMethodBuildItem::isPrompt).toList()) {
processFeatureMethod(promptsMethod, prompt, retPrompts, transformedAnnotations, DotNames.PROMPT_ARG);
}
promptsMethod.returnValue(retPrompts);

// io.quarkiverse.mcp.server.runtime.McpMetadata.tools()
MethodCreator toolsMethod = metadataCreator.getMethodCreator("tools", List.class);
ResultHandle retTools = Gizmo.newArrayList(toolsMethod);
for (FeatureMethodBuildItem tool : featureMethods.stream().filter(FeatureMethodBuildItem::isTool).toList()) {
processFeatureMethod(toolsMethod, tool, retTools, transformedAnnotations, DotNames.TOOL_ARG);
}
toolsMethod.returnValue(retTools);

metadataCreator.close();

syntheticBeans.produce(SyntheticBeanBuildItem.configure(McpMetadata.class)
.scope(Singleton.class)
.setRuntimeInit()
.runtimeValue(recorderContext.newInstance(metadataClassName))
.done());
}

private void processFeatureMethod(MethodCreator method, FeatureMethodBuildItem featureMethod, ResultHandle retList,
TransformedAnnotationsBuildItem transformedAnnotations, DotName argAnnotationName) {
ResultHandle args = Gizmo.newArrayList(method);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public abstract class McpServerTest {
@TestHTTPResource
URI testUri;

List<SseEvent<String>> sseMessages;
protected List<SseEvent<String>> sseMessages;

AtomicInteger idGenerator = new AtomicInteger();

Expand Down
8 changes: 8 additions & 0 deletions integration-tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest-client</artifactId>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.mcp</groupId>
<artifactId>quarkus-mcp-server</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkiverse.mcp.server.it;

import jakarta.enterprise.context.RequestScoped;

@RequestScoped
public class CodeService {

public String assist(String language) {
return switch (language) {
case "java" -> "System.out.println(\"Hello world!\");";
default -> throw new IllegalArgumentException("Unexpected value: " + language);
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.mcp.server.it;

import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE;
import static jakarta.ws.rs.core.MediaType.SERVER_SENT_EVENTS;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;

import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.resteasy.reactive.client.SseEvent;

import io.smallrye.mutiny.Multi;

@Path("mcp")
@RegisterRestClient
public interface McpClient {

@GET
@Path("sse")
@ClientHeaderParam(name = CONTENT_TYPE, value = SERVER_SENT_EVENTS)
@Produces(SERVER_SENT_EVENTS)
Multi<SseEvent<String>> init();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.quarkiverse.mcp.server.it;

import java.net.URI;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

import jakarta.enterprise.event.Observes;

import org.awaitility.Awaitility;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.jboss.resteasy.reactive.client.SseEvent;

import io.quarkus.logging.Log;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.vertx.ext.web.Router;

// A workaround to init the MCP client via SSE in the test
public class McpClientInit {

void initRoute(@Observes Router router, @ConfigProperty(name = "quarkus.http.host") String host,
@ConfigProperty(name = "quarkus.http.test-port") int port,
@ConfigProperty(name = "quarkus.http.root-path") String root) {
router.route("/test-init-mcp-client").blockingHandler(rc -> {
try {
URI baseUri = new URI("http://" + host + ":" + port + root);
Log.infof("Test base URI: %s", baseUri);
List<SseEvent<String>> sseMessages = new CopyOnWriteArrayList<>();
McpClient mcpClient = QuarkusRestClientBuilder.newBuilder()
.baseUri(baseUri)
.build(McpClient.class);
mcpClient.init().subscribe().with(s -> sseMessages.add(s), e -> {
});
Awaitility.await().until(() -> !sseMessages.isEmpty());
URI endpoint = new URI(sseMessages.get(0).data());
rc.end(endpoint.toString());
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.mcp.server.it;

import jakarta.inject.Inject;

import io.quarkiverse.mcp.server.Prompt;
import io.quarkiverse.mcp.server.PromptArg;
import io.quarkiverse.mcp.server.PromptMessage;
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.Tool;

public class ServerFeatures {

@Inject
CodeService codeService;

@Tool
TextContent toLowerCase(String value) {
return new TextContent(value.toLowerCase());
}

@Prompt(name = "code_assist")
PromptMessage codeAssist(@PromptArg(name = "lang") String language) {
return PromptMessage.withUserRole(new TextContent(codeService.assist(language)));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.quarkiverse.mcp.server.it;

import io.quarkus.test.junit.QuarkusIntegrationTest;

@QuarkusIntegrationTest
public class ServerFeaturesIT extends ServerFeaturesTest {

}
Loading

0 comments on commit 605ee9b

Please sign in to comment.