From 8dba50b4ab5592e9b97e884394de7bd9ff451260 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Sat, 9 Dec 2023 02:54:38 +0800 Subject: [PATCH] Fix sort on get trace API and flaky test of conversation (#1733) * Add trace number into sort Signed-off-by: Hailong Cui * refresh interaction immediately so that get interaction list able to fetch the latest data Signed-off-by: Hailong Cui * update test Signed-off-by: Hailong Cui * Add more logs Signed-off-by: Hailong Cui * Add more debug logs Signed-off-by: Hailong Cui * Add delay for creating two conversations Signed-off-by: Hailong Cui * Using trace number as sort key Signed-off-by: Hailong Cui --------- Signed-off-by: Hailong Cui --- .../ml/memory/index/InteractionsIndex.java | 2 +- .../opensearch/ml/engine/memory/MLMemoryManager.java | 2 ++ .../ml/rest/RestMemoryGetConversationsActionIT.java | 12 ++++++++++-- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 573e4a6b93..903f30e810 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -337,7 +337,7 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList request.source(searchSourceBuilder); request.source().from(from).size(maxResults); - request.source().sort(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, SortOrder.ASC); + request.source().sort(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, SortOrder.ASC); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(response -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index b6620bf8e2..64c608a3e1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -14,6 +14,7 @@ import org.opensearch.action.DocWriteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -266,6 +267,7 @@ public void updateInteraction(String interactionId, Map updateCo UpdateRequest updateRequest = new UpdateRequest(indexName, interactionId); updateRequest.doc(updateContent); updateRequest.docAsUpsert(true); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener al = ActionListener.runBefore(ActionListener.wrap(updateResponse -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java index 6fb30dd426..c213e9daa2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Map; +import java.util.concurrent.TimeUnit; import org.apache.http.HttpEntity; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; +import org.junit.Assert; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; @@ -119,21 +121,26 @@ public void testConversations_MorePages() throws IOException { assert (((Double) map.get("next_token")).intValue() == 1); } - public void testGetConversations_nextPage() throws IOException { + public void testGetConversations_nextPage() throws IOException, InterruptedException { Response ccresponse1 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse1 != null); assert (TestHelper.restStatus(ccresponse1) == RestStatus.OK); HttpEntity cchttpEntity1 = ccresponse1.getEntity(); String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1); + logger.info("ccentityString1={}", ccentityString1); Map ccmap1 = gson.fromJson(ccentityString1, Map.class); assert (ccmap1.containsKey("conversation_id")); String id1 = (String) ccmap1.get("conversation_id"); + // wait for 0.1s to make sure update time is different between conversation 1 and 2 + TimeUnit.MICROSECONDS.sleep(100); + Response ccresponse2 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse2 != null); assert (TestHelper.restStatus(ccresponse2) == RestStatus.OK); HttpEntity cchttpEntity2 = ccresponse2.getEntity(); String ccentityString2 = TestHelper.httpEntityToString(cchttpEntity2); + logger.info("ccentityString2={}", ccentityString2); Map ccmap2 = gson.fromJson(ccentityString2, Map.class); assert (ccmap2.containsKey("conversation_id")); String id2 = (String) ccmap2.get("conversation_id"); @@ -151,6 +158,7 @@ public void testGetConversations_nextPage() throws IOException { assert (TestHelper.restStatus(response1) == RestStatus.OK); HttpEntity httpEntity1 = response1.getEntity(); String entityString1 = TestHelper.httpEntityToString(httpEntity1); + logger.info("entityString1={}", entityString1); Map map1 = gson.fromJson(entityString1, Map.class); assert (map1.containsKey("conversations")); assert (map1.containsKey("next_token")); @@ -158,7 +166,7 @@ public void testGetConversations_nextPage() throws IOException { ArrayList conversations1 = (ArrayList) map1.get("conversations"); assert (conversations1.size() == 1); assert (conversations1.get(0).containsKey("conversation_id")); - assert (((String) conversations1.get(0).get("conversation_id")).equals(id2)); + Assert.assertEquals(conversations1.get(0).get("conversation_id"), id2); assert (((Double) map1.get("next_token")).intValue() == 1); Response response = TestHelper