Skip to content

Commit

Permalink
separate get Trace actions into a new rest API (#1672)
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Nov 22, 2023
1 parent becfbe0 commit 8e36de3
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public class ActionConstants {
public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations";
/** name of list on interactions in all responses */
public final static String RESPONSE_INTERACTION_LIST_FIELD = "interactions";
/** name of list on interactions in all responses */
public final static String RESPONSE_TRACES_LIST_FIELD = "traces";
/** name of interaction Id field in all responses */
public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id";

Expand Down Expand Up @@ -67,6 +69,8 @@ public class ActionConstants {
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
/** path for get interactions */
public final static String GET_TRACES_REST_PATH = "/_plugins/_ml/memory/trace" + "/{interaction_id}/_list";
/** path for delete conversation */
public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete";
/** path for search conversations */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ public ActionFuture<String> createInteraction(
*/
public void getInteractions(String conversationId, int from, int maxResults, ActionListener<List<Interaction>> listener);

public void getTraces(String interactionId, int from, int maxResults, ActionListener<List<Interaction>> listener);

/**
* Get the interactions associate with this conversation, sorted by recency
* @param conversationId the conversation whose interactions to get
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;

/**
* Action to return the traces associated with an interaction
*/
public class GetTracesAction extends ActionType<GetTracesResponse> {
/** Instance of this */
public static final GetTracesAction INSTANCE = new GetTracesAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get";

private GetTracesAction() {
super(NAME, GetTracesResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

import lombok.Getter;

/**
* ActionRequest for get traces
*/
public class GetTracesRequest extends ActionRequest {
@Getter
private String interactionId;
@Getter
private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS;
@Getter
private int from = 0;

/**
* Constructor
* @param interactionId UID of the interaction to get traces from
*/
public GetTracesRequest(String interactionId) {
this.interactionId = interactionId;
}

/**
* Constructor
* @param interactionId UID of the conversation to get interactions from
* @param maxResults number of interactions to retrieve
*/
public GetTracesRequest(String interactionId, int maxResults) {
this.interactionId = interactionId;
this.maxResults = maxResults;
}

/**
* Constructor
* @param interactionId UID of the conversation to get interactions from
* @param maxResults number of interactions to retrieve
* @param from position of first interaction to retrieve
*/
public GetTracesRequest(String interactionId, int maxResults, int from) {
this.interactionId = interactionId;
this.maxResults = maxResults;
this.from = from;
}

/**
* Constructor
* @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo
* @throws IOException if there wasn't a GIR in the stream
*/
public GetTracesRequest(StreamInput in) throws IOException {
super(in);
this.interactionId = in.readString();
this.maxResults = in.readInt();
this.from = in.readInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(interactionId);
out.writeInt(maxResults);
out.writeInt(from);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (interactionId == null) {
exception = addValidationError("Traces must be retrieved from an interaction", exception);
}
if (maxResults <= 0) {
exception = addValidationError("The number of traces to retrieve must be positive", exception);
}
if (from < 0) {
exception = addValidationError("The starting position must be nonnegative", exception);
}

return exception;
}

/**
* Makes a GetTracesRequest out of a RestRequest
* @param request Rest Request representing a get traces request
* @return a new GetTracesRequest
* @throws IOException if something goes wrong
*/
public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException {
String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD);
if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) {
int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD));
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD));
return new GetTracesRequest(cid, maxResults, from);
} else {
return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from);
}
} else {
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD));
return new GetTracesRequest(cid, maxResults);
} else {
return new GetTracesRequest(cid);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import java.io.IOException;
import java.util.List;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.conversation.Interaction;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
* Action Response for get traces for an interaction
*/
@AllArgsConstructor
public class GetTracesResponse extends ActionResponse implements ToXContentObject {
@Getter
private List<Interaction> traces;

/**
* Constructor
* @param in stream input; assumes GetTracesResponse.writeTo was called
* @throws IOException if there's not a G.I.R. in the stream
*/
public GetTracesResponse(StreamInput in) throws IOException {
super(in);
traces = in.readList(Interaction::fromStream);
}

public void writeTo(StreamOutput out) throws IOException {
out.writeList(traces);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD);
for (Interaction trace : traces) {
trace.toXContent(builder, params);
}
builder.endArray();
builder.endObject();
return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.memory.action.conversation;

import java.util.List;

import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class GetTracesTransportAction extends HandledTransportAction<GetTracesRequest, GetTracesResponse> {
private Client client;
private ConversationalMemoryHandler cmHandler;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetTracesTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client,
ClusterService clusterService
) {
super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new);
this.client = client;
this.cmHandler = cmHandler;
}

@Override
public void doExecute(Task task, GetTracesRequest request, ActionListener<GetTracesResponse> actionListener) {
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<GetTracesResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<List<Interaction>> al = ActionListener
.wrap(traces -> { internalListener.onResponse(new GetTracesResponse(traces)); }, e -> {
internalListener.onFailure(e);
});
cmHandler.getTraces(request.getInteractionId(), from, maxResults, al);
} catch (Exception e) {
log.error("Failed to get traces for conversation " + request.getInteractionId(), e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -265,8 +268,70 @@ public void getInteractions(String conversationId, int from, int maxResults, Act
@VisibleForTesting
void innerGetInteractions(String conversationId, int from, int maxResults, ActionListener<List<Interaction>> listener) {
SearchRequest request = Requests.searchRequest(indexName);
TermQueryBuilder builder = new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId);
request.source().query(builder);

// Build the query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Add the ExistsQueryBuilder for checking null values
ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD);
boolQueryBuilder.mustNot(existsQueryBuilder);

// Add the TermQueryBuilder for another field
TermQueryBuilder termQueryBuilder = QueryBuilders
.termQuery(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId);
boolQueryBuilder.must(termQueryBuilder);

// Set the query to the search source
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(boolQueryBuilder);

request.source(searchSourceBuilder);
request.source().from(from).size(maxResults);
request.source().sort(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, SortOrder.DESC);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<List<Interaction>> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<SearchResponse> al = ActionListener.wrap(response -> {
List<Interaction> result = new LinkedList<Interaction>();
for (SearchHit hit : response.getHits()) {
result.add(Interaction.fromSearchHit(hit));
}
internalListener.onResponse(result);
}, e -> { internalListener.onFailure(e); });
client
.admin()
.indices()
.refresh(Requests.refreshRequest(indexName), ActionListener.wrap(r -> { client.search(request, al); }, e -> {
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
}
}

/**
* Gets a list of interactions belonging to a conversation
* @param interactionId the interaction to read from
* @param from where to start in the reading
* @param maxResults how many interactions to return
* @param listener gets the list, sorted by recency, of interactions
*/
public void getTraces(String interactionId, int from, int maxResults, ActionListener<List<Interaction>> listener) {
SearchRequest request = Requests.searchRequest(indexName);
// Build the query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Add the ExistsQueryBuilder for checking null values
ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD);
boolQueryBuilder.must(existsQueryBuilder);

// Add the TermQueryBuilder for another field
TermQueryBuilder termQueryBuilder = QueryBuilders
.termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, interactionId);
boolQueryBuilder.must(termQueryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(boolQueryBuilder);

request.source(searchSourceBuilder);
request.source().from(from).size(maxResults);
request.source().sort(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, SortOrder.DESC);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,4 +379,8 @@ public ActionFuture<SearchResponse> searchInteractions(String conversationId, Se
return fut;
}

public void getTraces(String interactionId, int from, int maxResults, ActionListener<List<Interaction>> listener) {
interactionsIndex.getTraces(interactionId, from, maxResults, listener);
}

}
Loading

0 comments on commit 8e36de3

Please sign in to comment.