Skip to content

Commit

Permalink
refactor memory layer APIs (#1877)
Browse files Browse the repository at this point in the history
* refactor memory layer APIs

Signed-off-by: Xun Zhang <[email protected]>

* remove printout leftovers in tests

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Jan 18, 2024
1 parent 38b51f2 commit b3c033f
Show file tree
Hide file tree
Showing 44 changed files with 236 additions and 308 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
public class ActionConstants {

/** name of conversation Id field in all responses */
public final static String CONVERSATION_ID_FIELD = "conversation_id";
public final static String CONVERSATION_ID_FIELD = "memory_id";

/** name of list of conversations in all responses */
public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations";
Expand All @@ -32,7 +32,7 @@ public class ActionConstants {
/** name of list on traces in all responses */
public final static String RESPONSE_TRACES_LIST_FIELD = "traces";
/** name of interaction Id field in all responses */
public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id";
public final static String RESPONSE_INTERACTION_ID_FIELD = "message_id";

/** name of conversation name in all requests */
public final static String REQUEST_CONVERSATION_NAME_FIELD = "name";
Expand All @@ -51,38 +51,41 @@ public class ActionConstants {
/** name of metadata field in all requests */
public final static String ADDITIONAL_INFO_FIELD = "additional_info";
/** name of metadata field in all requests */
public final static String PARENT_INTERACTION_ID_FIELD = "parent_interaction_id";
public final static String PARENT_INTERACTION_ID_FIELD = "parent_message_id";
/** name of metadata field in all requests */
public final static String TRACE_NUMBER_FIELD = "trace_number";
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation";
private final static String BASE_REST_INTERACTION_PATH = "/_plugins/_ml/memory/interaction";
/** parameter for memory_id in URLs */
public final static String MEMORY_ID = "memory_id";
/** parameter for message_id in URLs */
public final static String MESSAGE_ID = "message_id";
private final static String BASE_REST_PATH = "/_plugins/_ml/memory";
/** path for create conversation */
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create";
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH;
/** path for get conversations */
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH;
/** path for update conversations */
public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update";
public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}";
/** path for create interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{memory_id}/messages";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}/messages";
/** path for get traces */
public final static String GET_TRACES_REST_PATH = "/_plugins/_ml/memory/trace" + "/{interaction_id}/_list";
public final static String GET_TRACES_REST_PATH = BASE_REST_PATH + "/message/{message_id}/traces";
/** path for delete conversation */
public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete";
public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{memory_id}";
/** path for search conversations */
public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search";
/** path for search interactions */
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search";
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}/_search";
/** path for update interactions */
public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_INTERACTION_PATH + "/{interaction_id}/_update";
public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/message/{message_id}";
/** path for get conversation */
public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}";
public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{memory_id}";
/** path for get interaction */
public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}";
public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/message/{message_id}";

/** default max results returned by get operations */
public final static int DEFAULT_MAX_RESULTS = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public class ConversationalIndexConstants {
/** Name of the conversational interactions index */
public final static String INTERACTIONS_INDEX_NAME = ".plugins-ml-memory-message";
/** Name of the interaction field for the conversation Id */
public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "conversation_id";
public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "memory_id";
/** Name of the interaction field for the human input */
public final static String INTERACTIONS_INPUT_FIELD = "input";
/** Name of the interaction field for the prompt template */
Expand All @@ -80,7 +80,7 @@ public class ConversationalIndexConstants {
/** Name of the interaction field for the timestamp */
public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time";
/** Name of the interaction id */
public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_interaction_id";
public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_message_id";
/** The trace number of an interaction */
public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number";
/** Mappings for the interactions index */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void test_ToXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
conversationMeta.toXContent(builder, EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
assertEquals(content, "{\"conversation_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}");
assertEquals(content, "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void test_ToXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
interaction.toXContent(builder, EMPTY_PARAMS);
String interactionContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"conversation_id\":\"conversation id\",\"interaction_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_interaction_id\":\"parant id\",\"trace_number\":1}", interactionContent);
assertEquals("{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,16 @@ public void createInteraction(

/**
* Get a single interaction
* @param conversationId id of the conversation this interaction belongs to
* @param interactionId id of this interaction
* @param listener receives the interaction
*/
public void getInteraction(String conversationId, String interactionId, ActionListener<Interaction> listener);
public void getInteraction(String interactionId, ActionListener<Interaction> listener);

/**
* Get a single interaction
* @param conversationId id of the conversation this interaction belongs to
* @param interactionId id of this interaction
* @return ActionFuture for the interaction
*/
public ActionFuture<Interaction> getInteraction(String conversationId, String interactionId);
public ActionFuture<Interaction> getInteraction(String interactionId);

}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public ActionRequestValidationException validate() {
* @throws IOException if something goes wrong reading from request
*/
public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException {
String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String cid = request.param(ActionConstants.MEMORY_ID);
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public ActionRequestValidationException validate() {
* @throws IOException if something goes wrong in translation
*/
public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException {
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String conversationId = request.param(ActionConstants.MEMORY_ID);
return new GetConversationRequest(conversationId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
*/
@AllArgsConstructor
public class GetInteractionRequest extends ActionRequest {
@Getter
private String conversationId;
@Getter
private String interactionId;

Expand All @@ -48,23 +46,18 @@ public class GetInteractionRequest extends ActionRequest {
*/
public GetInteractionRequest(StreamInput in) throws IOException {
super(in);
this.conversationId = in.readString();
this.interactionId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.conversationId);
out.writeString(this.interactionId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (conversationId == null) {
exception = addValidationError("Get Interaction Request must have a conversation id", exception);
}
if (interactionId == null) {
exception = addValidationError("Get Interaction Request must have an interaction id", exception);
}
Expand All @@ -78,8 +71,7 @@ public ActionRequestValidationException validate() {
* @throws IOException if something goes wrong reading from the rest request
*/
public static GetInteractionRequest fromRestRequest(RestRequest request) throws IOException {
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String interactionId = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD);
return new GetInteractionRequest(conversationId, interactionId);
String interactionId = request.param(ActionConstants.MESSAGE_ID);
return new GetInteractionRequest(interactionId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,15 @@ public void doExecute(Task task, GetInteractionRequest request, ActionListener<G
);
return;
}
String conversationId = request.getConversationId();
String interactionId = request.getInteractionId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<GetInteractionResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<Interaction> al = ActionListener.wrap(interaction -> {
internalListener.onResponse(new GetInteractionResponse(interaction));
}, e -> { internalListener.onFailure(e); });
cmHandler.getInteraction(conversationId, interactionId, al);
cmHandler.getInteraction(interactionId, al);
} catch (Exception e) {
log.error("Failed to get interaction " + interactionId + " in conversation " + conversationId, e);
log.error("Failed to get interaction " + interactionId, e);
actionListener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public ActionRequestValidationException validate() {
* @throws IOException if something goes wrong
*/
public static GetInteractionsRequest fromRestRequest(RestRequest request) throws IOException {
String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String cid = request.param(ActionConstants.MEMORY_ID);
if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) {
int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD));
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,10 @@ public void searchInteractions(String conversationId, SearchRequest request, Act

/**
* Get a single interaction
* @param conversationId id of the conversation this interaction belongs to
* @param interactionId id of this interaction
* @param listener receives the interaction
*/
public void getInteraction(String conversationId, String interactionId, ActionListener<Interaction> listener) {
public void getInteraction(String interactionId, ActionListener<Interaction> listener) {
if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) {
listener
.onFailure(
Expand All @@ -506,39 +505,25 @@ public void getInteraction(String conversationId, String interactionId, ActionLi
);
return;
}
conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Interaction> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId);
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
// If the conversation doesn't exist, fail
if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) {
throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found");
}
Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap());
internalListener.onResponse(interaction);
}, e -> { internalListener.onFailure(e); });
client
.admin()
.indices()
.refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
client.get(request, al);
}, e -> {
log.error("Failed to refresh interactions index during get interaction ", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Interaction> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId);
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
// If the conversation doesn't exist, fail
if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) {
throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found");
}
} else {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
}
}, e -> { listener.onFailure(e); }));
Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap());
internalListener.onResponse(interaction);
}, e -> { internalListener.onFailure(e); });
client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
client.get(request, al);
}, e -> {
log.error("Failed to refresh interactions index during get interaction ", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -421,23 +421,21 @@ public ActionFuture<ConversationMeta> getConversation(String conversationId) {

/**
* Get a single interaction
* @param conversationId id of the conversation this interaction belongs to
* @param interactionId id of this interaction
* @param listener receives the interaction
*/
public void getInteraction(String conversationId, String interactionId, ActionListener<Interaction> listener) {
interactionsIndex.getInteraction(conversationId, interactionId, listener);
public void getInteraction(String interactionId, ActionListener<Interaction> listener) {
interactionsIndex.getInteraction(interactionId, listener);
}

/**
* Get a single interaction
* @param conversationId id of the conversation this interaction belongs to
* @param interactionId id of this interaction
* @return ActionFuture for the interaction
*/
public ActionFuture<Interaction> getInteraction(String conversationId, String interactionId) {
public ActionFuture<Interaction> getInteraction(String interactionId) {
PlainActionFuture<Interaction> fut = PlainActionFuture.newFuture();
getInteraction(conversationId, interactionId, fut);
getInteraction(interactionId, fut);
return fut;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void testCreateConversationResponseStreaming() throws IOException {
public void testToXContent() throws IOException {
CreateConversationResponse response = new CreateConversationResponse("createme");
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
String expected = "{\"conversation_id\":\"createme\"}";
String expected = "{\"memory_id\":\"createme\"}";
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
assert (result.equals(expected));
Expand Down
Loading

0 comments on commit b3c033f

Please sign in to comment.