-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
separate get Trace actions into a new rest API (#1672)
Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
becfbe0
commit 8e36de3
Showing
10 changed files
with
391 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
/** | ||
* Action to return the traces associated with an interaction | ||
*/ | ||
public class GetTracesAction extends ActionType<GetTracesResponse> { | ||
/** Instance of this */ | ||
public static final GetTracesAction INSTANCE = new GetTracesAction(); | ||
/** Name of this action */ | ||
public static final String NAME = "cluster:admin/opensearch/ml/memory/trace/get"; | ||
|
||
private GetTracesAction() { | ||
super(NAME, GetTracesResponse::new); | ||
} | ||
|
||
} |
123 changes: 123 additions & 0 deletions
123
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
import java.io.IOException; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.rest.RestRequest; | ||
|
||
import lombok.Getter; | ||
|
||
/** | ||
* ActionRequest for get traces | ||
*/ | ||
public class GetTracesRequest extends ActionRequest { | ||
@Getter | ||
private String interactionId; | ||
@Getter | ||
private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS; | ||
@Getter | ||
private int from = 0; | ||
|
||
/** | ||
* Constructor | ||
* @param interactionId UID of the interaction to get traces from | ||
*/ | ||
public GetTracesRequest(String interactionId) { | ||
this.interactionId = interactionId; | ||
} | ||
|
||
/** | ||
* Constructor | ||
* @param interactionId UID of the conversation to get interactions from | ||
* @param maxResults number of interactions to retrieve | ||
*/ | ||
public GetTracesRequest(String interactionId, int maxResults) { | ||
this.interactionId = interactionId; | ||
this.maxResults = maxResults; | ||
} | ||
|
||
/** | ||
* Constructor | ||
* @param interactionId UID of the conversation to get interactions from | ||
* @param maxResults number of interactions to retrieve | ||
* @param from position of first interaction to retrieve | ||
*/ | ||
public GetTracesRequest(String interactionId, int maxResults, int from) { | ||
this.interactionId = interactionId; | ||
this.maxResults = maxResults; | ||
this.from = from; | ||
} | ||
|
||
/** | ||
* Constructor | ||
* @param in streaminput to read this from. assumes there was a GetTracesRequest.writeTo | ||
* @throws IOException if there wasn't a GIR in the stream | ||
*/ | ||
public GetTracesRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.interactionId = in.readString(); | ||
this.maxResults = in.readInt(); | ||
this.from = in.readInt(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(interactionId); | ||
out.writeInt(maxResults); | ||
out.writeInt(from); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (interactionId == null) { | ||
exception = addValidationError("Traces must be retrieved from an interaction", exception); | ||
} | ||
if (maxResults <= 0) { | ||
exception = addValidationError("The number of traces to retrieve must be positive", exception); | ||
} | ||
if (from < 0) { | ||
exception = addValidationError("The starting position must be nonnegative", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
/** | ||
* Makes a GetTracesRequest out of a RestRequest | ||
* @param request Rest Request representing a get traces request | ||
* @return a new GetTracesRequest | ||
* @throws IOException if something goes wrong | ||
*/ | ||
public static GetTracesRequest fromRestRequest(RestRequest request) throws IOException { | ||
String cid = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); | ||
if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { | ||
int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); | ||
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { | ||
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); | ||
return new GetTracesRequest(cid, maxResults, from); | ||
} else { | ||
return new GetTracesRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from); | ||
} | ||
} else { | ||
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { | ||
int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); | ||
return new GetTracesRequest(cid, maxResults); | ||
} else { | ||
return new GetTracesRequest(cid); | ||
} | ||
} | ||
} | ||
} |
56 changes: 56 additions & 0 deletions
56
memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContent; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.conversation.ActionConstants; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
/** | ||
* Action Response for get traces for an interaction | ||
*/ | ||
@AllArgsConstructor | ||
public class GetTracesResponse extends ActionResponse implements ToXContentObject { | ||
@Getter | ||
private List<Interaction> traces; | ||
|
||
/** | ||
* Constructor | ||
* @param in stream input; assumes GetTracesResponse.writeTo was called | ||
* @throws IOException if there's not a G.I.R. in the stream | ||
*/ | ||
public GetTracesResponse(StreamInput in) throws IOException { | ||
super(in); | ||
traces = in.readList(Interaction::fromStream); | ||
} | ||
|
||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeList(traces); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { | ||
builder.startObject(); | ||
builder.startArray(ActionConstants.RESPONSE_TRACES_LIST_FIELD); | ||
for (Interaction trace : traces) { | ||
trace.toXContent(builder, params); | ||
} | ||
builder.endArray(); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
} |
67 changes: 67 additions & 0 deletions
67
.../src/main/java/org/opensearch/ml/memory/action/conversation/GetTracesTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.memory.action.conversation; | ||
|
||
import java.util.List; | ||
|
||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
import org.opensearch.ml.memory.ConversationalMemoryHandler; | ||
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class GetTracesTransportAction extends HandledTransportAction<GetTracesRequest, GetTracesResponse> { | ||
private Client client; | ||
private ConversationalMemoryHandler cmHandler; | ||
|
||
/** | ||
* Constructor | ||
* @param transportService for inter-node communications | ||
* @param actionFilters for filtering actions | ||
* @param cmHandler Handler for conversational memory operations | ||
* @param client OS Client for dealing with OS | ||
* @param clusterService for some cluster ops | ||
*/ | ||
@Inject | ||
public GetTracesTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
OpenSearchConversationalMemoryHandler cmHandler, | ||
Client client, | ||
ClusterService clusterService | ||
) { | ||
super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new); | ||
this.client = client; | ||
this.cmHandler = cmHandler; | ||
} | ||
|
||
@Override | ||
public void doExecute(Task task, GetTracesRequest request, ActionListener<GetTracesResponse> actionListener) { | ||
int maxResults = request.getMaxResults(); | ||
int from = request.getFrom(); | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { | ||
ActionListener<GetTracesResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); | ||
ActionListener<List<Interaction>> al = ActionListener | ||
.wrap(traces -> { internalListener.onResponse(new GetTracesResponse(traces)); }, e -> { | ||
internalListener.onFailure(e); | ||
}); | ||
cmHandler.getTraces(request.getInteractionId(), from, maxResults, al); | ||
} catch (Exception e) { | ||
log.error("Failed to get traces for conversation " + request.getInteractionId(), e); | ||
actionListener.onFailure(e); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.