diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 0243923a40..19266d17ae 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -43,8 +43,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; - -import software.amazon.awssdk.utils.ImmutableMap; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; public class MLChatAgentRunnerTest { public static final String FIRST_TOOL = "firstTool"; @@ -194,6 +193,36 @@ public void testRunWithIncludeOutputSet() { Assert.assertEquals("Second tool response", additionalInfos.get(String.format("%s.output", SECOND_TOOL)).get(0)); } + @Test + public void testChatHistoryExcludeOngoingQuestion() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).includeOutputInAgentResponse(false).build(); + MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .memory(mlMemorySpec) + .llm(llmSpec) + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + List interactionList = generateInteractions(2); + Interaction inProgressInteraction = Interaction.builder().id("interaction-99").input("input-99").response(null).build(); + interactionList.add(inProgressInteraction); + listener.onResponse(interactionList); + return null; + }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture()); + + HashMap params = new HashMap<>(); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); + Assert.assertFalse(chatHistory.contains("input-99")); + + } + private List generateInteractions(int size) { return IntStream .range(1, size + 1)