From 826b13039e0ae21e1c436bcc72be9eaf18fba865 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Tue, 16 Jan 2024 16:45:35 -0800 Subject: [PATCH] refactor memory layer APIs Signed-off-by: Xun Zhang --- .../common/conversation/ActionConstants.java | 35 ++++--- .../ConversationalIndexConstants.java | 4 +- .../conversation/ConversationMetaTests.java | 2 +- .../common/conversation/InteractionTests.java | 2 +- .../memory/ConversationalMemoryHandler.java | 6 +- .../CreateInteractionRequest.java | 2 +- .../conversation/GetConversationRequest.java | 2 +- .../conversation/GetInteractionRequest.java | 12 +-- .../GetInteractionTransportAction.java | 5 +- .../conversation/GetInteractionsRequest.java | 2 +- .../ml/memory/index/InteractionsIndex.java | 55 ++++------- ...OpenSearchConversationalMemoryHandler.java | 10 +- .../CreateConversationResponseTests.java | 2 +- .../CreateInteractionRequestTests.java | 6 +- .../CreateInteractionResponseTests.java | 2 +- .../GetConversationRequestTests.java | 2 +- .../GetConversationResponseTests.java | 2 +- .../GetConversationsResponseTests.java | 4 +- .../GetInteractionRequestTests.java | 24 +---- .../GetInteractionResponseTests.java | 2 +- .../GetInteractionTransportActionTests.java | 12 +-- .../GetInteractionsRequestTests.java | 16 +-- .../GetInteractionsResponseTests.java | 4 +- .../conversation/GetTracesResponseTests.java | 4 +- .../index/InteractionsIndexITTests.java | 4 +- .../memory/index/InteractionsIndexTests.java | 21 +--- ...earchConversationalMemoryHandlerTests.java | 6 +- .../algorithms/agent/MLAgentExecutor.java | 2 +- .../engine/memory/MLMemoryManagerTests.java | 2 +- .../RestMemoryUpdateConversationAction.java | 3 +- .../RestMemoryUpdateInteractionAction.java | 3 +- .../RestMemoryCreateConversationActionIT.java | 6 +- .../RestMemoryCreateInteractionActionIT.java | 10 +- .../RestMemoryDeleteConversationActionIT.java | 49 ++++----- .../RestMemoryGetConversationActionIT.java | 2 +- .../RestMemoryGetConversationsActionIT.java | 51 +++++----- .../RestMemoryGetInteractionActionIT.java | 11 +-- .../RestMemoryGetInteractionActionTests.java | 3 +- .../RestMemoryGetInteractionsActionIT.java | 99 +++++++++---------- ...RestMemorySearchConversationsActionIT.java | 4 +- .../RestMemorySearchInteractionsActionIT.java | 18 ++-- .../RestMemoryUpdateConversationTests.java | 15 +-- ...estMemoryUpdateInteractionActionTests.java | 15 +-- 43 files changed, 237 insertions(+), 304 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 119d5a6659..6dc1f1ba6a 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -23,7 +23,7 @@ public class ActionConstants { /** name of conversation Id field in all responses */ - public final static String CONVERSATION_ID_FIELD = "conversation_id"; + public final static String CONVERSATION_ID_FIELD = "memory_id"; /** name of list of conversations in all responses */ public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations"; @@ -32,7 +32,7 @@ public class ActionConstants { /** name of list on traces in all responses */ public final static String RESPONSE_TRACES_LIST_FIELD = "traces"; /** name of interaction Id field in all responses */ - public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id"; + public final static String RESPONSE_INTERACTION_ID_FIELD = "message_id"; /** name of conversation name in all requests */ public final static String REQUEST_CONVERSATION_NAME_FIELD = "name"; @@ -51,38 +51,41 @@ public class ActionConstants { /** name of metadata field in all requests */ public final static String ADDITIONAL_INFO_FIELD = "additional_info"; /** name of metadata field in all requests */ - public final static String PARENT_INTERACTION_ID_FIELD = "parent_interaction_id"; + public final static String PARENT_INTERACTION_ID_FIELD = "parent_message_id"; /** name of metadata field in all requests */ public final static String TRACE_NUMBER_FIELD = "trace_number"; /** name of success field in all requests */ public final static String SUCCESS_FIELD = "success"; - private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; - private final static String BASE_REST_INTERACTION_PATH = "/_plugins/_ml/memory/interaction"; + /** parameter for memory_id in URLs */ + public final static String MEMORY_ID = "memory_id"; + /** parameter for message_id in URLs */ + public final static String MESSAGE_ID = "message_id"; + private final static String BASE_REST_PATH = "/_plugins/_ml/memory"; /** path for create conversation */ - public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create"; + public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH; /** path for get conversations */ - public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list"; + public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH; /** path for update conversations */ - public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update"; + public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}"; /** path for create interaction */ - public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create"; + public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{memory_id}/messages"; /** path for get interactions */ - public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list"; + public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}/messages"; /** path for get traces */ - public final static String GET_TRACES_REST_PATH = "/_plugins/_ml/memory/trace" + "/{interaction_id}/_list"; + public final static String GET_TRACES_REST_PATH = BASE_REST_PATH + "/message/{message_id}/traces"; /** path for delete conversation */ - public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete"; + public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{memory_id}"; /** path for search conversations */ public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; /** path for search interactions */ - public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search"; + public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}/_search"; /** path for update interactions */ - public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_INTERACTION_PATH + "/{interaction_id}/_update"; + public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/message/{message_id}"; /** path for get conversation */ - public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{memory_id}"; /** path for get interaction */ - public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}"; + public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/message/{message_id}"; /** default max results returned by get operations */ public final static int DEFAULT_MAX_RESULTS = 10; diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index cb44ade93c..23cad0064d 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -66,7 +66,7 @@ public class ConversationalIndexConstants { /** Name of the conversational interactions index */ public final static String INTERACTIONS_INDEX_NAME = ".plugins-ml-memory-message"; /** Name of the interaction field for the conversation Id */ - public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "conversation_id"; + public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "memory_id"; /** Name of the interaction field for the human input */ public final static String INTERACTIONS_INPUT_FIELD = "input"; /** Name of the interaction field for the prompt template */ @@ -80,7 +80,7 @@ public class ConversationalIndexConstants { /** Name of the interaction field for the timestamp */ public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time"; /** Name of the interaction id */ - public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_interaction_id"; + public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_message_id"; /** The trace number of an interaction */ public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number"; /** Mappings for the interactions index */ diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java index febb29fbf1..304703d34f 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java @@ -89,7 +89,7 @@ public void test_ToXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); conversationMeta.toXContent(builder, EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - assertEquals(content, "{\"conversation_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}"); + assertEquals(content, "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}"); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index c704547050..128d9449ea 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -127,7 +127,7 @@ public void test_ToXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); interaction.toXContent(builder, EMPTY_PARAMS); String interactionContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"conversation_id\":\"conversation id\",\"interaction_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_interaction_id\":\"parant id\",\"trace_number\":1}", interactionContent); + assertEquals("{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent); } @Test diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index 0a439fe7e0..a48cc6ed17 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -269,18 +269,16 @@ public void createInteraction( /** * Get a single interaction - * @param conversationId id of the conversation this interaction belongs to * @param interactionId id of this interaction * @param listener receives the interaction */ - public void getInteraction(String conversationId, String interactionId, ActionListener listener); + public void getInteraction(String interactionId, ActionListener listener); /** * Get a single interaction - * @param conversationId id of the conversation this interaction belongs to * @param interactionId id of this interaction * @return ActionFuture for the interaction */ - public ActionFuture getInteraction(String conversationId, String interactionId); + public ActionFuture getInteraction(String interactionId); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index 5f9f4a8128..fe4a05bc0c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -126,7 +126,7 @@ public ActionRequestValidationException validate() { * @throws IOException if something goes wrong reading from request */ public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException { - String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); + String cid = request.param(ActionConstants.MEMORY_ID); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java index c5a6f6dd0e..6bc47f6df0 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java @@ -71,7 +71,7 @@ public ActionRequestValidationException validate() { * @throws IOException if something goes wrong in translation */ public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException { - String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + String conversationId = request.param(ActionConstants.MEMORY_ID); return new GetConversationRequest(conversationId); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java index 6808857c40..b494ad160c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java @@ -36,8 +36,6 @@ */ @AllArgsConstructor public class GetInteractionRequest extends ActionRequest { - @Getter - private String conversationId; @Getter private String interactionId; @@ -48,23 +46,18 @@ public class GetInteractionRequest extends ActionRequest { */ public GetInteractionRequest(StreamInput in) throws IOException { super(in); - this.conversationId = in.readString(); this.interactionId = in.readString(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(this.conversationId); out.writeString(this.interactionId); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; - if (conversationId == null) { - exception = addValidationError("Get Interaction Request must have a conversation id", exception); - } if (interactionId == null) { exception = addValidationError("Get Interaction Request must have an interaction id", exception); } @@ -78,8 +71,7 @@ public ActionRequestValidationException validate() { * @throws IOException if something goes wrong reading from the rest request */ public static GetInteractionRequest fromRestRequest(RestRequest request) throws IOException { - String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); - String interactionId = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); - return new GetInteractionRequest(conversationId, interactionId); + String interactionId = request.param(ActionConstants.MESSAGE_ID); + return new GetInteractionRequest(interactionId); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java index 16205ec8b9..4ba7c74b88 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java @@ -79,16 +79,15 @@ public void doExecute(Task task, GetInteractionRequest request, ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener.wrap(interaction -> { internalListener.onResponse(new GetInteractionResponse(interaction)); }, e -> { internalListener.onFailure(e); }); - cmHandler.getInteraction(conversationId, interactionId, al); + cmHandler.getInteraction(interactionId, al); } catch (Exception e) { - log.error("Failed to get interaction " + interactionId + " in conversation " + conversationId, e); + log.error("Failed to get interaction " + interactionId, e); actionListener.onFailure(e); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequest.java index 4554300f1c..1d77ce453e 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequest.java @@ -113,7 +113,7 @@ public ActionRequestValidationException validate() { * @throws IOException if something goes wrong */ public static GetInteractionsRequest fromRestRequest(RestRequest request) throws IOException { - String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); + String cid = request.param(ActionConstants.MEMORY_ID); if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 7c522f957e..efd4b562d4 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -491,11 +491,10 @@ public void searchInteractions(String conversationId, SearchRequest request, Act /** * Get a single interaction - * @param conversationId id of the conversation this interaction belongs to * @param interactionId id of this interaction * @param listener receives the interaction */ - public void getInteraction(String conversationId, String interactionId, ActionListener listener) { + public void getInteraction(String interactionId, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { listener .onFailure( @@ -506,39 +505,25 @@ public void getInteraction(String conversationId, String interactionId, ActionLi ); return; } - conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { - if (access) { - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); - ActionListener al = ActionListener.wrap(getResponse -> { - // If the conversation doesn't exist, fail - if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { - throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); - } - Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); - internalListener.onResponse(interaction); - }, e -> { internalListener.onFailure(e); }); - client - .admin() - .indices() - .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { - client.get(request, al); - }, e -> { - log.error("Failed to refresh interactions index during get interaction ", e); - internalListener.onFailure(e); - })); - } catch (Exception e) { - listener.onFailure(e); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the conversation doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { + throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); } - } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); - throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); - } - }, e -> { listener.onFailure(e); })); + Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); + internalListener.onResponse(interaction); + }, e -> { internalListener.onFailure(e); }); + client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, al); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index b2b753651a..64d39991df 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -421,23 +421,21 @@ public ActionFuture getConversation(String conversationId) { /** * Get a single interaction - * @param conversationId id of the conversation this interaction belongs to * @param interactionId id of this interaction * @param listener receives the interaction */ - public void getInteraction(String conversationId, String interactionId, ActionListener listener) { - interactionsIndex.getInteraction(conversationId, interactionId, listener); + public void getInteraction(String interactionId, ActionListener listener) { + interactionsIndex.getInteraction(interactionId, listener); } /** * Get a single interaction - * @param conversationId id of the conversation this interaction belongs to * @param interactionId id of this interaction * @return ActionFuture for the interaction */ - public ActionFuture getInteraction(String conversationId, String interactionId) { + public ActionFuture getInteraction(String interactionId) { PlainActionFuture fut = PlainActionFuture.newFuture(); - getInteraction(conversationId, interactionId, fut); + getInteraction(interactionId, fut); return fut; } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java index 75f256ceca..542a1f652d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java @@ -46,7 +46,7 @@ public void testCreateConversationResponseStreaming() throws IOException { public void testToXContent() throws IOException { CreateConversationResponse response = new CreateConversationResponse("createme"); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - String expected = "{\"conversation_id\":\"createme\"}"; + String expected = "{\"memory_id\":\"createme\"}"; response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); assert (result.equals(expected)); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index fae2984af9..155edbc3c8 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -104,10 +104,12 @@ public void testFromRestRequest() throws IOException { ); RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid")) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) .build(); CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + System.out.println(request.getConversationId()); + System.out.println(request.getInput()); assert (request.validate() == null); assert (request.getConversationId().equals("cid")); @@ -138,7 +140,7 @@ public void testFromRestRequest_Trace() throws IOException { ); RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "tid")) + .withParams(Map.of(ActionConstants.MEMORY_ID, "tid")) .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) .build(); CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponseTests.java index 939acc0435..0fa9f2bdbf 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponseTests.java @@ -46,7 +46,7 @@ public void testCreateInteractionResponseStreaming() throws IOException { public void testToXContent() throws IOException { CreateInteractionResponse response = new CreateInteractionResponse("createme"); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - String expected = "{\"interaction_id\":\"createme\"}"; + String expected = "{\"message_id\":\"createme\"}"; response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); assert (result.equals(expected)); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java index cb8b67b44b..5585dcf955 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java @@ -59,7 +59,7 @@ public void testNullConvoId_ThenFail() { } public void testFromRestRequest() throws IOException { - Map params = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "testcid"); + Map params = Map.of(ActionConstants.MEMORY_ID, "testcid"); RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); GetConversationRequest request = GetConversationRequest.fromRestRequest(rrequest); assert (request.validate() == null); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java index abb8d04de9..b3ed2f14ff 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -54,7 +54,7 @@ public void testToXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); - String expected = "{\"conversation_id\":\"cid\",\"create_time\":\"" + String expected = "{\"memory_id\":\"cid\",\"create_time\":\"" + convo.getCreatedTime() + "\",\"updated_time\":\"" + convo.getUpdatedTime() diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java index 4d14e6f703..44883bfcaf 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java @@ -73,7 +73,7 @@ public void testToXContent_MoreTokens() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); - String expected = "{\"conversations\":[{\"conversation_id\":\"0\",\"create_time\":\"" + String expected = "{\"conversations\":[{\"memory_id\":\"0\",\"create_time\":\"" + conversation.getCreatedTime() + "\"updated_time\":\"" + conversation.getUpdatedTime() @@ -93,7 +93,7 @@ public void testToXContent_NoMoreTokens() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); - String expected = "{\"conversations\":[{\"conversation_id\":\"0\",\"create_time\":\"" + String expected = "{\"conversations\":[{\"memory_id\":\"0\",\"create_time\":\"" + conversation.getCreatedTime() + "\"updated_time\":\"" + conversation.getUpdatedTime() diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java index 678004ae09..9adcf281eb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java @@ -36,9 +36,8 @@ public class GetInteractionRequestTests extends OpenSearchTestCase { public void testConstructorAndStreaming() throws IOException { - GetInteractionRequest request = new GetInteractionRequest("cid", "iid"); + GetInteractionRequest request = new GetInteractionRequest("iid"); assert (request.validate() == null); - assert (request.getConversationId().equals("cid")); assert (request.getInteractionId().equals("iid")); BytesStreamOutput outbytes = new BytesStreamOutput(); @@ -47,39 +46,24 @@ public void testConstructorAndStreaming() throws IOException { StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); GetInteractionRequest newRequest = new GetInteractionRequest(in); assert (newRequest.validate() == null); - assert (newRequest.getConversationId().equals("cid")); assert (newRequest.getInteractionId().equals("iid")); } public void testMalformedRequest_ThenInvalid() { - GetInteractionRequest bad1 = new GetInteractionRequest(null, "iid"); - GetInteractionRequest bad2 = new GetInteractionRequest("cid", null); - GetInteractionRequest bad3 = new GetInteractionRequest(null, null); - ActionRequestValidationException exc1 = bad1.validate(); + String nullId = null; + GetInteractionRequest bad2 = new GetInteractionRequest(nullId); ActionRequestValidationException exc2 = bad2.validate(); - ActionRequestValidationException exc3 = bad3.validate(); - - assert (exc1 != null); - assert (exc1.validationErrors().size() == 1); - assert (exc1.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); assert (exc2 != null); assert (exc2.validationErrors().size() == 1); assert (exc2.validationErrors().get(0).equals("Get Interaction Request must have an interaction id")); - - assert (exc3 != null); - assert (exc3.validationErrors().size() == 2); - assert (exc3.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); - assert (exc3.validationErrors().get(1).equals("Get Interaction Request must have an interaction id")); } public void testFromRestRequest() throws IOException { - Map params = Map - .of(ActionConstants.CONVERSATION_ID_FIELD, "testcid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "testiid"); + Map params = Map.of(ActionConstants.MESSAGE_ID, "testiid"); RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); GetInteractionRequest request = GetInteractionRequest.fromRestRequest(rrequest); assert (request.validate() == null); - assert (request.getConversationId().equals("testcid")); assert (request.getInteractionId().equals("testiid")); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java index 5cd79afc4a..2cadd7948a 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java @@ -74,7 +74,7 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); System.out.println(result); - String expected = "{\"conversation_id\":\"cid\",\"interaction_id\":\"iid\",\"create_time\":\"" + String expected = "{\"memory_id\":\"cid\",\"message_id\":\"iid\",\"create_time\":\"" + interaction.getCreateTime() + "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":{\"metadata\":\"some meta\"}}"; // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java index eca0a9251a..fa7fd52918 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java @@ -91,7 +91,7 @@ public void setup() throws IOException { this.actionListener = al; this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); - this.request = new GetInteractionRequest("cid", "iid"); + this.request = new GetInteractionRequest("iid"); Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); @@ -116,10 +116,10 @@ public void testGetInteraction() { Collections.singletonMap("metadata", "some meta") ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(1); listener.onResponse(testInteraction); return null; - }).when(cmHandler).getInteraction(any(), any(), any()); + }).when(cmHandler).getInteraction(any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionResponse.class); verify(actionListener, times(1)).onResponse(argCaptor.capture()); @@ -128,10 +128,10 @@ public void testGetInteraction() { public void testGetInteractionFails_ThenFail() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new Exception("Storage layer failure")); return null; - }).when(cmHandler).getInteraction(any(), any(), any()); + }).when(cmHandler).getInteraction(any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(argCaptor.capture()); @@ -139,7 +139,7 @@ public void testGetInteractionFails_ThenFail() { } public void testHandlerThrows_ThenFail() { - doThrow(new RuntimeException("CMHandler Failure")).when(cmHandler).getInteraction(any(), any(), any()); + doThrow(new RuntimeException("CMHandler Failure")).when(cmHandler).getInteraction(any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(argCaptor.capture()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequestTests.java index e1428d87c3..9ad5c668a9 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequestTests.java @@ -104,19 +104,11 @@ public void testMultipleBadValues_thenFailMultipleWays() { } public void testFromRestRequest() throws IOException { - Map basic = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid1"); - Map maxResOnly = Map - .of(ActionConstants.CONVERSATION_ID_FIELD, "cid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); - Map nextTokOnly = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid3", ActionConstants.NEXT_TOKEN_FIELD, "6"); + Map basic = Map.of(ActionConstants.MEMORY_ID, "cid1"); + Map maxResOnly = Map.of(ActionConstants.MEMORY_ID, "cid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); + Map nextTokOnly = Map.of(ActionConstants.MEMORY_ID, "cid3", ActionConstants.NEXT_TOKEN_FIELD, "6"); Map bothFields = Map - .of( - ActionConstants.CONVERSATION_ID_FIELD, - "cid4", - ActionConstants.REQUEST_MAX_RESULTS_FIELD, - "2", - ActionConstants.NEXT_TOKEN_FIELD, - "7" - ); + .of(ActionConstants.MEMORY_ID, "cid4", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "2", ActionConstants.NEXT_TOKEN_FIELD, "7"); RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(basic).build(); RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build(); RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build(); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java index c1fdfbffac..7f8bdfd27f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java @@ -100,7 +100,7 @@ public void testToXContent_MoreTokens() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); - String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":\"" + String expected = "{\"interactions\":[{\"memory_id\":\"cid\",\"message_id\":\"id0\",\"create_time\":\"" + interaction.getCreateTime() + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"}}],\"next_token\":2}"; log.info(result); @@ -117,7 +117,7 @@ public void testToXContent_NoMoreTokens() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); - String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":\"" + String expected = "{\"interactions\":[{\"memory_id\":\"cid\",\"message_id\":\"id0\",\"create_time\":\"" + interaction.getCreateTime() + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"}}]}"; log.info(result); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java index e013bcc518..87a96a16f3 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetTracesResponseTests.java @@ -94,9 +94,9 @@ public void testToXContent_MoreTokens() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); String result = BytesReference.bytes(builder).utf8ToString(); System.out.println(result); - String expected = "{\"traces\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":" + String expected = "{\"traces\":[{\"memory_id\":\"cid\",\"message_id\":\"id0\",\"create_time\":" + trace.getCreateTime() - + ",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"},\"parent_interaction_id\":\"parent_id\",\"trace_number\":1}],\"next_token\":2}"; + + ",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"},\"parent_message_id\":\"parent_id\",\"trace_number\":1}],\"next_token\":2}"; // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness LevenshteinDistance ld = new LevenshteinDistance(); assert (ld.getDistance(result, expected) > 0.95); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index 133c31971a..157263edb6 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -545,13 +545,13 @@ public void testGetInteractionById() { }); StepListener get1 = new StepListener<>(); - iid2.whenComplete(iid -> { index.getInteraction(conversation, iid1.result(), get1); }, e -> { + iid2.whenComplete(iid -> { index.getInteraction(iid1.result(), get1); }, e -> { cdl.countDown(); log.error(e); }); StepListener get2 = new StepListener<>(); - get1.whenComplete(interaction1 -> { index.getInteraction(conversation, iid2.result(), get2); }, e -> { + get1.whenComplete(interaction1 -> { index.getInteraction(iid2.result(), get2); }, e -> { cdl.countDown(); log.error(e); }); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 007e78019d..41fdb1af41 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -725,7 +725,7 @@ public void testGetSg_NoIndex_ThenFail() { doReturn(false).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); - interactionsIndex.getInteraction("cid", "iid", getListener); + interactionsIndex.getInteraction("iid", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor @@ -747,7 +747,7 @@ public void testGetSg_InteractionNotExist_ThenFail() { }).when(client).get(any(), any()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); - interactionsIndex.getInteraction("cid", "iid", getListener); + interactionsIndex.getInteraction("iid", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Interaction [iid] not found")); @@ -767,7 +767,7 @@ public void testGetSg_WrongId_ThenFail() { }).when(client).get(any(), any()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); - interactionsIndex.getInteraction("cid", "iid", getListener); + interactionsIndex.getInteraction("iid", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Interaction [iid] not found")); @@ -783,7 +783,7 @@ public void testGetSg_RefreshFails_ThenFail() { }).when(indicesAdminClient).refresh(any(), any()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); - interactionsIndex.getInteraction("cid", "iid", getListener); + interactionsIndex.getInteraction("iid", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Failed during Sg Get Refresh")); @@ -795,20 +795,9 @@ public void testGetSg_ClientFails_ThenFail() { doThrow(new RuntimeException("Client Failure in Sg Get")).when(client).admin(); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); - interactionsIndex.getInteraction("cid", "iid", getListener); + interactionsIndex.getInteraction("iid", getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get")); } - - public void testGetSg_NoAccess_ThenFail() { - doReturn(true).when(metadata).hasIndex(anyString()); - setupDenyAccess("Henry"); - @SuppressWarnings("unchecked") - ActionListener getListener = mock(ActionListener.class); - interactionsIndex.getInteraction("cid", "iid", getListener); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); - verify(getListener, times(1)).onFailure(argCaptor.capture()); - assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to conversation cid")); - } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index a979505a52..2c1c28c529 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -322,11 +322,11 @@ public void testGetAnInteraction_Future() { Collections.singletonMap("meta", "some meta") ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(1); listener.onResponse(interaction); return null; - }).when(interactionsIndex).getInteraction(any(), any(), any()); - ActionFuture result = cmHandler.getInteraction("cid", "iid"); + }).when(interactionsIndex).getInteraction(any(), any()); + ActionFuture result = cmHandler.getInteraction("iid"); assert (result.actionGet().equals(interaction)); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index d0129e87ab..52156771bd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -139,7 +139,7 @@ public void execute(Input input, ActionListener listener) { client .execute( GetInteractionAction.INSTANCE, - new GetInteractionRequest(memoryId, regenerateInteractionId), + new GetInteractionRequest(regenerateInteractionId), ActionListener.wrap(interactionRes -> { inputDataSet .getParameters() diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java index 234a8f856e..185b116b85 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -417,7 +417,7 @@ public void testBuildTraceQuery() { String query = Strings.toString(XContentType.JSON, queryBuilder); Assert .assertEquals( - "{\"bool\":{\"should\":[{\"ids\":{\"values\":[\"interaction-id-1\"],\"boost\":1.0}},{\"bool\":{\"must\":[{\"exists\":{\"field\":\"trace_number\",\"boost\":1.0}},{\"term\":{\"parent_interaction_id\":{\"value\":\"interaction-id-1\",\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}}],\"adjust_pure_negative\":true,\"boost\":1.0}}", + "{\"bool\":{\"should\":[{\"ids\":{\"values\":[\"interaction-id-1\"],\"boost\":1.0}},{\"bool\":{\"must\":[{\"exists\":{\"field\":\"trace_number\",\"boost\":1.0}},{\"term\":{\"parent_message_id\":{\"value\":\"interaction-id-1\",\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}}],\"adjust_pure_negative\":true,\"boost\":1.0}}", query ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java index c0934056b6..63ab88d09e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateConversationAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.conversation.ActionConstants.CONVERSATION_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; import java.io.IOException; @@ -49,7 +50,7 @@ private UpdateConversationRequest getRequest(RestRequest request) throws IOExcep throw new OpenSearchParseException("Failed to update conversation: Request body is empty"); } - String conversationId = getParameterId(request, "conversation_id"); + String conversationId = getParameterId(request, CONVERSATION_ID_FIELD); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java index dafc0352ec..dfb315d0d0 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_INTERACTION_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; import java.io.IOException; @@ -49,7 +50,7 @@ private UpdateInteractionRequest getRequest(RestRequest request) throws IOExcept throw new OpenSearchParseException("Failed to update interaction: Request body is empty"); } - String interactionId = getParameterId(request, "interaction_id"); + String interactionId = getParameterId(request, RESPONSE_INTERACTION_ID_FIELD); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java index 23ef15d1fc..d691d24ef5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.conversation.ActionConstants.CONVERSATION_ID_FIELD; + import java.io.IOException; import java.util.Map; @@ -55,7 +57,7 @@ public void testCreateConversation() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("conversation_id")); + assert (map.containsKey(CONVERSATION_ID_FIELD)); } public void testCreateConversationNamed() throws IOException { @@ -73,6 +75,6 @@ public void testCreateConversationNamed() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("conversation_id")); + assert (map.containsKey(CONVERSATION_ID_FIELD)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java index 1cf0cd34ce..f59f3bf8ec 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.conversation.ActionConstants.CONVERSATION_ID_FIELD; + import java.io.IOException; import java.util.Map; @@ -55,8 +57,8 @@ public void testCreateInteraction() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String id = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String id = (String) ccmap.get(CONVERSATION_ID_FIELD); Map params = Map .of( @@ -75,7 +77,7 @@ public void testCreateInteraction() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", id), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", id), null, gson.toJson(params), null @@ -85,6 +87,6 @@ public void testCreateInteraction() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("interaction_id")); + assert (map.containsKey("message_id")); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java index 2eb7589696..9c6f7c83b5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java @@ -17,6 +17,11 @@ */ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.conversation.ActionConstants.CONVERSATION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_CONVERSATION_LIST_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_INTERACTION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_INTERACTION_LIST_FIELD; + import java.io.IOException; import java.util.ArrayList; import java.util.Map; @@ -57,18 +62,11 @@ public void testDeleteConversation_ThatExists() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String id = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String id = (String) ccmap.get(CONVERSATION_ID_FIELD); Response response = TestHelper - .makeRequest( - client(), - "DELETE", - ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{conversation_id}", id), - null, - "", - null - ); + .makeRequest(client(), "DELETE", ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{memory_id}", id), null, "", null); assert (response != null); assert (TestHelper.restStatus(response) == RestStatus.OK); HttpEntity httpEntity = response.getEntity(); @@ -83,7 +81,7 @@ public void testDeleteConversation_ThatDoesNotExist() throws IOException { .makeRequest( client(), "DELETE", - ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{conversation_id}", "happybirthday"), + ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{memory_id}", "happybirthday"), null, "", null @@ -104,8 +102,8 @@ public void testDeleteConversation_WithInteractions() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String cid = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(CONVERSATION_ID_FIELD); Map params = Map .of( @@ -124,7 +122,7 @@ public void testDeleteConversation_WithInteractions() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params), null @@ -134,18 +132,11 @@ public void testDeleteConversation_WithInteractions() throws IOException { HttpEntity cihttpEntity = ciresponse.getEntity(); String cientityString = TestHelper.httpEntityToString(cihttpEntity); Map cimap = gson.fromJson(cientityString, Map.class); - assert (cimap.containsKey("interaction_id")); - String iid = (String) cimap.get("interaction_id"); + assert (cimap.containsKey(RESPONSE_INTERACTION_ID_FIELD)); + String iid = (String) cimap.get(RESPONSE_INTERACTION_ID_FIELD); Response dcresponse = TestHelper - .makeRequest( - client(), - "DELETE", - ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{conversation_id}", cid), - null, - "", - null - ); + .makeRequest(client(), "DELETE", ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{memory_id}", cid), null, "", null); assert (dcresponse != null); assert (TestHelper.restStatus(dcresponse) == RestStatus.OK); HttpEntity dchttpEntity = dcresponse.getEntity(); @@ -160,21 +151,21 @@ public void testDeleteConversation_WithInteractions() throws IOException { HttpEntity gchttpEntity = gcresponse.getEntity(); String gcentityString = TestHelper.httpEntityToString(gchttpEntity); Map gcmap = gson.fromJson(gcentityString, Map.class); - assert (gcmap.containsKey("conversations")); + assert (gcmap.containsKey(RESPONSE_CONVERSATION_LIST_FIELD)); assert (!gcmap.containsKey("next_token")); - assert (((ArrayList) gcmap.get("conversations")).size() == 0); + assert (((ArrayList) gcmap.get(RESPONSE_CONVERSATION_LIST_FIELD)).size() == 0); try { Response giresponse = TestHelper - .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{memory_id}", cid), null, "", null); assert (giresponse != null); assert (TestHelper.restStatus(giresponse) == RestStatus.OK); HttpEntity gihttpEntity = giresponse.getEntity(); String gientityString = TestHelper.httpEntityToString(gihttpEntity); Map gimap = gson.fromJson(gientityString, Map.class); - assert (gimap.containsKey("interactions")); + assert (gimap.containsKey(RESPONSE_INTERACTION_LIST_FIELD)); assert (!gimap.containsKey("next_token")); - assert (((ArrayList) gimap.get("interactions")).size() == 0); + assert (((ArrayList) gimap.get(RESPONSE_INTERACTION_LIST_FIELD)).size() == 0); assert (false); } catch (ResponseException e) { assert (TestHelper.restStatus(e.getResponse()) == RestStatus.NOT_FOUND); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java index 5a55b1c301..28df548c6e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java @@ -67,7 +67,7 @@ public void testGetConversation() throws IOException { String id = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); Response gcresponse = TestHelper - .makeRequest(client(), "GET", ActionConstants.GET_CONVERSATION_REST_PATH.replace("{conversation_id}", id), null, "", null); + .makeRequest(client(), "GET", ActionConstants.GET_CONVERSATION_REST_PATH.replace("{memory_id}", id), null, "", null); assert (gcresponse != null); assert (TestHelper.restStatus(gcresponse) == RestStatus.OK); HttpEntity gchttpEntity = gcresponse.getEntity(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java index 2b2f409908..185ea02b39 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java @@ -17,6 +17,9 @@ */ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.conversation.ActionConstants.CONVERSATION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_CONVERSATION_LIST_FIELD; + import java.io.IOException; import java.util.ArrayList; import java.util.Map; @@ -70,8 +73,8 @@ public void testGetConversations_LastPage() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String id = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String id = (String) ccmap.get(CONVERSATION_ID_FIELD); Response response = TestHelper.makeRequest(client(), "GET", ActionConstants.GET_CONVERSATIONS_REST_PATH, null, "", null); assert (response != null); @@ -79,13 +82,13 @@ public void testGetConversations_LastPage() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("conversations")); + assert (map.containsKey(RESPONSE_CONVERSATION_LIST_FIELD)); assert (!map.containsKey("next_token")); @SuppressWarnings("unchecked") - ArrayList conversations = (ArrayList) map.get("conversations"); + ArrayList conversations = (ArrayList) map.get(RESPONSE_CONVERSATION_LIST_FIELD); assert (conversations.size() == 1); - assert (conversations.get(0).containsKey("conversation_id")); - assert (((String) conversations.get(0).get("conversation_id")).equals(id)); + assert (conversations.get(0).containsKey(CONVERSATION_ID_FIELD)); + assert (((String) conversations.get(0).get(CONVERSATION_ID_FIELD)).equals(id)); } public void testConversations_MorePages() throws IOException { @@ -95,8 +98,8 @@ public void testConversations_MorePages() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String id = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String id = (String) ccmap.get(CONVERSATION_ID_FIELD); Response response = TestHelper .makeRequest( @@ -112,13 +115,13 @@ public void testConversations_MorePages() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("conversations")); + assert (map.containsKey(RESPONSE_CONVERSATION_LIST_FIELD)); assert (map.containsKey("next_token")); @SuppressWarnings("unchecked") - ArrayList conversations = (ArrayList) map.get("conversations"); + ArrayList conversations = (ArrayList) map.get(RESPONSE_CONVERSATION_LIST_FIELD); assert (conversations.size() == 1); - assert (conversations.get(0).containsKey("conversation_id")); - assert (((String) conversations.get(0).get("conversation_id")).equals(id)); + assert (conversations.get(0).containsKey(CONVERSATION_ID_FIELD)); + assert (((String) conversations.get(0).get(CONVERSATION_ID_FIELD)).equals(id)); assert (((Double) map.get("next_token")).intValue() == 1); } @@ -129,9 +132,9 @@ public void testGetConversations_nextPage() throws IOException, InterruptedExcep HttpEntity cchttpEntity1 = ccresponse1.getEntity(); String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1); Map ccmap1 = gson.fromJson(ccentityString1, Map.class); - assert (ccmap1.containsKey("conversation_id")); + assert (ccmap1.containsKey(CONVERSATION_ID_FIELD)); logger.info("ccentityString1={}", ccentityString1); - String id1 = (String) ccmap1.get("conversation_id"); + String id1 = (String) ccmap1.get(CONVERSATION_ID_FIELD); // wait for 0.1s to make sure update time is different between conversation 1 and 2 TimeUnit.MICROSECONDS.sleep(100); @@ -142,8 +145,8 @@ public void testGetConversations_nextPage() throws IOException, InterruptedExcep HttpEntity cchttpEntity2 = ccresponse2.getEntity(); String ccentityString2 = TestHelper.httpEntityToString(cchttpEntity2); Map ccmap2 = gson.fromJson(ccentityString2, Map.class); - assert (ccmap2.containsKey("conversation_id")); - String id2 = (String) ccmap2.get("conversation_id"); + assert (ccmap2.containsKey(CONVERSATION_ID_FIELD)); + String id2 = (String) ccmap2.get(CONVERSATION_ID_FIELD); Response response1 = TestHelper .makeRequest( @@ -159,13 +162,13 @@ public void testGetConversations_nextPage() throws IOException, InterruptedExcep HttpEntity httpEntity1 = response1.getEntity(); String entityString1 = TestHelper.httpEntityToString(httpEntity1); Map map1 = gson.fromJson(entityString1, Map.class); - assert (map1.containsKey("conversations")); + assert (map1.containsKey(RESPONSE_CONVERSATION_LIST_FIELD)); assert (map1.containsKey("next_token")); @SuppressWarnings("unchecked") - ArrayList conversations1 = (ArrayList) map1.get("conversations"); + ArrayList conversations1 = (ArrayList) map1.get(RESPONSE_CONVERSATION_LIST_FIELD); assert (conversations1.size() == 1); - assert (conversations1.get(0).containsKey("conversation_id")); - Assert.assertEquals(conversations1.get(0).get("conversation_id"), id2); + assert (conversations1.get(0).containsKey(CONVERSATION_ID_FIELD)); + Assert.assertEquals(conversations1.get(0).get(CONVERSATION_ID_FIELD), id2); assert (((Double) map1.get("next_token")).intValue() == 1); Response response = TestHelper @@ -182,13 +185,13 @@ public void testGetConversations_nextPage() throws IOException, InterruptedExcep HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("conversations")); + assert (map.containsKey(RESPONSE_CONVERSATION_LIST_FIELD)); assert (map.containsKey("next_token")); @SuppressWarnings("unchecked") - ArrayList conversations = (ArrayList) map.get("conversations"); + ArrayList conversations = (ArrayList) map.get(RESPONSE_CONVERSATION_LIST_FIELD); assert (conversations.size() == 1); - assert (conversations.get(0).containsKey("conversation_id")); - assert (((String) conversations.get(0).get("conversation_id")).equals(id1)); + assert (conversations.get(0).containsKey(CONVERSATION_ID_FIELD)); + assert (((String) conversations.get(0).get(CONVERSATION_ID_FIELD)).equals(id1)); assert (((Double) map.get("next_token")).intValue() == 2); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java index da196ad7d8..e711948b6c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java @@ -84,7 +84,7 @@ public void testGetInteraction() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params), null @@ -99,14 +99,7 @@ public void testGetInteraction() throws IOException { String iid = cimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); Response giresponse = TestHelper - .makeRequest( - client(), - "GET", - ActionConstants.GET_INTERACTION_REST_PATH.replace("{conversation_id}", cid).replace("{interaction_id}", iid), - null, - "", - null - ); + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTION_REST_PATH.replace("{message_id}", iid), null, "", null); assert (giresponse != null); assert (TestHelper.restStatus(giresponse) == RestStatus.OK); HttpEntity gihttpEntity = giresponse.getEntity(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java index 9d0cc6515b..cbad8e0843 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java @@ -50,7 +50,7 @@ public void testBasics() { public void testPrepareRequest() throws Exception { RestMemoryGetInteractionAction action = new RestMemoryGetInteractionAction(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid")) + .withParams(Map.of(ActionConstants.MESSAGE_ID, "iid")) .build(); NodeClient client = mock(NodeClient.class); @@ -59,7 +59,6 @@ public void testPrepareRequest() throws Exception { ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionRequest.class); verify(client, times(1)).execute(eq(GetInteractionAction.INSTANCE), argCaptor.capture(), any()); - assert (argCaptor.getValue().getConversationId().equals("cid")); assert (argCaptor.getValue().getInteractionId().equals("iid")); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java index 1c37662218..8acfc4dcf1 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java @@ -17,6 +17,10 @@ */ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.conversation.ActionConstants.CONVERSATION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_INTERACTION_ID_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_INTERACTION_LIST_FIELD; + import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -52,14 +56,7 @@ public void setupFeatureSettings() throws IOException { public void testGetInteractions_NoConversation() throws IOException { Response response = TestHelper - .makeRequest( - client(), - "GET", - ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", "coffee"), - null, - "", - null - ); + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{memory_id}", "coffee"), null, "", null); assert (response != null); assert (TestHelper.restStatus(response) == RestStatus.OK); HttpEntity httpEntity = response.getEntity(); @@ -77,19 +74,19 @@ public void testGetInteractions_NoInteractions() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String cid = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(CONVERSATION_ID_FIELD); Response response = TestHelper - .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{memory_id}", cid), null, "", null); assert (response != null); assert (TestHelper.restStatus(response) == RestStatus.OK); HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("interactions")); + assert (map.containsKey(RESPONSE_INTERACTION_LIST_FIELD)); assert (!map.containsKey("next_token")); - assert (((ArrayList) map.get("interactions")).size() == 0); + assert (((ArrayList) map.get(RESPONSE_INTERACTION_LIST_FIELD)).size() == 0); } public void testGetInteractions_LastPage() throws IOException { @@ -99,8 +96,8 @@ public void testGetInteractions_LastPage() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String cid = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(CONVERSATION_ID_FIELD); Map params = Map .of( @@ -119,7 +116,7 @@ public void testGetInteractions_LastPage() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params), null @@ -129,22 +126,22 @@ public void testGetInteractions_LastPage() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("interaction_id")); - String iid = (String) map.get("interaction_id"); + assert (map.containsKey(RESPONSE_INTERACTION_ID_FIELD)); + String iid = (String) map.get(RESPONSE_INTERACTION_ID_FIELD); Response response1 = TestHelper - .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{memory_id}", cid), null, "", null); assert (response1 != null); assert (TestHelper.restStatus(response1) == RestStatus.OK); HttpEntity httpEntity1 = response1.getEntity(); String entityString1 = TestHelper.httpEntityToString(httpEntity1); Map map1 = gson.fromJson(entityString1, Map.class); - assert (map1.containsKey("interactions")); + assert (map1.containsKey(RESPONSE_INTERACTION_LIST_FIELD)); assert (!map1.containsKey("next_token")); - assert (((ArrayList) map1.get("interactions")).size() == 1); + assert (((ArrayList) map1.get(RESPONSE_INTERACTION_LIST_FIELD)).size() == 1); @SuppressWarnings("unchecked") - ArrayList interactions = (ArrayList) map1.get("interactions"); - assert (((String) interactions.get(0).get("interaction_id")).equals(iid)); + ArrayList interactions = (ArrayList) map1.get(RESPONSE_INTERACTION_LIST_FIELD); + assert (((String) interactions.get(0).get(RESPONSE_INTERACTION_ID_FIELD)).equals(iid)); } public void testGetInteractions_MorePages() throws IOException { @@ -154,8 +151,8 @@ public void testGetInteractions_MorePages() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String cid = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(CONVERSATION_ID_FIELD); Map params = Map .of( @@ -174,7 +171,7 @@ public void testGetInteractions_MorePages() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params), null @@ -184,14 +181,14 @@ public void testGetInteractions_MorePages() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("interaction_id")); - String iid = (String) map.get("interaction_id"); + assert (map.containsKey(RESPONSE_INTERACTION_ID_FIELD)); + String iid = (String) map.get(RESPONSE_INTERACTION_ID_FIELD); Response response1 = TestHelper .makeRequest( client(), "GET", - ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{memory_id}", cid), Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1"), "", null @@ -201,12 +198,12 @@ public void testGetInteractions_MorePages() throws IOException { HttpEntity httpEntity1 = response1.getEntity(); String entityString1 = TestHelper.httpEntityToString(httpEntity1); Map map1 = gson.fromJson(entityString1, Map.class); - assert (map1.containsKey("interactions")); + assert (map1.containsKey(RESPONSE_INTERACTION_LIST_FIELD)); assert (map1.containsKey("next_token")); - assert (((ArrayList) map1.get("interactions")).size() == 1); + assert (((ArrayList) map1.get(RESPONSE_INTERACTION_LIST_FIELD)).size() == 1); @SuppressWarnings("unchecked") - ArrayList interactions = (ArrayList) map1.get("interactions"); - assert (((String) interactions.get(0).get("interaction_id")).equals(iid)); + ArrayList interactions = (ArrayList) map1.get(RESPONSE_INTERACTION_LIST_FIELD); + assert (((String) interactions.get(0).get(RESPONSE_INTERACTION_ID_FIELD)).equals(iid)); assert (((Double) map1.get("next_token")).intValue() == 1); } @@ -217,8 +214,8 @@ public void testGetInteractions_NextPage() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String cid = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey(CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(CONVERSATION_ID_FIELD); Map params = Map .of( @@ -237,7 +234,7 @@ public void testGetInteractions_NextPage() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params), null @@ -247,14 +244,14 @@ public void testGetInteractions_NextPage() throws IOException { HttpEntity httpEntity = response.getEntity(); String entityString = TestHelper.httpEntityToString(httpEntity); Map map = gson.fromJson(entityString, Map.class); - assert (map.containsKey("interaction_id")); - String iid = (String) map.get("interaction_id"); + assert (map.containsKey(RESPONSE_INTERACTION_ID_FIELD)); + String iid = (String) map.get(RESPONSE_INTERACTION_ID_FIELD); Response response2 = TestHelper .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params), null @@ -264,14 +261,14 @@ public void testGetInteractions_NextPage() throws IOException { HttpEntity httpEntity2 = response2.getEntity(); String entityString2 = TestHelper.httpEntityToString(httpEntity2); Map map2 = gson.fromJson(entityString2, Map.class); - assert (map2.containsKey("interaction_id")); - String iid2 = (String) map2.get("interaction_id"); + assert (map2.containsKey(RESPONSE_INTERACTION_ID_FIELD)); + String iid2 = (String) map2.get(RESPONSE_INTERACTION_ID_FIELD); Response response1 = TestHelper .makeRequest( client(), "GET", - ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{memory_id}", cid), Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1"), "", null @@ -281,19 +278,19 @@ public void testGetInteractions_NextPage() throws IOException { HttpEntity httpEntity1 = response1.getEntity(); String entityString1 = TestHelper.httpEntityToString(httpEntity1); Map map1 = gson.fromJson(entityString1, Map.class); - assert (map1.containsKey("interactions")); + assert (map1.containsKey(RESPONSE_INTERACTION_LIST_FIELD)); assert (map1.containsKey("next_token")); - assert (((ArrayList) map1.get("interactions")).size() == 1); + assert (((ArrayList) map1.get(RESPONSE_INTERACTION_LIST_FIELD)).size() == 1); @SuppressWarnings("unchecked") - ArrayList interactions = (ArrayList) map1.get("interactions"); - assert (((String) interactions.get(0).get("interaction_id")).equals(iid)); + ArrayList interactions = (ArrayList) map1.get(RESPONSE_INTERACTION_LIST_FIELD); + assert (((String) interactions.get(0).get(RESPONSE_INTERACTION_ID_FIELD)).equals(iid)); assert (((Double) map1.get("next_token")).intValue() == 1); Response response3 = TestHelper .makeRequest( client(), "GET", - ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{memory_id}", cid), Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1", ActionConstants.NEXT_TOKEN_FIELD, "1"), "", null @@ -303,12 +300,12 @@ public void testGetInteractions_NextPage() throws IOException { HttpEntity httpEntity3 = response3.getEntity(); String entityString3 = TestHelper.httpEntityToString(httpEntity3); Map map3 = gson.fromJson(entityString3, Map.class); - assert (map3.containsKey("interactions")); + assert (map3.containsKey(RESPONSE_INTERACTION_LIST_FIELD)); assert (map3.containsKey("next_token")); - assert (((ArrayList) map3.get("interactions")).size() == 1); + assert (((ArrayList) map3.get(RESPONSE_INTERACTION_LIST_FIELD)).size() == 1); @SuppressWarnings("unchecked") - ArrayList interactions3 = (ArrayList) map3.get("interactions"); - assert (((String) interactions3.get(0).get("interaction_id")).equals(iid2)); + ArrayList interactions3 = (ArrayList) map3.get(RESPONSE_INTERACTION_LIST_FIELD); + assert (((String) interactions3.get(0).get(RESPONSE_INTERACTION_ID_FIELD)).equals(iid2)); assert (((Double) map3.get("next_token")).intValue() == 2); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java index 264ef5ea24..2c2ca0cc94 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java @@ -58,8 +58,8 @@ public void testSearchConversations_Successful() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String id = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey("memory_id")); + String id = (String) ccmap.get("memory_id"); Response scresponse = TestHelper .makeRequest(client(), "POST", ActionConstants.SEARCH_CONVERSATIONS_REST_PATH, null, matchAllSearchQuery(), null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java index 9de93ac103..f2f80b3e60 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java @@ -58,8 +58,8 @@ public void testSearchInteractions_Successfull() throws IOException { HttpEntity cchttpEntity = ccresponse.getEntity(); String ccentityString = TestHelper.httpEntityToString(cchttpEntity); Map ccmap = gson.fromJson(ccentityString, Map.class); - assert (ccmap.containsKey("conversation_id")); - String cid = (String) ccmap.get("conversation_id"); + assert (ccmap.containsKey("memory_id")); + String cid = (String) ccmap.get("memory_id"); Map params1 = Map .of( @@ -78,7 +78,7 @@ public void testSearchInteractions_Successfull() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params1), null @@ -88,8 +88,8 @@ public void testSearchInteractions_Successfull() throws IOException { HttpEntity cihttpEntity1 = ciresponse1.getEntity(); String cientityString1 = TestHelper.httpEntityToString(cihttpEntity1); Map cimap1 = gson.fromJson(cientityString1, Map.class); - assert (cimap1.containsKey("interaction_id")); - String iid1 = (String) cimap1.get("interaction_id"); + assert (cimap1.containsKey("message_id")); + String iid1 = (String) cimap1.get("message_id"); Map params2 = Map .of( @@ -108,7 +108,7 @@ public void testSearchInteractions_Successfull() throws IOException { .makeRequest( client(), "POST", - ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{memory_id}", cid), null, gson.toJson(params2), null @@ -118,14 +118,14 @@ public void testSearchInteractions_Successfull() throws IOException { HttpEntity cihttpEntity2 = ciresponse2.getEntity(); String cientityString2 = TestHelper.httpEntityToString(cihttpEntity2); Map cimap2 = gson.fromJson(cientityString2, Map.class); - assert (cimap2.containsKey("interaction_id")); - String iid2 = (String) cimap2.get("interaction_id"); + assert (cimap2.containsKey("message_id")); + String iid2 = (String) cimap2.get("message_id"); Response siresponse = TestHelper .makeRequest( client(), "POST", - ActionConstants.SEARCH_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + ActionConstants.SEARCH_INTERACTIONS_REST_PATH.replace("{memory_id}", cid), null, matchAllSearchQuery(), null diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java index 539527bdf5..2131b80740 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateConversationTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ActionConstants.CONVERSATION_ID_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; import java.util.HashMap; @@ -93,7 +94,7 @@ public void testRoutes() { assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.PUT, route.getMethod()); - assertEquals("/_plugins/_ml/memory/conversation/{conversation_id}/_update", route.getPath()); + assertEquals("/_plugins/_ml/memory/{memory_id}", route.getPath()); } public void testUpdateConversationRequest() throws Exception { @@ -115,7 +116,7 @@ public void testUpdateConnectorRequestWithEmptyContent() throws Exception { public void testUpdateConnectorRequestWithNullConversationId() throws Exception { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Request should contain conversation_id"); + exceptionRule.expectMessage("Request should contain memory_id"); RestRequest request = getRestRequestWithNullConversationId(); restMemoryUpdateConversationAction.handleRequest(request, channel, client); } @@ -125,10 +126,10 @@ private RestRequest getRestRequest() { final Map updateContent = Map.of(META_NAME_FIELD, "new name"); String requestContent = new Gson().toJson(updateContent); Map params = new HashMap<>(); - params.put("conversation_id", "test_conversationId"); + params.put(CONVERSATION_ID_FIELD, "test_conversationId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withPath("/_plugins/_ml/memory/{memory_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -138,10 +139,10 @@ private RestRequest getRestRequest() { private RestRequest getRestRequestWithEmptyContent() { RestRequest.Method method = RestRequest.Method.POST; Map params = new HashMap<>(); - params.put("conversation_id", "test_conversationId"); + params.put(CONVERSATION_ID_FIELD, "test_conversationId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withPath("/_plugins/_ml/memory/{memory_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build(); @@ -155,7 +156,7 @@ private RestRequest getRestRequestWithNullConversationId() { Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/memory/conversation/{conversation_id}/_update") + .withPath("/_plugins/_ml/memory/{memory_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java index cdfdaa2b3c..6d1b2d6dee 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryUpdateInteractionActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.conversation.ActionConstants.RESPONSE_INTERACTION_ID_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; import java.util.HashMap; @@ -93,7 +94,7 @@ public void testRoutes() { assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.PUT, route.getMethod()); - assertEquals("/_plugins/_ml/memory/interaction/{interaction_id}/_update", route.getPath()); + assertEquals("/_plugins/_ml/memory/message/{message_id}", route.getPath()); } public void testUpdateInteractionRequest() throws Exception { @@ -115,7 +116,7 @@ public void testUpdateInteractionRequestWithEmptyContent() throws Exception { public void testUpdateInteractionRequestWithNullInteractionId() throws Exception { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Request should contain interaction_id"); + exceptionRule.expectMessage("Request should contain message_id"); RestRequest request = getRestRequestWithNullInteractionId(); restMemoryUpdateInteractionAction.handleRequest(request, channel, client); } @@ -125,10 +126,10 @@ private RestRequest getRestRequest() { final Map updateContent = Map.of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!")); String requestContent = new Gson().toJson(updateContent); Map params = new HashMap<>(); - params.put("interaction_id", "test_interactionId"); + params.put(RESPONSE_INTERACTION_ID_FIELD, "test_interactionId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withPath("/_plugins/_ml/memory/message/{message_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -138,10 +139,10 @@ private RestRequest getRestRequest() { private RestRequest getRestRequestWithEmptyContent() { RestRequest.Method method = RestRequest.Method.POST; Map params = new HashMap<>(); - params.put("interaction_id", "test_interactionId"); + params.put(RESPONSE_INTERACTION_ID_FIELD, "test_interactionId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withPath("/_plugins/_ml/memory/message/{message_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build(); @@ -155,7 +156,7 @@ private RestRequest getRestRequestWithNullInteractionId() { Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/memory/interaction/{interaction_id}/_update") + .withPath("/_plugins/_ml/memory/message/{message_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build();