Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register agent API support for MLClient #1656

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -337,4 +339,21 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
* @param listener action listener
*/
void getTool(String toolName, ActionListener<ToolMetadata> listener);

/**
* Registers new agent and returns ActionFuture.
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
*/
default ActionFuture<MLRegisterAgentResponse> registerAgent(MLAgent mlAgent) {
PlainActionFuture<MLRegisterAgentResponse> actionFuture = PlainActionFuture.newFuture();
registerAgent(mlAgent, actionFuture);
return actionFuture;
}

/**
* Registers new agent and returns agent ID in response
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
*/
void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
Expand Down Expand Up @@ -253,6 +257,27 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
}

@Override
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build();
client
.execute(
MLRegisterAgentAction.INSTANCE,
mlRegisterAgentRequest,
ActionListener.wrap(listener::onResponse, listener::onFailure)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ActionListener.wrap(listener::onResponse, listener::onFailure)
getMLRegisterAgentResponseActionListener(listener)

);
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
ActionListener<MLRegisterAgentResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res);
return mlRegisterAgentResponse;
});
return actionListener;
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -40,6 +41,7 @@
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -82,6 +84,9 @@ public class MachineLearningClientTest {
@Mock
MLRegisterModelGroupResponse registerModelGroupResponse;

@Mock
MLRegisterAgentResponse registerAgentResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -178,6 +183,11 @@ public void listTools(ActionListener<List<ToolMetadata>> listener) {
public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
listener.onResponse(null);
}

@Override
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
listener.onResponse(registerAgentResponse);
}
};
}

Expand Down Expand Up @@ -365,4 +375,10 @@ public void createConnector() {

assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet());
}

@Test
public void testRegisterAgent() {
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -66,6 +67,9 @@
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
Expand Down Expand Up @@ -152,6 +156,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLRegisterModelGroupResponse> registerModelGroupResponseActionListener;

@Mock
ActionListener<MLRegisterAgentResponse> registerAgentResponseActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -676,6 +683,27 @@ public void createConnector() {

}

@Test
public void testRegisterAgent() {
String agentId = "agentId";

doAnswer(invocation -> {
ActionListener<MLRegisterAgentResponse> actionListener = invocation.getArgument(2);
MLRegisterAgentResponse output = new MLRegisterAgentResponse(agentId);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterAgentResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class);
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();

machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener);

verify(client).execute(eq(MLRegisterAgentAction.INSTANCE), isA(MLRegisterAgentRequest.class), any());
verify(registerAgentResponseActionListener).onResponse(argumentCaptor.capture());
assertEquals(agentId, (argumentCaptor.getValue()).getAgentId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand All @@ -701,4 +729,5 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti
SearchResponse.Clusters.EMPTY
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -41,4 +46,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterAgentResponse) {
return (MLRegisterAgentResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.function.Consumer;

import org.junit.Before;
import org.junit.Ignore;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.StepListener;
Expand All @@ -46,6 +47,7 @@
@Log4j2
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2)
@Ignore
public class ConversationalMemoryHandlerITTests extends OpenSearchIntegTestCase {

private Client client;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.function.Consumer;

import org.junit.Before;
import org.junit.Ignore;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.StepListener;
Expand All @@ -47,6 +48,7 @@
@Log4j2
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2)
@Ignore
public class ConversationMetaIndexITTests extends OpenSearchIntegTestCase {

private ClusterService clusterService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.CountDownLatch;

import org.junit.Before;
import org.junit.Ignore;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.StepListener;
import org.opensearch.client.Client;
Expand All @@ -39,6 +40,7 @@
@Log4j2
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2)
@Ignore
public class InteractionsIndexITTests extends OpenSearchIntegTestCase {

private Client client;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.utils.TestHelper;

@Ignore
public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {

private final String OPENAI_KEY = System.getenv("OPENAI_KEY");
Expand Down