Skip to content

Commit

Permalink
fix duplicate trace record may caused by code merge
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Jan 31, 2024
1 parent 45e5199 commit d449cf7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,21 +374,6 @@ private void runReAct(
.build();
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null);
}
if (finalAnswer != null) {
finalAnswer = finalAnswer.trim();
if (conversationIndexMemory != null) {
// Create final trace message.
ConversationIndexMessage msgTemp = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type("ReAct")
.question(question)
.response(finalAnswer)
.finalAnswer(true)
.sessionId(sessionId)
.build();
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null);
}
}
if (finalAnswer != null) {
finalAnswer = finalAnswer.trim();
if (conversationIndexMemory != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -52,6 +55,7 @@
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.engine.memory.MLMemoryManager;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap;

Expand Down Expand Up @@ -212,6 +216,11 @@ public void testRunWithIncludeOutputNotSet() {
.build();
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
// 2 tools thought + 1 final answer thought + 1 final answer
Mockito.verify(conversationIndexMemory, times(4)).save(any(ConversationIndexMessage.class), any(), anyInt(), eq(null));
// two tools response
Mockito.verify(conversationIndexMemory, times(1)).save(any(ConversationIndexMessage.class), any(), anyInt(), eq(FIRST_TOOL));
Mockito.verify(conversationIndexMemory, times(1)).save(any(ConversationIndexMessage.class), any(), anyInt(), eq(SECOND_TOOL));
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
assertEquals(1, agentOutput.size());
Expand Down

0 comments on commit d449cf7

Please sign in to comment.