From 8f6f2efc595072cc92fb0ad19b50d9d529a4dbca Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 23 Jan 2024 13:25:24 -0800 Subject: [PATCH] bug fix - tool parameters missing Signed-off-by: Jing Zhang --- .../algorithms/agent/MLChatAgentRunner.java | 1 + .../agent/MLChatAgentRunnerTest.java | 49 +++++++++++++++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index d2c513eabd..a2a9107c73 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -491,6 +491,7 @@ private void runReAct( llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); tools.get(action).run(llmToolTmpParameters, toolListener); // run tool } else { + toolParams.putAll(toolSpecMap.get(action).getParameters()); tools.get(action).run(toolParams, toolListener); // run tool } } catch (Exception e) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 2247bf00c5..8af4370a62 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -212,8 +212,20 @@ public void testRunWithIncludeOutputNotSet() { @Test public void testRunWithIncludeOutputSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(false).build(); - MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .includeOutputInAgentResponse(false) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .build(); + MLToolSpec secondToolSpec = MLToolSpec + .builder() + .name(SECOND_TOOL) + .type(SECOND_TOOL) + .includeOutputInAgentResponse(true) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .build(); final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") @@ -471,10 +483,41 @@ public void testToolThrowException() { assertNotNull(modelTensorOutput); } + @Test + public void testToolParameters() { + // Mock tool validation to return false. + when(firstTool.validate(any())).thenReturn(true); + + // Create an MLAgent with a tool including two parameters. + MLAgent mlAgent = createMLAgentWithTools(); + + // Create parameters for the agent. + Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + + // Run the MLChatAgentRunner. + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that the tool's run method was called. + verify(firstTool).run(any(), any()); + // Verify the size of parameters passed in the tool run method. + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(firstTool).run((Map) argumentCaptor.capture(), any()); + assertEquals(3, ((Map) argumentCaptor.getValue()).size()); + + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + assertNotNull(modelTensorOutput); + } + // Helper methods to create MLAgent and parameters private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .build(); return MLAgent.builder().name("TestAgent").tools(Arrays.asList(firstToolSpec)).memory(mlMemorySpec).llm(llmSpec).build(); }