Skip to content

Commit

Permalink
add action input as parameters for tool execution in conversational a…
Browse files Browse the repository at this point in the history
…gent (#3200)

* add llm generated action input as parameters for tool execution in conversational agent

Signed-off-by: Jing Zhang <[email protected]>

* add UT for null action input

Signed-off-by: Jing Zhang <[email protected]>

* change llm_generated_action_input to llm_generated_input

Signed-off-by: Jing Zhang <[email protected]>

---------

Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es authored Dec 31, 2024
1 parent df1b1ef commit c850eef
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class AgentUtils {
public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix";
public static final String DISABLE_TRACE = "disable_trace";
public static final String VERBOSE = "verbose";
public static final String LLM_GEN_INPUT = "llm_generated_input";

public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
Map<String, String> examplesMap = new HashMap<>();
Expand Down Expand Up @@ -472,6 +473,11 @@ public static Map<String, String> constructToolParams(
if (toolSpecConfigMap != null) {
toolParams.putAll(toolSpecConfigMap);
}
toolParams.put(LLM_GEN_INPUT, actionInput);
if (isJson(actionInput)) {
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
toolParams.putAll(params);
}
if (tools.get(action).useOriginalInput()) {
toolParams.put("input", question);
lastActionInput.set(question);
Expand All @@ -486,10 +492,6 @@ public static Map<String, String> constructToolParams(
}
} else {
toolParams.put("input", actionInput);
if (isJson(actionInput)) {
Map<String, String> params = getParameterMap(gson.fromJson(actionInput, Map.class));
toolParams.putAll(params);
}
}
return toolParams;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_GEN_INPUT;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
Expand Down Expand Up @@ -603,11 +604,24 @@ public void testConstructToolParams() {
String question = "dummy question";
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(4, toolParams.size());
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(actionInput, toolParams.get("input"));
Assert.assertEquals("abc", toolParams.get("detectorName"));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
});
}

@Test
public void testConstructToolParamsNullActionInput() {
String question = "dummy question";
String actionInput = null;
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(3, toolParams.size());
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertNull(toolParams.get(LLM_GEN_INPUT));
Assert.assertNull(toolParams.get("input"));
});
}

Expand All @@ -617,12 +631,65 @@ public void testConstructToolParams_UseOriginalInput() {
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
when(tool1.useOriginalInput()).thenReturn(true);
verifyConstructToolParams(question, actionInput, (toolParams) -> {
Assert.assertEquals(2, toolParams.size());
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(question, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
Assert.assertEquals("sample-data", toolParams.get("indices"));
Assert.assertEquals("abc", toolParams.get("detectorName"));
});
}

@Test
public void testConstructToolParams_PlaceholderConfigInput() {
String question = "dummy question";
String actionInput = "action input";
String preConfigInputStr = "Config Input: ";
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"tool1",
MLToolSpec
.builder()
.type("tool1")
.parameters(Map.of("key1", "value1"))
.configMap(Map.of("input", preConfigInputStr + "${parameters.llm_generated_input}"))
.build()
);
AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "tool1";
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
Assert.assertEquals(3, toolParams.size());
Assert.assertEquals(preConfigInputStr + actionInput, toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
}

@Test
public void testConstructToolParams_PlaceholderConfigInputJson() {
String question = "dummy question";
String actionInput = "{'detectorName': 'abc', 'indices': 'sample-data' }";
String preConfigInputStr = "Config Input: ";
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
.of(
"tool1",
MLToolSpec
.builder()
.type("tool1")
.parameters(Map.of("key1", "value1"))
.configMap(Map.of("input", preConfigInputStr + "${parameters.detectorName}"))
.build()
);
AtomicReference<String> lastActionInput = new AtomicReference<>();
String action = "tool1";
Map<String, String> toolParams = AgentUtils.constructToolParams(tools, toolSpecMap, question, lastActionInput, action, actionInput);
Assert.assertEquals(5, toolParams.size());
Assert.assertEquals(preConfigInputStr + "abc", toolParams.get("input"));
Assert.assertEquals("value1", toolParams.get("key1"));
Assert.assertEquals(actionInput, toolParams.get(LLM_GEN_INPUT));
}

private void verifyConstructToolParams(String question, String actionInput, Consumer<Map<String, String>> verify) {
Map<String, Tool> tools = Map.of("tool1", tool1);
Map<String, MLToolSpec> toolSpecMap = Map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ public void testToolParameters() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(14, ((Map) argumentCaptor.getValue()).size());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
Expand Down Expand Up @@ -734,7 +734,7 @@ public void testToolUseOriginalInput() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
Expand Down Expand Up @@ -763,7 +763,7 @@ public void testToolConfig() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
// The value of input should be "config_value".
assertEquals("config_value", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down Expand Up @@ -793,7 +793,7 @@ public void testToolConfigWithInputPlaceholder() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
// The value of input should be replaced with the value associated with the key "key2" of the first tool.
assertEquals("value2", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down

0 comments on commit c850eef

Please sign in to comment.