Skip to content

Commit

Permalink
Detect duplicate prompt/tool names
Browse files Browse the repository at this point in the history
  • Loading branch information
mkouba committed Dec 16, 2024
1 parent d33d901 commit 6d383cc
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkiverse.mcp.server.deployment;

import java.util.Objects;

import org.jboss.jandex.MethodInfo;

import io.quarkus.arc.processor.BeanInfo;
Expand All @@ -17,12 +19,12 @@ final class FeatureMethodBuildItem extends MultiBuildItem {

FeatureMethodBuildItem(BeanInfo bean, MethodInfo method, InvokerInfo invoker, String name, String description,
Feature feature) {
this.bean = bean;
this.method = method;
this.invoker = invoker;
this.name = name;
this.bean = Objects.requireNonNull(bean);
this.method = Objects.requireNonNull(method);
this.invoker = Objects.requireNonNull(invoker);
this.name = Objects.requireNonNull(name);
this.description = description;
this.feature = feature;
this.feature = Objects.requireNonNull(feature);
}

BeanInfo getBean() {
Expand Down Expand Up @@ -57,6 +59,12 @@ boolean isPrompt() {
return feature == Feature.PROMPT;
}

@Override
public String toString() {
return "FeatureMethodBuildItem [name=" + name + ", method=" + method.declaringClass() + "#" + method.name()
+ "(), feature=" + feature + "]";
}

enum Feature {
PROMPT,
TOOL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import jakarta.enterprise.invoke.Invoker;
import jakarta.inject.Singleton;
Expand Down Expand Up @@ -38,6 +42,7 @@
import io.quarkus.arc.deployment.InvokerFactoryBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem;
import io.quarkus.arc.deployment.ValidationPhaseBuildItem.ValidationErrorBuildItem;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.InvokerBuilder;
Expand Down Expand Up @@ -106,7 +111,9 @@ void registerEndpoints(McpBuildTimeConfig config, HttpRootPathBuildItem rootPath

@BuildStep
void collectFeatureMethods(BeanDiscoveryFinishedBuildItem beanDiscovery, InvokerFactoryBuildItem invokerFactory,
BuildProducer<FeatureMethodBuildItem> features) {
BuildProducer<FeatureMethodBuildItem> features, BuildProducer<ValidationErrorBuildItem> errors) {
Map<Feature, List<FeatureMethodBuildItem>> found = new HashMap<>();

for (BeanInfo bean : beanDiscovery.beanStream().classBeans().filter(this::hasFeatureMethod)) {
ClassInfo beanClass = bean.getTarget().get().asClass();
for (MethodInfo method : beanClass.methods()) {
Expand All @@ -123,8 +130,34 @@ void collectFeatureMethods(BeanDiscoveryFinishedBuildItem beanDiscovery, Invoker
String description = descValue != null ? descValue.asString() : "";
InvokerBuilder invokerBuilder = invokerFactory.createInvoker(bean, method)
.withInstanceLookup();
features.produce(
new FeatureMethodBuildItem(bean, method, invokerBuilder.build(), name, description, feature));
FeatureMethodBuildItem fm = new FeatureMethodBuildItem(bean, method, invokerBuilder.build(), name,
description, feature);
features.produce(fm);
found.compute(feature, (f, list) -> {
if (list == null) {
list = new ArrayList<>();
}
list.add(fm);
return list;
});
}
}
}

// Check duplicate names
for (List<FeatureMethodBuildItem> featureMethods : found.values()) {
Map<String, List<FeatureMethodBuildItem>> byName = featureMethods.stream()
.collect(Collectors.toMap(FeatureMethodBuildItem::getName, List::of, (v1, v2) -> {
List<FeatureMethodBuildItem> list = new ArrayList<>();
list.addAll(v1);
list.addAll(v2);
return list;
}));
for (List<FeatureMethodBuildItem> list : byName.values()) {
if (list.size() > 1) {
String message = "Duplicate feature name found:\n\t%s"
.formatted(list.stream().map(Object::toString).collect(Collectors.joining("\n\t")));
errors.produce(new ValidationErrorBuildItem(new IllegalStateException(message)));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.quarkiverse.mcp.server.test.validation;

import static org.junit.jupiter.api.Assertions.fail;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.Prompt;
import io.quarkiverse.mcp.server.PromptResponse;
import io.quarkus.test.QuarkusUnitTest;

public class DuplicatePromptNameTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(InvalidPrompts.class);
})
.setExpectedException(IllegalStateException.class, true);

@Test
public void test() {
fail();
}

public static class InvalidPrompts {

@Prompt
PromptResponse foo() {
return null;
}

@Prompt(name = "foo")
PromptResponse foos() {
return null;
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.quarkiverse.mcp.server.test.validation;

import static org.junit.jupiter.api.Assertions.fail;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.Tool;
import io.quarkiverse.mcp.server.ToolResponse;
import io.quarkus.test.QuarkusUnitTest;

public class DuplicateToolNameTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(InvalidTools.class);
})
.setExpectedException(IllegalStateException.class, true);

@Test
public void test() {
fail();
}

public static class InvalidTools {

@Tool
ToolResponse foo() {
return null;
}

@Tool(name = "foo")
ToolResponse foos() {
return null;
}

}

}
7 changes: 3 additions & 4 deletions runtime/src/main/java/io/quarkiverse/mcp/server/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
* Annotates a business method of a CDI bean as an exposed prompt template.
* <p>
* The result of a "prompt get" operation is always represented as a {@link PromptResponse}. However, the annotated method can
* also return
* other types that are converted according to the following rules.
* also return other types that are converted according to the following rules.
* <ul>
* <li>If the method returns a {@link PromptMessage} then the reponse has no description and contains the single
* message object.</li>
Expand Down Expand Up @@ -45,12 +44,12 @@
String ELEMENT_NAME = "<<element name>>";

/**
*
* Each prompt must have a unique name. By default, the name is derived from the name of the annotated method.
*/
String name() default ELEMENT_NAME;

/**
*
* An optional description.
*/
String description() default "";

Expand Down
4 changes: 2 additions & 2 deletions runtime/src/main/java/io/quarkiverse/mcp/server/Tool.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
String ELEMENT_NAME = "<<element name>>";

/**
*
* Each tool must have a unique name. By default, the name is derived from the name of the annotated method.
*/
String name() default ELEMENT_NAME;

/**
*
* An optional description.
*/
String description() default "";

Expand Down

0 comments on commit 6d383cc

Please sign in to comment.