From f998c9f7be399fc66d2dc3673d0a5fc662fdb780 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Sun, 10 Dec 2023 15:44:32 -0800 Subject: [PATCH] Handle empty memory type in agent configuration Signed-off-by: Arjun kumar Giri --- .../algorithms/agent/MLAgentExecutor.java | 10 +++++---- .../algorithms/agent/MLFlowAgentRunner.java | 12 ++++++----- .../agent/MLFlowAgentRunnerTest.java | 21 +++++++++++++++++++ 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index a0d66f9163..bb13fa060b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -31,6 +31,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; @@ -109,17 +110,18 @@ public void execute(Input input, ActionListener listener) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLAgent mlAgent = MLAgent.parse(parser); - String memoryType = mlAgent.getMemory().getType(); + MLMemorySpec memorySpec = mlAgent.getMemory(); String memoryId = inputDataSet.getParameters().get(MEMORY_ID); String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); String appType = mlAgent.getAppType(); String question = inputDataSet.getParameters().get(QUESTION); - if (memoryType != null - && memoryFactoryMap.containsKey(memoryType) + if (memorySpec != null + && memorySpec.getType() != null + && memoryFactoryMap.containsKey(memorySpec.getType()) && (memoryId == null || parentInteractionId == null)) { ConversationIndexMemory.Factory conversationIndexMemoryFactory = - (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); + (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> { inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index a64508c7ee..95c50dd304 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -26,6 +26,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.output.model.ModelTensor; @@ -81,7 +82,7 @@ public void run(MLAgent mlAgent, Map params, ActionListener params, ActionListener params, ActionListener additionalInfo, String memoryType, String memoryId, String interactionId) { - if (memoryId == null || interactionId == null) { + private void updateMemory(Map additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId) { + if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { return; } - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); + ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap + .get(memorySpec.getType()); conversationIndexMemoryFactory.create(memoryId, ActionListener.wrap(memory -> { updateInteraction(additionalInfo, interactionId, memory); }, e -> { log.error("Failed create memory from id: " + memoryId, e); })); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index a828f71c12..c4129e9f77 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -150,4 +150,25 @@ public void testRunWithIncludeOutputSet() { Assert.assertEquals("First tool response", agentOutput.get(0).getResult()); Assert.assertEquals("Second tool response", agentOutput.get(1).getResult()); } + + @Test + public void testWithMemoryNotSet() { + final Map params = new HashMap<>(); + params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .memory(null) + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + mlFlowAgentRunner.run(mlAgent, params, agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + List agentOutput = (List) objectCaptor.getValue(); + Assert.assertEquals(1, agentOutput.size()); + // Respond with last tool output + Assert.assertEquals(SECOND_TOOL, agentOutput.get(0).getName()); + Assert.assertEquals("Second tool response", agentOutput.get(0).getResult()); + } }