Skip to content

Commit

Permalink
Add unit tests for Get and Delete APIs
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Dec 14, 2023
1 parent 9fbc0ea commit 418ab64
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.opensearch.ml.common.transport.agent;

import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class MLAgentDeleteActionTest {
@Test
public void testMLAgentDeleteActionInstance() {
assertNotNull(MLAgentDeleteAction.INSTANCE);
assertEquals("cluster:admin/opensearch/ml/agents/delete", MLAgentDeleteAction.NAME);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.opensearch.ml.common.transport.agent;

import org.junit.Test;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;

import java.io.IOException;

import static org.junit.Assert.assertEquals;
import static org.opensearch.action.ValidateActions.addValidationError;

public class MLAgentDeleteRequestTest {
String agentId;

@Test
public void constructor_AgentId() {
agentId = "test-abc";
MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId);
assertEquals(mLAgentDeleteRequest.agentId,agentId);
}

@Test
public void writeTo() throws IOException {
agentId = "test-hij";

MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId);
BytesStreamOutput output = new BytesStreamOutput();
mLAgentDeleteRequest.writeTo(output);

MLAgentDeleteRequest mLAgentDeleteRequest1 = new MLAgentDeleteRequest(output.bytes().streamInput());

assertEquals(mLAgentDeleteRequest.agentId, mLAgentDeleteRequest1.agentId);
assertEquals(agentId, mLAgentDeleteRequest1.agentId);
}

@Test
public void validate_Success() {
agentId = "not-null";
MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId);

assertEquals(null, mLAgentDeleteRequest.validate());
}

@Test
public void validate_Failure() {
agentId = null;
MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId);
assertEquals(null,mLAgentDeleteRequest.agentId);

ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null);
mLAgentDeleteRequest.validate().equals(exception) ;
}

@Test
public void fromActionRequest() throws IOException {
agentId = "test-lmn";
MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId);
assertEquals(mLAgentDeleteRequest.fromActionRequest(mLAgentDeleteRequest), mLAgentDeleteRequest);

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.opensearch.ml.common.transport.agent;

import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class MLAgentGetActionTest {

@Test
public void testMLAgentGetActionInstance() {
assertNotNull(MLAgentGetAction.INSTANCE);
assertEquals("cluster:admin/opensearch/ml/agents/get", MLAgentGetAction.NAME);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.opensearch.ml.common.transport.agent;

import org.junit.Test;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;

import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.opensearch.action.ValidateActions.addValidationError;

public class MLAgentGetRequestTest {
String agentId;

@Test
public void constructor_AgentId() {
agentId = "test-abc";
MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest("test-abc");
assertEquals(mLAgentGetRequest.getAgentId(),agentId);
}

@Test
public void writeTo() throws IOException {
agentId = "test-hij";

MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId);
BytesStreamOutput output = new BytesStreamOutput();
mLAgentGetRequest.writeTo(output);

MLAgentGetRequest mLAgentGetRequest1 = new MLAgentGetRequest(output.bytes().streamInput());

assertEquals(mLAgentGetRequest1.getAgentId(), mLAgentGetRequest.getAgentId());
assertEquals(mLAgentGetRequest1.getAgentId(), agentId);
}

@Test
public void validate_Success() {
agentId = "not-null";
MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId);

assertEquals(null, mLAgentGetRequest.validate());
}

@Test
public void validate_Failure() {
agentId = null;
MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId);
assertEquals(null,mLAgentGetRequest.agentId);

ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null);
mLAgentGetRequest.validate().equals(exception) ;
}
@Test
public void fromActionRequest() throws IOException {
agentId = "test-lmn";
MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId);
assertEquals(mLAgentGetRequest.fromActionRequest(mLAgentGetRequest), mLAgentGetRequest);

}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package org.opensearch.ml.common.transport.agent;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.common.io.stream.*;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;

import java.io.*;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

public class MLAgentGetResponseTest {

MLAgent mlAgent;

@Test
public void Create_MLAgentResponse_With_StreamInput() throws IOException {
// Create a BytesStreamOutput to simulate the StreamOutput
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();

//create a test agent using input
bytesStreamOutput.writeString("Test Agent");
bytesStreamOutput.writeString("flow");
bytesStreamOutput.writeBoolean(false);
bytesStreamOutput.writeBoolean(false);
bytesStreamOutput.writeBoolean(false);
bytesStreamOutput.writeBoolean(false);
bytesStreamOutput.writeBoolean(false);
bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z"));
bytesStreamOutput.writeInstant(Instant.parse("2023-12-31T12:00:00Z"));
bytesStreamOutput.writeString("test");

StreamInput testInputStream = bytesStreamOutput.bytes().streamInput();

MLAgentGetResponse mlAgentGetResponse = new MLAgentGetResponse(testInputStream);
MLAgent testMlAgent = mlAgentGetResponse.mlAgent;
assertEquals("flow",testMlAgent.getType());
assertEquals("Test Agent",testMlAgent.getName());
assertEquals("test",testMlAgent.getAppType());
}

@Test
public void mLAgentGetResponse_Builder() throws IOException {

MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder()
.mlAgent(mlAgent)
.build();

assertEquals(mlAgentGetResponse.mlAgent, mlAgent);
}
@Test
public void writeTo() throws IOException {
//create ml agent using MLAgent and mlAgentGetResponse
mlAgent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test");
MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder()
.mlAgent(mlAgent)
.build();
//use write out for both agents
BytesStreamOutput output = new BytesStreamOutput();
mlAgent.writeTo(output);
mlAgentGetResponse.writeTo(output);
MLAgent agent1 = mlAgentGetResponse.mlAgent;

assertEquals(mlAgent.getAppType(), agent1.getAppType());
assertEquals(mlAgent.getDescription(), agent1.getDescription());
assertEquals(mlAgent.getCreatedTime(), agent1.getCreatedTime());
assertEquals(mlAgent.getName(), agent1.getName());
assertEquals(mlAgent.getParameters(), agent1.getParameters());
assertEquals(mlAgent.getType(), agent1.getType());
}
@Test
public void toXContent() throws IOException {
mlAgent = new MLAgent("mock", "flow", "test", null, null, null, null, null, null, "test");
MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder()
.mlAgent(mlAgent)
.build();
XContentBuilder builder = XContentFactory.jsonBuilder();
ToXContent.Params params = EMPTY_PARAMS;
XContentBuilder getResponseXContentBuilder = mlAgentGetResponse.toXContent(builder, params);
assertEquals(getResponseXContentBuilder, mlAgent.toXContent(builder, params));
}

@Test
public void FromActionResponse() throws IOException {
MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder()
.mlAgent(mlAgent)
.build();
assertEquals(mlAgentGetResponse.fromActionResponse(mlAgentGetResponse), mlAgentGetResponse);

}
}

0 comments on commit 418ab64

Please sign in to comment.