From bffe1ea0287c9a4d0b80c33b42f99aaa1c141dd4 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 5 Feb 2024 21:55:57 -0800 Subject: [PATCH] flow agent suggestions missing Signed-off-by: Jing Zhang --- .../algorithms/agent/MLFlowAgentRunner.java | 50 ++++++++++++++- .../agent/MLFlowAgentRunnerTest.java | 62 ++++++++++++++++++- 2 files changed, 108 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index ab03859271..1d574abc86 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -125,8 +125,18 @@ public void run(MLAgent mlAgent, Map params, ActionListenerwrap(updateResponse -> { + log.info("Updated additional info for interaction ID: " + updateResponse.getId() + " in the flow agent."); + listener.onResponse(flowAgentOutput); + }, e -> { + log.error("Failed to update root interaction", e); + listener.onResponse(flowAgentOutput); + }); + updateMemoryWithListener(additionalInfo, memorySpec, memoryId, parentInteractionId, updateListener); + } return; } @@ -168,6 +178,30 @@ void updateMemory(Map additionalInfo, MLMemorySpec memorySpec, S ); } + @VisibleForTesting + void updateMemoryWithListener( + Map additionalInfo, + MLMemorySpec memorySpec, + String memoryId, + String interactionId, + ActionListener listener + ) { + if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { + return; + } + ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap + .get(memorySpec.getType()); + conversationIndexMemoryFactory + .create( + memoryId, + ActionListener + .wrap( + memory -> updateInteractionWithListener(additionalInfo, interactionId, memory, listener), + e -> log.error("Failed create memory from id: " + memoryId, e) + ) + ); + } + @VisibleForTesting void updateInteraction(Map additionalInfo, String interactionId, ConversationIndexMemory memory) { memory @@ -181,6 +215,18 @@ void updateInteraction(Map additionalInfo, String interactionId, ); } + @VisibleForTesting + void updateInteractionWithListener( + Map additionalInfo, + String interactionId, + ConversationIndexMemory memory, + ActionListener listener + ) { + memory + .getMemoryManager() + .updateInteraction(interactionId, ImmutableMap.of(ActionConstants.ADDITIONAL_INFO_FIELD, additionalInfo), listener); + } + @VisibleForTesting String parseResponse(Object output) throws IOException { if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index 9aaf76c2f4..68cae251b9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -14,6 +14,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -35,11 +36,15 @@ import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.stubbing.Answer; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.StepListener; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -162,12 +167,18 @@ public void testRunWithIncludeOutputNotSet() { MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); ConversationIndexMemory memory = mock(ConversationIndexMemory.class); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + listener.onResponse(new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED)); + return null; + }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); + doReturn(memoryManager).when(memory).getMemoryManager(); Mockito.doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(memory); return null; }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); - Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); final MLAgent mlAgent = MLAgent .builder() @@ -210,12 +221,18 @@ public void testRunWithIncludeOutputSet() { MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build(); MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); ConversationIndexMemory memory = mock(ConversationIndexMemory.class); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + listener.onResponse(new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED)); + return null; + }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); + doReturn(memoryManager).when(memory).getMemoryManager(); Mockito.doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(memory); return null; }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); - Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") @@ -415,4 +432,45 @@ public void testUpdateMemory() { verify(memoryManager).updateInteraction(anyString(), anyMap(), any(ActionListener.class)); } + @Test + public void testRunWithUpdateFailure() { + final Map params = new HashMap<>(); + params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "interaction_id"); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); + MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); + ConversationIndexMemory memory = mock(ConversationIndexMemory.class); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("input error")); + return null; + }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); + doReturn(memoryManager).when(memory).getMemoryManager(); + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(memory); + return null; + }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); + + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + mlFlowAgentRunner.run(mlAgent, params, agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + List agentOutput = (List) objectCaptor.getValue(); + assertEquals(1, agentOutput.size()); + // Respond with last tool output + assertEquals(SECOND_TOOL, agentOutput.get(0).getName()); + assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(0).getResult()); + + verify(memoryManager).updateInteraction(anyString(), memoryMapCaptor.capture(), any(ActionListener.class)); + Map additionalInfo = (Map) memoryMapCaptor.getValue().get("additional_info"); + assertEquals(1, additionalInfo.size()); + assertNotNull(additionalInfo.get(SECOND_TOOL + ".output")); + } + }