Skip to content

Commit

Permalink
bug fix - tool parameters missing
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Jan 23, 2024
1 parent d0895bb commit 8f6f2ef
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ private void runReAct(
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
tools.get(action).run(llmToolTmpParameters, toolListener); // run tool
} else {
toolParams.putAll(toolSpecMap.get(action).getParameters());
tools.get(action).run(toolParams, toolListener); // run tool
}
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,20 @@ public void testRunWithIncludeOutputNotSet() {
@Test
public void testRunWithIncludeOutputSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(false).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build();
MLToolSpec firstToolSpec = MLToolSpec
.builder()
.name(FIRST_TOOL)
.type(FIRST_TOOL)
.includeOutputInAgentResponse(false)
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
.build();
MLToolSpec secondToolSpec = MLToolSpec
.builder()
.name(SECOND_TOOL)
.type(SECOND_TOOL)
.includeOutputInAgentResponse(true)
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
.build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
Expand Down Expand Up @@ -471,10 +483,41 @@ public void testToolThrowException() {
assertNotNull(modelTensorOutput);
}

@Test
public void testToolParameters() {
// Mock tool validation to return false.
when(firstTool.validate(any())).thenReturn(true);

// Create an MLAgent with a tool including two parameters.
MLAgent mlAgent = createMLAgentWithTools();

// Create parameters for the agent.
Map<String, String> params = createAgentParamsWithAction(FIRST_TOOL, "someInput");

// Run the MLChatAgentRunner.
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify that the tool's run method was called.
verify(firstTool).run(any(), any());
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(3, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
assertNotNull(modelTensorOutput);
}

// Helper methods to create MLAgent and parameters
private MLAgent createMLAgentWithTools() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec firstToolSpec = MLToolSpec
.builder()
.name(FIRST_TOOL)
.type(FIRST_TOOL)
.parameters(ImmutableMap.of("key1", "value1", "key2", "value2"))
.build();
return MLAgent.builder().name("TestAgent").tools(Arrays.asList(firstToolSpec)).memory(mlMemorySpec).llm(llmSpec).build();
}

Expand Down

0 comments on commit 8f6f2ef

Please sign in to comment.