Skip to content

Commit

Permalink
flow agent suggestions missing
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Feb 6, 2024
1 parent b62b0de commit bffe1ea
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,18 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
}

if (finalI == toolSpecs.size()) {
updateMemory(additionalInfo, memorySpec, memoryId, parentInteractionId);
listener.onResponse(flowAgentOutput);
if (memoryId == null || parentInteractionId == null || memorySpec == null || memorySpec.getType() == null) {
listener.onResponse(flowAgentOutput);
} else {
ActionListener updateListener = ActionListener.<UpdateResponse>wrap(updateResponse -> {
log.info("Updated additional info for interaction ID: " + updateResponse.getId() + " in the flow agent.");
listener.onResponse(flowAgentOutput);
}, e -> {
log.error("Failed to update root interaction", e);
listener.onResponse(flowAgentOutput);
});
updateMemoryWithListener(additionalInfo, memorySpec, memoryId, parentInteractionId, updateListener);
}
return;
}

Expand Down Expand Up @@ -168,6 +178,30 @@ void updateMemory(Map<String, Object> additionalInfo, MLMemorySpec memorySpec, S
);
}

@VisibleForTesting
void updateMemoryWithListener(
Map<String, Object> additionalInfo,
MLMemorySpec memorySpec,
String memoryId,
String interactionId,
ActionListener listener
) {
if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) {
return;
}
ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap
.get(memorySpec.getType());
conversationIndexMemoryFactory
.create(
memoryId,
ActionListener
.wrap(
memory -> updateInteractionWithListener(additionalInfo, interactionId, memory, listener),
e -> log.error("Failed create memory from id: " + memoryId, e)
)
);
}

@VisibleForTesting
void updateInteraction(Map<String, Object> additionalInfo, String interactionId, ConversationIndexMemory memory) {
memory
Expand All @@ -181,6 +215,18 @@ void updateInteraction(Map<String, Object> additionalInfo, String interactionId,
);
}

@VisibleForTesting
void updateInteractionWithListener(
Map<String, Object> additionalInfo,
String interactionId,
ConversationIndexMemory memory,
ActionListener listener
) {
memory
.getMemoryManager()
.updateInteraction(interactionId, ImmutableMap.of(ActionConstants.ADDITIONAL_INFO_FIELD, additionalInfo), listener);
}

@VisibleForTesting
String parseResponse(Object output) throws IOException {
if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -35,11 +36,15 @@
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.stubbing.Answer;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.StepListener;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
Expand Down Expand Up @@ -162,12 +167,18 @@ public void testRunWithIncludeOutputNotSet() {
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build();
ConversationIndexMemory memory = mock(ConversationIndexMemory.class);
Mockito.doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
listener.onResponse(new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED));
return null;
}).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any());
doReturn(memoryManager).when(memory).getMemoryManager();
Mockito.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(memory);
return null;
}).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any());
Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager);

final MLAgent mlAgent = MLAgent
.builder()
Expand Down Expand Up @@ -210,12 +221,18 @@ public void testRunWithIncludeOutputSet() {
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).includeOutputInAgentResponse(true).build();
MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build();
ConversationIndexMemory memory = mock(ConversationIndexMemory.class);
Mockito.doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
listener.onResponse(new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED));
return null;
}).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any());
doReturn(memoryManager).when(memory).getMemoryManager();
Mockito.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(memory);
return null;
}).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any());
Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager);
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
Expand Down Expand Up @@ -415,4 +432,45 @@ public void testUpdateMemory() {
verify(memoryManager).updateInteraction(anyString(), anyMap(), any(ActionListener.class));
}

@Test
public void testRunWithUpdateFailure() {
final Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.MEMORY_ID, "memoryId");
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "interaction_id");
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build();
ConversationIndexMemory memory = mock(ConversationIndexMemory.class);
Mockito.doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
listener.onFailure(new IllegalArgumentException("input error"));
return null;
}).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any());
doReturn(memoryManager).when(memory).getMemoryManager();
Mockito.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(memory);
return null;
}).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any());

final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();
mlFlowAgentRunner.run(mlAgent, params, agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
List<ModelTensor> agentOutput = (List<ModelTensor>) objectCaptor.getValue();
assertEquals(1, agentOutput.size());
// Respond with last tool output
assertEquals(SECOND_TOOL, agentOutput.get(0).getName());
assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(0).getResult());

verify(memoryManager).updateInteraction(anyString(), memoryMapCaptor.capture(), any(ActionListener.class));
Map<String, Object> additionalInfo = (Map<String, Object>) memoryMapCaptor.getValue().get("additional_info");
assertEquals(1, additionalInfo.size());
assertNotNull(additionalInfo.get(SECOND_TOOL + ".output"));
}

}

0 comments on commit bffe1ea

Please sign in to comment.