Skip to content

Commit

Permalink
Merge branch 'main' into main-mingshl
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Dec 16, 2023
2 parents 883e912 + 00b5fea commit 3b94bf2
Show file tree
Hide file tree
Showing 31 changed files with 3,346 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
.gradle/
build/
.idea/
.project
.classpath
.settings
client/build/
common/build/
ml-algorithms/build/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public class ActionConstants {
public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations";
/** name of list on interactions in all responses */
public final static String RESPONSE_INTERACTION_LIST_FIELD = "interactions";
/** name of list on traces in all responses */
public final static String RESPONSE_TRACES_LIST_FIELD = "traces";
/** name of interaction Id field in all responses */
public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id";

Expand Down Expand Up @@ -56,20 +58,27 @@ public class ActionConstants {
public final static String SUCCESS_FIELD = "success";

private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation";
private final static String BASE_REST_INTERACTION_PATH = "/_plugins/_ml/memory/interaction";
/** path for create conversation */
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create";
/** path for get conversations */
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
/** path for update conversations */
public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update";
/** path for create interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
/** path for get traces */
public final static String GET_TRACES_REST_PATH = "/_plugins/_ml/memory/trace" + "/{interaction_id}/_list";
/** path for delete conversation */
public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete";
/** path for search conversations */
public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search";
/** path for search interactions */
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search";
/** path for update interactions */
public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_INTERACTION_PATH + "/{interaction_id}/_update";
/** path for get conversation */
public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}";
/** path for get interaction */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;

/**
* Action to return the traces associated with an interaction
*/
public class GetTracesAction extends ActionType<GetTracesResponse> {
/** Instance of this */
public static final GetTracesAction INSTANCE = new GetTracesAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get";

private GetTracesAction() {
super(NAME, GetTracesResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

import lombok.Getter;

/**
* ActionRequest for get traces
*/
public class GetTracesRequest extends ActionRequest {
@Getter
private String interactionId;
@Getter
private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS;
@Getter
private int from = 0;

/**
* Constructor
* @param interactionId UID of the interaction to get traces from
*/
public GetTracesRequest(String interactionId) {
this.interactionId = interactionId;
}

/**
* Constructor
* @param interactionId UID of the conversation to get interactions from
* @param maxResults number of interactions to retrieve
*/
public GetTracesRequest(String interactionId, int maxResults) {
this.interactionId = interactionId;
this.maxResults = maxResults;
}

/**
* Constructor
* @param interactionId UID of the conversation to get interactions from
* @param maxResults number of interactions to retrieve
* @param from position of first interaction to retrieve
*/
public GetTracesRequest(String interactionId, int maxResults, int from) {
this.interactionId = interactionId;
this.maxResults = maxResults;
this.from = from;
}

/**
* Constructor
* @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo
* @throws IOException if there wasn't a GIR in the stream
*/
public GetTracesRequest(StreamInput in) throws IOException {
super(in);
this.interactionId = in.readString();
this.maxResults = in.readInt();
this.from = in.readInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(interactionId);
out.writeInt(maxResults);
out.writeInt(from);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (interactionId == null) {
exception = addValidationError("Traces must be retrieved from an interaction", exception);
}
if (maxResults <= 0) {
exception = addValidationError("The number of traces to retrieve must be positive", exception);
}
if (from < 0) {
exception = addValidationError("The starting position must be nonnegative", exception);
}

return exception;
}

/**
* Makes a GetTracesRequest out of a RestRequest
* @param request Rest Request representing a get traces request
* @return a new GetTracesRequest
* @throws IOException if something goes wrong
*/
public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException {
String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD);
if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) {
int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD));
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD));
return new GetTracesRequest(cid, maxResults, from);
} else {
return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from);
}
} else {
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD));
return new GetTracesRequest(cid, maxResults);
} else {
return new GetTracesRequest(cid);
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import java.io.IOException;
import java.util.List;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.conversation.Interaction;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;

/**
* Action Response for get traces for an interaction
*/
@AllArgsConstructor
public class GetTracesResponse extends ActionResponse implements ToXContentObject {
@Getter
@NonNull
private List<Interaction> traces;
@Getter
private int nextToken;
private boolean hasMoreTokens;

/**
* Constructor
* @param in stream input; assumes GetTracesResponse.writeTo was called
* @throws IOException if there's not a G.I.R. in the stream
*/
public GetTracesResponse(StreamInput in) throws IOException {
super(in);
traces = in.readList(Interaction::fromStream);
nextToken = in.readInt();
hasMoreTokens = in.readBoolean();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeList(traces);
out.writeInt(nextToken);
out.writeBoolean(hasMoreTokens);
}

/**
* Are there more pages in this search results
* @return whether there are more traces in this search
*/
public boolean hasMorePages() {
return hasMoreTokens;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD);
for (Interaction trace : traces) {
trace.toXContent(builder, params);
}
builder.endArray();
if (hasMoreTokens) {
builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken);
}
builder.endObject();
return builder;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import java.util.List;

import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class GetTracesTransportAction extends HandledTransportAction<GetTracesRequest, GetTracesResponse> {
private Client client;
private ConversationalMemoryHandler cmHandler;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
*/
@Inject
public GetTracesTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
) {
super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new);
this.client = client;
this.cmHandler = cmHandler;
}

@Override
public void doExecute(Task task, GetTracesRequest request, ActionListener<GetTracesResponse> actionListener) {
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
// TODO: check this newStoredContext() method and remove it if it's redundant
ActionListener<GetTracesResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<List<Interaction>> al = ActionListener.wrap(tracesList -> {
internalListener.onResponse(new GetTracesResponse(tracesList, from + maxResults, tracesList.size() == maxResults));
}, e -> { internalListener.onFailure(e); });
cmHandler.getTraces(request.getInteractionId(), from, maxResults, al);
} catch (Exception e) {
log.error("Failed to get traces for conversation " + request.getInteractionId(), e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class UpdateConversationAction extends ActionType<UpdateResponse> {
public static final UpdateConversationAction INSTANCE = new UpdateConversationAction();
public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/update";

private UpdateConversationAction() {
super(NAME, UpdateResponse::new);
}
}
Loading

0 comments on commit 3b94bf2

Please sign in to comment.