Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Nov 28, 2023
1 parent c8e2dc6 commit 889d9e2
Showing 1 changed file with 223 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package org.opensearch.ml.engine.algorithms.agent;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.stubbing.Answer;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.StepListener;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.MLMemoryManager;
import software.amazon.awssdk.utils.ImmutableMap;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;

public class MLChatAgentRunnerTest {
public static final String FIRST_TOOL = "firstTool";
public static final String SECOND_TOOL = "secondTool";

@Mock
private Client client;

private Settings settings;

@Mock
private ClusterService clusterService;

@Mock
private NamedXContentRegistry xContentRegistry;

private Map<String, Tool.Factory> toolFactories;

@Mock
private Map<String, Memory.Factory> memoryMap;

private MLChatAgentRunner mlChatAgentRunner;

@Mock
private Tool.Factory firstToolFactory;

@Mock
private Tool.Factory secondToolFactory;
@Mock
private Tool firstTool;

@Mock
private Tool secondTool;

@Mock
private ActionListener<Object> agentActionListener;

@Captor
private ArgumentCaptor<Object> objectCaptor;

@Captor
private ArgumentCaptor<StepListener<Object>> nextStepListenerCaptor;

private MLMemorySpec mlMemorySpec;
@Mock
private ConversationIndexMemory conversationIndexMemory;
@Mock
private MLMemoryManager mlMemoryManager;

@Mock
private ConversationIndexMemory.Factory memoryFactory;
@Captor
private ArgumentCaptor<ActionListener<ConversationIndexMemory>> memoryFactoryCapture;
@Captor
private ArgumentCaptor<ActionListener<List<Interaction>>> memoryInteractionCapture;

@Before
@SuppressWarnings("unchecked")
public void setup() {
MockitoAnnotations.openMocks(this);
settings = Settings.builder().build();
toolFactories = ImmutableMap.of(FIRST_TOOL, firstToolFactory, SECOND_TOOL, secondToolFactory);

// memory
mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10);
when(memoryMap.get(anyString())).thenReturn(memoryFactory);
doAnswer(invocation -> {
ActionListener<List<Interaction>> listener = invocation.getArgument(0);
listener.onResponse(generateInteractions(2));
return null;
}).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture());
when(conversationIndexMemory.getConversationId()).thenReturn("conversation_id");
when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager);
ArgumentCaptor<ActionListener<Boolean>> argumentCaptor = ArgumentCaptor.forClass(ActionListener.class);
doAnswer(invocation -> {
ActionListener<ConversationIndexMemory> listener = invocation.getArgument(3);
listener.onResponse(conversationIndexMemory);
return null;
}).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture());

mlChatAgentRunner = new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryMap);
when(firstToolFactory.create(Mockito.anyMap())).thenReturn(firstTool);
when(secondToolFactory.create(Mockito.anyMap())).thenReturn(secondTool);
when(firstTool.getName()).thenReturn(FIRST_TOOL);
when(secondTool.getName()).thenReturn(SECOND_TOOL);
when(firstTool.validate(Mockito.anyMap())).thenReturn(true);
when(secondTool.validate(Mockito.anyMap())).thenReturn(true);
Mockito
.doAnswer(generateToolResponse("First tool response"))
.when(firstTool)
.run(Mockito.anyMap(), nextStepListenerCaptor.capture());
Mockito
.doAnswer(generateToolResponse("Second tool response"))
.when(secondTool)
.run(Mockito.anyMap(), nextStepListenerCaptor.capture());

Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "thought 1", "action", FIRST_TOOL)))
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "thought 2", "action", SECOND_TOOL)))
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "thought 3", "final_answer", "This is the final answer")))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));
}

@Test
public void testRunWithIncludeOutputNotSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
Assert.assertEquals(1, agentOutput.size());
// Respond with last tool output
Assert.assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
}

@Test
public void testRunWithIncludeOutputSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(false).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.memory(mlMemorySpec)
.llm(llmSpec)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();
HashMap<String, String> params = new HashMap<>();
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
Assert.assertEquals(1, agentOutput.size());
// Respond with last tool output
Assert.assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
Map<String, List<String>> additionalInfos = (Map<String, List<String>>)agentOutput.get(0).getDataAsMap().get("additional_info");
Assert.assertEquals("Second tool response", additionalInfos.get(String.format("%s.output", SECOND_TOOL)).get(0));
}

private List<Interaction> generateInteractions(int size) {
return IntStream
.range(1, size + 1)
.mapToObj(i -> Interaction.builder().id("interaction-" + i).input("input-" + i).response("response-" + i).build())
.collect(Collectors.toList());
}

private Answer getLLMAnswer(Map<String, String> llmResponse) {
return invocation -> {
ActionListener<Object> listener = invocation.getArgument(2);
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(llmResponse).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build();
listener.onResponse(mlTaskResponse);
return null;
};
}

private Answer generateToolResponse(String response) {
return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
};
}

}

0 comments on commit 889d9e2

Please sign in to comment.