Skip to content

Commit

Permalink
Merge pull request #10 from mkouba/tools-support-more-return-types
Browse files Browse the repository at this point in the history
Tools support more return types
  • Loading branch information
mkouba authored Dec 13, 2024
2 parents 76e5017 + b7eaa3c commit 766974c
Showing 20 changed files with 411 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -4,12 +4,16 @@

import org.jboss.jandex.DotName;

import io.quarkiverse.mcp.server.Content;
import io.quarkiverse.mcp.server.ImageContent;
import io.quarkiverse.mcp.server.McpConnection;
import io.quarkiverse.mcp.server.Prompt;
import io.quarkiverse.mcp.server.PromptArg;
import io.quarkiverse.mcp.server.PromptMessage;
import io.quarkiverse.mcp.server.PromptResponse;
import io.quarkiverse.mcp.server.RequestId;
import io.quarkiverse.mcp.server.ResourceContent;
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.Tool;
import io.quarkiverse.mcp.server.ToolArg;
import io.quarkiverse.mcp.server.ToolResponse;
@@ -37,5 +41,9 @@ class DotNames {
static final DotName TRANSACTIONAL = DotName.createSimple("jakarta.transaction.Transactional");
static final DotName MCP_CONNECTION = DotName.createSimple(McpConnection.class);
static final DotName REQUEST_ID = DotName.createSimple(RequestId.class);
static final DotName CONTENT = DotName.createSimple(Content.class);
static final DotName TEXT_CONTENT = DotName.createSimple(TextContent.class);
static final DotName IMAGE_CONTENT = DotName.createSimple(ImageContent.class);
static final DotName RESOURCE_CONTENT = DotName.createSimple(ResourceContent.class);

}
Original file line number Diff line number Diff line change
@@ -4,8 +4,10 @@
import static io.quarkiverse.mcp.server.deployment.FeatureMethodBuildItem.Feature.TOOL;
import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT;

import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Set;
import java.util.function.Function;

import jakarta.enterprise.invoke.Invoker;
@@ -14,10 +16,12 @@
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.DotName;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;

import io.quarkiverse.mcp.server.deployment.FeatureMethodBuildItem.Feature;
import io.quarkiverse.mcp.server.runtime.ExecutionModel;
import io.quarkiverse.mcp.server.runtime.FeatureArgument;
import io.quarkiverse.mcp.server.runtime.FeatureMetadata;
@@ -111,21 +115,68 @@ void collectFeatureMethods(BeanDiscoveryFinishedBuildItem beanDiscovery, Invoker
featureAnnotation = method.declaredAnnotation(DotNames.TOOL);
}
if (featureAnnotation != null) {
// TODO validate method
Feature feature = featureAnnotation.name().equals(DotNames.PROMPT) ? PROMPT : TOOL;
validateFeatureMethod(method, feature);
AnnotationValue nameValue = featureAnnotation.value("name");
String name = nameValue != null ? nameValue.asString() : method.name();
AnnotationValue descValue = featureAnnotation.value("description");
String description = descValue != null ? descValue.asString() : "";
InvokerBuilder invokerBuilder = invokerFactory.createInvoker(bean, method)
.withInstanceLookup();
features.produce(
new FeatureMethodBuildItem(bean, method, invokerBuilder.build(), name, description,
featureAnnotation.name().equals(DotNames.PROMPT) ? PROMPT : TOOL));
new FeatureMethodBuildItem(bean, method, invokerBuilder.build(), name, description, feature));
}
}
}
}

private void validateFeatureMethod(MethodInfo method, Feature feature) {
if (Modifier.isStatic(method.flags())) {
throw new IllegalStateException("MCP feature method must not be static: " + method);
}
if (Modifier.isPrivate(method.flags())) {
throw new IllegalStateException("MCP feature method must not be private: " + method);
}
switch (feature) {
case PROMPT -> validatePromptMethod(method);
case TOOL -> validateToolMethod(method);
default -> throw new IllegalArgumentException("Unsupported feature: " + feature);
}
}

private static final Set<org.jboss.jandex.Type> PROMPT_TYPES = Set.of(ClassType.create(DotNames.PROMPT_RESPONSE),
ClassType.create(DotNames.PROMPT_MESSAGE));

private void validatePromptMethod(MethodInfo method) {
org.jboss.jandex.Type type = method.returnType();
if (DotNames.UNI.equals(type.name())) {
type = type.asParameterizedType().arguments().get(0);
}
if (DotNames.LIST.equals(type.name())) {
type = type.asParameterizedType().arguments().get(0);
}
if (!PROMPT_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported prompt method return type: " + method.returnType());
}
}

private static final Set<org.jboss.jandex.Type> TOOL_TYPES = Set.of(ClassType.create(DotNames.TOOL_RESPONSE),
ClassType.create(DotNames.CONTENT), ClassType.create(DotNames.TEXT_CONTENT),
ClassType.create(DotNames.IMAGE_CONTENT), ClassType.create(DotNames.RESOURCE_CONTENT));

private void validateToolMethod(MethodInfo method) {
org.jboss.jandex.Type type = method.returnType();
if (DotNames.UNI.equals(type.name())) {
type = type.asParameterizedType().arguments().get(0);
}
if (DotNames.LIST.equals(type.name())) {
type = type.asParameterizedType().arguments().get(0);
}
if (!TOOL_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported Tool method return type: " + method.returnType());
}
}

private boolean hasFeatureMethod(BeanInfo bean) {
ClassInfo beanClass = bean.getTarget().get().asClass();
return beanClass.hasAnnotation(DotNames.PROMPT) || beanClass.hasAnnotation(DotNames.TOOL);
@@ -195,7 +246,6 @@ private void processFeatureMethod(MethodCreator method, FeatureMethodBuildItem f
required = requiredValue.asBoolean();
}
}
// TODO validate types
ResultHandle type = Types.getTypeHandle(method, pi.type());
ResultHandle provider;
if (pi.type().name().equals(DotNames.MCP_CONNECTION)) {
@@ -257,14 +307,27 @@ private ResultHandle promptMapper(BytecodeCreator bytecode, org.jboss.jandex.Typ

private ResultHandle toolMapper(BytecodeCreator bytecode, org.jboss.jandex.Type returnType) {
if (returnType.name().equals(DotNames.TOOL_RESPONSE)) {
// ToolResponse
return resultMapper(bytecode, "TO_UNI");
} else if (isContent(returnType.name())) {
return resultMapper(bytecode, "TOOL_CONTENT");
} else if (returnType.name().equals(DotNames.LIST)) {
return resultMapper(bytecode, "TOOL_LIST_CONTENT");
} else if (returnType.name().equals(DotNames.UNI)) {
// Uni<ToolResponse>
return resultMapper(bytecode, "IDENTITY");
} else {
throw new IllegalArgumentException("Unsupported return type");
org.jboss.jandex.Type typeArg = returnType.asParameterizedType().arguments().get(0);
if (typeArg.name().equals(DotNames.TOOL_RESPONSE)) {
return resultMapper(bytecode, "IDENTITY");
} else if (isContent(typeArg.name())) {
return resultMapper(bytecode, "TOOL_UNI_CONTENT");
} else if (typeArg.name().equals(DotNames.LIST)) {
return resultMapper(bytecode, "TOOL_UNI_LIST_CONTENT");
}
}
throw new IllegalArgumentException("Unsupported return type");
}

private boolean isContent(DotName typeName) {
return DotNames.CONTENT.equals(typeName) || DotNames.TEXT_CONTENT.equals(typeName)
|| DotNames.IMAGE_CONTENT.equals(typeName) || DotNames.RESOURCE_CONTENT.equals(typeName);
}

private ResultHandle resultMapper(BytecodeCreator bytecode, String contantName) {
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.quarkiverse.mcp.server.test;

import io.quarkiverse.mcp.server.McpConnection;
import io.quarkiverse.mcp.server.McpConnection.Status;
import io.quarkiverse.mcp.server.RequestId;
import io.quarkus.arc.Arc;
import io.quarkus.runtime.BlockingOperationControl;
import io.smallrye.common.vertx.VertxContext;

public class Checks {

public static void checkRequestContext() {
if (!Arc.container().requestContext().isActive()) {
throw new IllegalStateException("Request context not active");
}
}

public static void checkExecutionModel(boolean blocking) {
if (BlockingOperationControl.isBlockingAllowed() && !blocking) {
throw new IllegalStateException("Invalid execution model");
}
}

public static void checkDuplicatedContext() {
if (!VertxContext.isOnDuplicatedContext()) {
throw new IllegalStateException("Not on duplicated context");
}
}

public static void checkRequestId(RequestId id) {
if (id == null || id.asInteger() < 1) {
throw new IllegalStateException("Invalid request id: " + id);
}
}

public static void checkMcpConnection(McpConnection connection) {
if (connection == null || connection.status() != Status.IN_OPERATION) {
throw new IllegalStateException("Invalid connection: " + connection);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.quarkiverse.mcp.server.test;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.context.RequestScoped;

@ApplicationScoped
@RequestScoped
public class FooService {

public String ping(String name, int repeat, Options options) {
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package io.quarkiverse.mcp.server.test.prompts;

import static io.quarkiverse.mcp.server.test.Checks.checkDuplicatedContext;
import static io.quarkiverse.mcp.server.test.Checks.checkExecutionModel;
import static io.quarkiverse.mcp.server.test.Checks.checkMcpConnection;
import static io.quarkiverse.mcp.server.test.Checks.checkRequestContext;
import static io.quarkiverse.mcp.server.test.Checks.checkRequestId;

import java.util.List;

import jakarta.inject.Inject;

import io.quarkiverse.mcp.server.McpConnection;
import io.quarkiverse.mcp.server.McpConnection.Status;
import io.quarkiverse.mcp.server.Prompt;
import io.quarkiverse.mcp.server.PromptArg;
import io.quarkiverse.mcp.server.PromptMessage;
@@ -14,9 +19,6 @@
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.test.FooService;
import io.quarkiverse.mcp.server.test.Options;
import io.quarkus.arc.Arc;
import io.quarkus.runtime.BlockingOperationControl;
import io.smallrye.common.vertx.VertxContext;
import io.smallrye.mutiny.Uni;

public class MyPrompts {
@@ -75,34 +77,4 @@ Uni<PromptResponse> uni_response(String val) {
.item(new PromptResponse("My description", List.of(PromptMessage.user(new TextContent(val.toUpperCase())))));
}

private void checkRequestContext() {
if (!Arc.container().requestContext().isActive()) {
throw new IllegalStateException("Request context not active");
}
}

private void checkExecutionModel(boolean blocking) {
if (BlockingOperationControl.isBlockingAllowed() && !blocking) {
throw new IllegalStateException("Invalid execution model");
}
}

private void checkDuplicatedContext() {
if (!VertxContext.isOnDuplicatedContext()) {
throw new IllegalStateException("Not on duplicated context");
}
}

private void checkRequestId(RequestId id) {
if (id == null || id.asInteger() < 1) {
throw new IllegalStateException("Invalid request id: " + id);
}
}

private void checkMcpConnection(McpConnection connection) {
if (connection == null || connection.status() != Status.IN_OPERATION) {
throw new IllegalStateException("Invalid connection: " + connection);
}
}

}
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.test.Checks;
import io.quarkiverse.mcp.server.test.FooService;
import io.quarkiverse.mcp.server.test.McpClient;
import io.quarkiverse.mcp.server.test.McpServerTest;
@@ -24,7 +25,8 @@ public class PromptsTest extends McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class, FooService.class, Options.class, MyPrompts.class));
.withApplicationRoot(
root -> root.addClasses(McpClient.class, FooService.class, Options.class, Checks.class, MyPrompts.class));

@Test
public void testPrompts() throws URISyntaxException {
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package io.quarkiverse.mcp.server.test.tools;

import static io.quarkiverse.mcp.server.test.Checks.checkDuplicatedContext;
import static io.quarkiverse.mcp.server.test.Checks.checkExecutionModel;
import static io.quarkiverse.mcp.server.test.Checks.checkRequestContext;

import jakarta.inject.Inject;

import io.quarkiverse.mcp.server.Content;
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.Tool;
import io.quarkiverse.mcp.server.ToolArg;
import io.quarkiverse.mcp.server.ToolResponse;
import io.quarkiverse.mcp.server.test.FooService;
import io.quarkiverse.mcp.server.test.Options;
import io.smallrye.mutiny.Uni;

public class MyTools {

@@ -16,13 +22,35 @@ public class MyTools {

@Tool
ToolResponse alpha(@ToolArg(description = "Define the price...") int price) {
checkExecutionModel(true);
checkDuplicatedContext();
checkRequestContext();
return ToolResponse.success(
new TextContent(fooService.ping(price + "", 1, new Options(true))));
}

@Tool
ToolResponse uni_alpha(@ToolArg(name = "uni_price") double price) {
return ToolResponse.success(
new TextContent(fooService.ping(price + "", 1, new Options(true))));
Uni<ToolResponse> uni_alpha(@ToolArg(name = "uni_price") double price) {
checkExecutionModel(false);
checkDuplicatedContext();
checkRequestContext();
return Uni.createFrom().item(ToolResponse.success(
new TextContent(fooService.ping(price + "", 1, new Options(true)))));
}

@Tool
TextContent bravo(int price) {
checkExecutionModel(true);
checkDuplicatedContext();
checkRequestContext();
return new TextContent(fooService.ping(price + "", 1, new Options(true)));
}

@Tool
Uni<Content> uni_bravo(int price) {
checkExecutionModel(false);
checkDuplicatedContext();
checkRequestContext();
return Uni.createFrom().item(new TextContent(fooService.ping(price + "", 1, new Options(true))));
}
}
Loading

0 comments on commit 766974c

Please sign in to comment.