Skip to content

Commit

Permalink
add register action request/response
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 17, 2023
1 parent 4d8d32e commit e38c989
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class LLMSpec implements ToXContentObject {
public static final String MODEL_ID_FIELD = "model_id";
Expand Down
25 changes: 14 additions & 11 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -26,7 +27,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLAgent implements ToXContentObject, Writeable {
public static final String AGENT_NAME_FIELD = "name";
Expand Down Expand Up @@ -99,15 +100,17 @@ public MLAgent(StreamInput input) throws IOException{
if (input.readBoolean()) {
memory = new MLMemorySpec(input);
}
createdTime = input.readInstant();
lastUpdateTime = input.readInstant();
appType = input.readString();
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
appType = input.readOptionalString();
if (!"flow".equals(type)) {
Set<String> toolNames = new HashSet<>();
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
if (tools != null) {
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);

Check warning on line 112 in common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java#L112

Added line #L112 was not covered by tests
}
}
}
}
Expand Down Expand Up @@ -144,9 +147,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeInstant(createdTime);
out.writeInstant(lastUpdateTime);
out.writeString(appType);
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalString(appType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -18,7 +19,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


@EqualsAndHashCode
@Getter
public class MLMemorySpec implements ToXContentObject {
public static final String MEMORY_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLToolSpec implements ToXContentObject {
public static final String TOOL_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.transport.agent;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -20,6 +21,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLAgentGetResponse extends ActionResponse implements ToXContentObject {
MLAgent mlAgent;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import org.opensearch.action.ActionType;

public class MLRegisterAgentAction extends ActionType<MLRegisterAgentResponse> {
public static MLRegisterAgentAction INSTANCE = new MLRegisterAgentAction();
public static final String NAME = "cluster:admin/opensearch/ml/agents/register";

private MLRegisterAgentAction() {
super(NAME, MLRegisterAgentResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.agent.MLAgent;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLRegisterAgentRequest extends ActionRequest {

MLAgent mlAgent;

@Builder
public MLRegisterAgentRequest(MLAgent mlAgent) {
this.mlAgent = mlAgent;
}

public MLRegisterAgentRequest(StreamInput in) throws IOException {
super(in);
this.mlAgent = new MLAgent(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (mlAgent == null) {
exception = addValidationError("ML agent can't be null", exception);
}

return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.mlAgent.writeTo(out);
}

public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLRegisterAgentRequest) {
return (MLRegisterAgentRequest) actionRequest;

Check warning on line 63 in common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java#L63

Added line #L63 was not covered by tests
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelRequest", e);

Check warning on line 73 in common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java#L72-L73

Added lines #L72 - L73 were not covered by tests
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject {
public static final String AGENT_ID_FIELD = "agent_id";

private String agentId;

public MLRegisterAgentResponse(StreamInput in) throws IOException {
super(in);
this.agentId = in.readString();
}

public MLRegisterAgentResponse(String agentId) {
this.agentId= agentId;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(agentId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(AGENT_ID_FIELD, agentId);
builder.endObject();
return builder;
}

public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterAgentResponse) {
return (MLRegisterAgentResponse) actionResponse;

Check warning on line 52 in common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java#L52

Added line #L52 was not covered by tests
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e);

Check warning on line 62 in common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java#L61-L62

Added lines #L61 - L62 were not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -49,4 +54,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
return builder;
}

public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLUndeployModelsResponse) {
return (MLUndeployModelsResponse) actionResponse;

Check warning on line 60 in common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java#L60

Added line #L60 was not covered by tests
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUndeployModelsResponse(input);

Check warning on line 67 in common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java#L63-L67

Added lines #L63 - L67 were not covered by tests
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLUndeployModelsResponse", e);

Check warning on line 70 in common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java#L69-L70

Added lines #L69 - L70 were not covered by tests
}
}
}
Loading

0 comments on commit e38c989

Please sign in to comment.