Skip to content

Commit

Permalink
Handle empty memory type in agent configuration
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <[email protected]>
  • Loading branch information
arjunkumargiri committed Dec 10, 2023
1 parent 44118ef commit f998c9f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
Expand Down Expand Up @@ -109,17 +110,18 @@ public void execute(Input input, ActionListener<Output> listener) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLAgent mlAgent = MLAgent.parse(parser);
String memoryType = mlAgent.getMemory().getType();
MLMemorySpec memorySpec = mlAgent.getMemory();
String memoryId = inputDataSet.getParameters().get(MEMORY_ID);
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
String appType = mlAgent.getAppType();
String question = inputDataSet.getParameters().get(QUESTION);

if (memoryType != null
&& memoryFactoryMap.containsKey(memoryType)
if (memorySpec != null
&& memorySpec.getType() != null
&& memoryFactoryMap.containsKey(memorySpec.getType())
&& (memoryId == null || parentInteractionId == null)) {
ConversationIndexMemory.Factory conversationIndexMemoryFactory =
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType());
conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> {
inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId());
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand Down Expand Up @@ -81,7 +82,7 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
return;
}

String memoryType = mlAgent.getMemory().getType();
MLMemorySpec memorySpec = mlAgent.getMemory();
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
String parentInteractionId = params.get(MLAgentExecutor.PARENT_INTERACTION_ID);

Expand Down Expand Up @@ -121,7 +122,7 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
additionalInfo.put(outputKey, outputResponse);

if (finalI == toolSpecs.size()) {
updateMemory(additionalInfo, memoryType, memoryId, parentInteractionId);
updateMemory(additionalInfo, memorySpec, memoryId, parentInteractionId);
listener.onResponse(flowAgentOutput);
return;
}
Expand All @@ -146,11 +147,12 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
}
}

private void updateMemory(Map<String, Object> additionalInfo, String memoryType, String memoryId, String interactionId) {
if (memoryId == null || interactionId == null) {
private void updateMemory(Map<String, Object> additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId) {
if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) {
return;
}
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap
.get(memorySpec.getType());
conversationIndexMemoryFactory.create(memoryId, ActionListener.<ConversationIndexMemory>wrap(memory -> {
updateInteraction(additionalInfo, interactionId, memory);
}, e -> { log.error("Failed create memory from id: " + memoryId, e); }));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,25 @@ public void testRunWithIncludeOutputSet() {
Assert.assertEquals("First tool response", agentOutput.get(0).getResult());
Assert.assertEquals("Second tool response", agentOutput.get(1).getResult());
}

@Test
public void testWithMemoryNotSet() {
final Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.MEMORY_ID, "memoryId");
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.memory(null)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();
mlFlowAgentRunner.run(mlAgent, params, agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
List<ModelTensor> agentOutput = (List<ModelTensor>) objectCaptor.getValue();
Assert.assertEquals(1, agentOutput.size());
// Respond with last tool output
Assert.assertEquals(SECOND_TOOL, agentOutput.get(0).getName());
Assert.assertEquals("Second tool response", agentOutput.get(0).getResult());
}
}

0 comments on commit f998c9f

Please sign in to comment.