forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Hailong Cui <[email protected]>
- Loading branch information
1 parent
c8e2dc6
commit 889d9e2
Showing
1 changed file
with
223 additions
and
0 deletions.
There are no files selected for viewing
223 changes: 223 additions & 0 deletions
223
...rithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
package org.opensearch.ml.engine.algorithms.agent; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Captor; | ||
import org.mockito.Mock; | ||
import org.mockito.Mockito; | ||
import org.mockito.MockitoAnnotations; | ||
import org.mockito.stubbing.Answer; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.StepListener; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.agent.LLMSpec; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.agent.MLMemorySpec; | ||
import org.opensearch.ml.common.agent.MLToolSpec; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
import org.opensearch.ml.common.output.model.ModelTensor; | ||
import org.opensearch.ml.common.output.model.ModelTensorOutput; | ||
import org.opensearch.ml.common.output.model.ModelTensors; | ||
import org.opensearch.ml.common.spi.memory.Memory; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.transport.MLTaskResponse; | ||
import org.opensearch.ml.engine.memory.ConversationIndexMemory; | ||
import org.opensearch.ml.engine.memory.MLMemoryManager; | ||
import software.amazon.awssdk.utils.ImmutableMap; | ||
|
||
import java.util.Arrays; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.IntStream; | ||
|
||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.ArgumentMatchers.anyString; | ||
import static org.mockito.ArgumentMatchers.isA; | ||
import static org.mockito.Mockito.doAnswer; | ||
import static org.mockito.Mockito.when; | ||
|
||
public class MLChatAgentRunnerTest { | ||
public static final String FIRST_TOOL = "firstTool"; | ||
public static final String SECOND_TOOL = "secondTool"; | ||
|
||
@Mock | ||
private Client client; | ||
|
||
private Settings settings; | ||
|
||
@Mock | ||
private ClusterService clusterService; | ||
|
||
@Mock | ||
private NamedXContentRegistry xContentRegistry; | ||
|
||
private Map<String, Tool.Factory> toolFactories; | ||
|
||
@Mock | ||
private Map<String, Memory.Factory> memoryMap; | ||
|
||
private MLChatAgentRunner mlChatAgentRunner; | ||
|
||
@Mock | ||
private Tool.Factory firstToolFactory; | ||
|
||
@Mock | ||
private Tool.Factory secondToolFactory; | ||
@Mock | ||
private Tool firstTool; | ||
|
||
@Mock | ||
private Tool secondTool; | ||
|
||
@Mock | ||
private ActionListener<Object> agentActionListener; | ||
|
||
@Captor | ||
private ArgumentCaptor<Object> objectCaptor; | ||
|
||
@Captor | ||
private ArgumentCaptor<StepListener<Object>> nextStepListenerCaptor; | ||
|
||
private MLMemorySpec mlMemorySpec; | ||
@Mock | ||
private ConversationIndexMemory conversationIndexMemory; | ||
@Mock | ||
private MLMemoryManager mlMemoryManager; | ||
|
||
@Mock | ||
private ConversationIndexMemory.Factory memoryFactory; | ||
@Captor | ||
private ArgumentCaptor<ActionListener<ConversationIndexMemory>> memoryFactoryCapture; | ||
@Captor | ||
private ArgumentCaptor<ActionListener<List<Interaction>>> memoryInteractionCapture; | ||
|
||
@Before | ||
@SuppressWarnings("unchecked") | ||
public void setup() { | ||
MockitoAnnotations.openMocks(this); | ||
settings = Settings.builder().build(); | ||
toolFactories = ImmutableMap.of(FIRST_TOOL, firstToolFactory, SECOND_TOOL, secondToolFactory); | ||
|
||
// memory | ||
mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); | ||
when(memoryMap.get(anyString())).thenReturn(memoryFactory); | ||
doAnswer(invocation -> { | ||
ActionListener<List<Interaction>> listener = invocation.getArgument(0); | ||
listener.onResponse(generateInteractions(2)); | ||
return null; | ||
}).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture()); | ||
when(conversationIndexMemory.getConversationId()).thenReturn("conversation_id"); | ||
when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); | ||
ArgumentCaptor<ActionListener<Boolean>> argumentCaptor = ArgumentCaptor.forClass(ActionListener.class); | ||
doAnswer(invocation -> { | ||
ActionListener<ConversationIndexMemory> listener = invocation.getArgument(3); | ||
listener.onResponse(conversationIndexMemory); | ||
return null; | ||
}).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture()); | ||
|
||
mlChatAgentRunner = new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap); | ||
when(firstToolFactory.create(Mockito.anyMap())).thenReturn(firstTool); | ||
when(secondToolFactory.create(Mockito.anyMap())).thenReturn(secondTool); | ||
when(firstTool.getName()).thenReturn(FIRST_TOOL); | ||
when(secondTool.getName()).thenReturn(SECOND_TOOL); | ||
when(firstTool.validate(Mockito.anyMap())).thenReturn(true); | ||
when(secondTool.validate(Mockito.anyMap())).thenReturn(true); | ||
Mockito | ||
.doAnswer(generateToolResponse("First tool response")) | ||
.when(firstTool) | ||
.run(Mockito.anyMap(), nextStepListenerCaptor.capture()); | ||
Mockito | ||
.doAnswer(generateToolResponse("Second tool response")) | ||
.when(secondTool) | ||
.run(Mockito.anyMap(), nextStepListenerCaptor.capture()); | ||
|
||
Mockito | ||
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "thought 1", "action", FIRST_TOOL))) | ||
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "thought 2", "action", SECOND_TOOL))) | ||
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "thought 3", "final_answer", "This is the final answer"))) | ||
.when(client) | ||
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); | ||
} | ||
|
||
@Test | ||
public void testRunWithIncludeOutputNotSet() { | ||
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); | ||
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); | ||
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); | ||
final MLAgent mlAgent = MLAgent | ||
.builder() | ||
.name("TestAgent") | ||
.llm(llmSpec) | ||
.memory(mlMemorySpec) | ||
.tools(Arrays.asList(firstToolSpec, secondToolSpec)) | ||
.build(); | ||
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); | ||
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); | ||
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); | ||
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors(); | ||
Assert.assertEquals(1, agentOutput.size()); | ||
// Respond with last tool output | ||
Assert.assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); | ||
} | ||
|
||
@Test | ||
public void testRunWithIncludeOutputSet() { | ||
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); | ||
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(false).build(); | ||
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build(); | ||
final MLAgent mlAgent = MLAgent | ||
.builder() | ||
.name("TestAgent") | ||
.memory(mlMemorySpec) | ||
.llm(llmSpec) | ||
.tools(Arrays.asList(firstToolSpec, secondToolSpec)) | ||
.build(); | ||
HashMap<String, String> params = new HashMap<>(); | ||
mlChatAgentRunner.run(mlAgent, params, agentActionListener); | ||
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); | ||
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); | ||
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors(); | ||
Assert.assertEquals(1, agentOutput.size()); | ||
// Respond with last tool output | ||
Assert.assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); | ||
Map<String, List<String>> additionalInfos = (Map<String, List<String>>)agentOutput.get(0).getDataAsMap().get("additional_info"); | ||
Assert.assertEquals("Second tool response", additionalInfos.get(String.format("%s.output", SECOND_TOOL)).get(0)); | ||
} | ||
|
||
private List<Interaction> generateInteractions(int size) { | ||
return IntStream | ||
.range(1, size + 1) | ||
.mapToObj(i -> Interaction.builder().id("interaction-" + i).input("input-" + i).response("response-" + i).build()) | ||
.collect(Collectors.toList()); | ||
} | ||
|
||
private Answer getLLMAnswer(Map<String, String> llmResponse) { | ||
return invocation -> { | ||
ActionListener<Object> listener = invocation.getArgument(2); | ||
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(llmResponse).build(); | ||
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); | ||
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); | ||
MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); | ||
listener.onResponse(mlTaskResponse); | ||
return null; | ||
}; | ||
} | ||
|
||
private Answer generateToolResponse(String response) { | ||
return invocation -> { | ||
ActionListener<Object> listener = invocation.getArgument(1); | ||
listener.onResponse(response); | ||
return null; | ||
}; | ||
} | ||
|
||
} |