Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Nov 30, 2023
1 parent 59d1488 commit 35759a2
Showing 1 changed file with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.MLMemoryManager;

import software.amazon.awssdk.utils.ImmutableMap;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;

public class MLChatAgentRunnerTest {
public static final String FIRST_TOOL = "firstTool";
Expand Down Expand Up @@ -194,6 +193,36 @@ public void testRunWithIncludeOutputSet() {
Assert.assertEquals("Second tool response", additionalInfos.get(String.format("%s.output", SECOND_TOOL)).get(0));
}

@Test
public void testChatHistoryExcludeOngoingQuestion() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(false).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.memory(mlMemorySpec)
.llm(llmSpec)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();

doAnswer(invocation -> {
ActionListener<List<Interaction>> listener = invocation.getArgument(0);
List<Interaction> interactionList = generateInteractions(2);
Interaction inProgressInteraction = Interaction.builder().id("interaction-99").input("input-99").response(null).build();
interactionList.add(inProgressInteraction);
listener.onResponse(interactionList);
return null;
}).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture());

HashMap<String, String> params = new HashMap<>();
mlChatAgentRunner.run(mlAgent, params, agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY);
Assert.assertFalse(chatHistory.contains("input-99"));

}

private List<Interaction> generateInteractions(int size) {
return IntStream
.range(1, size + 1)
Expand Down

0 comments on commit 35759a2

Please sign in to comment.