From 57225a19e29b67dd174336afea98b721617f57c5 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sat, 16 Dec 2023 16:03:36 -0800 Subject: [PATCH] add register action request/response Signed-off-by: Yaliang Wu --- .../opensearch/ml/common/agent/LLMSpec.java | 3 +- .../opensearch/ml/common/agent/MLAgent.java | 25 ++--- .../ml/common/agent/MLMemorySpec.java | 3 +- .../ml/common/agent/MLToolSpec.java | 3 +- .../transport/agent/MLAgentGetResponse.java | 2 + .../agent/MLRegisterAgentAction.java | 18 ++++ .../agent/MLRegisterAgentRequest.java | 77 ++++++++++++++++ .../agent/MLRegisterAgentResponse.java | 65 +++++++++++++ .../undeploy/MLUndeployModelsResponse.java | 21 +++++ .../agent/MLAgentGetResponseTest.java | 45 ++++----- .../agent/MLRegisterAgentActionTest.java | 20 ++++ .../agent/MLRegisterAgentRequestTest.java | 91 +++++++++++++++++++ .../agent/MLRegisterAgentResponseTest.java | 75 +++++++++++++++ 13 files changed, 413 insertions(+), 35 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java index 6c0fda289a..561fe81d5f 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -19,7 +20,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; - +@EqualsAndHashCode @Getter public class LLMSpec implements ToXContentObject { public static final String MODEL_ID_FIELD = "model_id"; diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index ba2f241375..2d2dad77bf 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -26,7 +27,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; - +@EqualsAndHashCode @Getter public class MLAgent implements ToXContentObject, Writeable { public static final String AGENT_NAME_FIELD = "name"; @@ -99,15 +100,17 @@ public MLAgent(StreamInput input) throws IOException{ if (input.readBoolean()) { memory = new MLMemorySpec(input); } - createdTime = input.readInstant(); - lastUpdateTime = input.readInstant(); - appType = input.readString(); + createdTime = input.readOptionalInstant(); + lastUpdateTime = input.readOptionalInstant(); + appType = input.readOptionalString(); if (!"flow".equals(type)) { Set toolNames = new HashSet<>(); - for (MLToolSpec toolSpec : tools) { - String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); - if (toolNames.contains(toolName)) { - throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); + if (tools != null) { + for (MLToolSpec toolSpec : tools) { + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); + if (toolNames.contains(toolName)) { + throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); + } } } } @@ -144,9 +147,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } - out.writeInstant(createdTime); - out.writeInstant(lastUpdateTime); - out.writeString(appType); + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastUpdateTime); + out.writeOptionalString(appType); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java index aa192a7ee2..5d13d5236c 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import org.opensearch.core.common.io.stream.StreamInput; @@ -18,7 +19,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - +@EqualsAndHashCode @Getter public class MLMemorySpec implements ToXContentObject { public static final String MEMORY_TYPE_FIELD = "type"; diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 055c59d449..7b9b640c8a 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.agent; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -19,7 +20,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; - +@EqualsAndHashCode @Getter public class MLToolSpec implements ToXContentObject { public static final String TOOL_TYPE_FIELD = "type"; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java index a437ef0ed8..593e314b31 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.transport.agent; import lombok.Builder; +import lombok.Getter; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -20,6 +21,7 @@ import java.io.IOException; import java.io.UncheckedIOException; +@Getter public class MLAgentGetResponse extends ActionResponse implements ToXContentObject { MLAgent mlAgent; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentAction.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentAction.java new file mode 100644 index 0000000000..c5d1a1232f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.opensearch.action.ActionType; + +public class MLRegisterAgentAction extends ActionType { + public static MLRegisterAgentAction INSTANCE = new MLRegisterAgentAction(); + public static final String NAME = "cluster:admin/opensearch/ml/agents/register"; + + private MLRegisterAgentAction() { + super(NAME, MLRegisterAgentResponse::new); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java new file mode 100644 index 0000000000..00489ccf08 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.agent.MLAgent; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLRegisterAgentRequest extends ActionRequest { + + MLAgent mlAgent; + + @Builder + public MLRegisterAgentRequest(MLAgent mlAgent) { + this.mlAgent = mlAgent; + } + + public MLRegisterAgentRequest(StreamInput in) throws IOException { + super(in); + this.mlAgent = new MLAgent(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (mlAgent == null) { + exception = addValidationError("ML agent can't be null", exception); + } + + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.mlAgent.writeTo(out); + } + + public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLRegisterAgentRequest) { + return (MLRegisterAgentRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterAgentRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelRequest", e); + } + + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java new file mode 100644 index 0000000000..3005739416 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import lombok.Getter; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +@Getter +public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject { + public static final String AGENT_ID_FIELD = "agent_id"; + + private String agentId; + + public MLRegisterAgentResponse(StreamInput in) throws IOException { + super(in); + this.agentId = in.readString(); + } + + public MLRegisterAgentResponse(String agentId) { + this.agentId= agentId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(agentId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(AGENT_ID_FIELD, agentId); + builder.endObject(); + return builder; + } + + public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterAgentResponse) { + return (MLRegisterAgentResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterAgentResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java index 7534b52187..d86b889ce6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject { @@ -49,4 +54,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } return builder; } + + public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLUndeployModelsResponse) { + return (MLUndeployModelsResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUndeployModelsResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLUndeployModelsResponse", e); + } + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 7d733a4308..b692ce34ac 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -4,9 +4,11 @@ */ package org.opensearch.ml.common.transport.agent; +import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.*; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -17,40 +19,41 @@ import java.io.*; import java.time.Instant; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; public class MLAgentGetResponseTest { MLAgent mlAgent; + @Before + public void setUp() { + mlAgent = MLAgent.builder() + .name("test_agent") + .appType("test_app") + .type("flow") + .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) + .build(); + } + @Test public void Create_MLAgentResponse_With_StreamInput() throws IOException { // Create a BytesStreamOutput to simulate the StreamOutput - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - - //create a test agent using input - bytesStreamOutput.writeString("Test Agent"); - bytesStreamOutput.writeString("flow"); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeBoolean(false); - bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); - bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z")); - bytesStreamOutput.writeString("test"); - - StreamInput testInputStream = bytesStreamOutput.bytes().streamInput(); - - MLAgentGetResponse mlAgentGetResponse = new MLAgentGetResponse(testInputStream); - MLAgent testMlAgent = mlAgentGetResponse.mlAgent; - assertEquals("flow",testMlAgent.getType()); - assertEquals("Test Agent",testMlAgent.getName()); - assertEquals("test",testMlAgent.getAppType()); + MLAgentGetResponse agentGetResponse = new MLAgentGetResponse(mlAgent); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + agentGetResponse.writeTo(out); + } + }; + MLAgentGetResponse parsedResponse = MLAgentGetResponse.fromActionResponse(actionResponse); + assertNotSame(agentGetResponse, parsedResponse); + assertEquals(agentGetResponse.getMlAgent(), parsedResponse.getMlAgent()); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java new file mode 100644 index 0000000000..aa790d0ccd --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MLRegisterAgentActionTest { + + @Test + public void actionInstance() { + assertNotNull(MLRegisterAgentAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/agents/register", MLRegisterAgentAction.NAME); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java new file mode 100644 index 0000000000..2c189690dc --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLToolSpec; + +import java.io.IOException; +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class MLRegisterAgentRequestTest { + + MLAgent mlAgent; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + mlAgent = MLAgent.builder() + .name("test_agent") + .appType("test_app") + .type("flow") + .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) + .build(); + } + + @Test + public void constructor_Agent() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + assertEquals(mlAgent, registerAgentRequest.getMlAgent()); + + ActionRequestValidationException validationException = registerAgentRequest.validate(); + assertNull(validationException); + } + + @Test + public void constructor_NullAgent() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest((MLAgent) null); + assertNull(registerAgentRequest.getMlAgent()); + + ActionRequestValidationException validationException = registerAgentRequest.validate(); + assertNotNull(validationException); + assertTrue(validationException.toString().contains("ML agent can't be null")); + } + + @Test + public void writeTo_Success() throws IOException { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + registerAgentRequest.writeTo(bytesStreamOutput); + MLRegisterAgentRequest parsedRequest = new MLRegisterAgentRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlAgent, parsedRequest.getMlAgent()); + } + + @Test + public void fromActionRequest_Success() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + registerAgentRequest.writeTo(out); + } + }; + MLRegisterAgentRequest parsedRequest = MLRegisterAgentRequest.fromActionRequest(actionRequest); + assertNotSame(registerAgentRequest, parsedRequest); + assertEquals(registerAgentRequest.getMlAgent(), parsedRequest.getMlAgent()); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java new file mode 100644 index 0000000000..6c300e786c --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.agent; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; + +public class MLRegisterAgentResponseTest { + String agentId; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + agentId = "test_agent_id"; + } + + @Test + public void constructor_AgentId() { + MLRegisterAgentResponse response = new MLRegisterAgentResponse(agentId); + assertEquals(agentId, response.getAgentId()); + } + + @Test + public void writeTo_Success() throws IOException { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + registerAgentResponse.writeTo(bytesStreamOutput); + MLRegisterAgentResponse parsedResponse = new MLRegisterAgentResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(agentId, parsedResponse.getAgentId()); + } + + @Test + public void toXContent() throws IOException { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + registerAgentResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"agent_id\":\"test_agent_id\"}", jsonStr); + } + + @Test + public void fromActionResponse_Success() { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + ActionResponse actionResponse = new ActionResponse() { + + @Override + public void writeTo(StreamOutput out) throws IOException { + registerAgentResponse.writeTo(out); + } + }; + MLRegisterAgentResponse parsedResponse = MLRegisterAgentResponse.fromActionResponse(actionResponse); + assertNotSame(registerAgentResponse, parsedResponse); + assertEquals(registerAgentResponse.getAgentId(), parsedResponse.getAgentId()); + } +}