diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 76dc55e7e3..cf308f1d8d 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -30,7 +30,8 @@ public enum FunctionName { SPARSE_TOKENIZE, TEXT_SIMILARITY, QUESTION_ANSWERING, - AGENT; + AGENT, + CONNECTOR; public static FunctionName from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index ae43c10867..e424914b4f 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { } public enum ActionType { - PREDICT + PREDICT, + EXECUTE } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java new file mode 100644 index 0000000000..02e1c59cb4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.opensearch.action.ActionType; +import org.opensearch.ml.common.transport.MLTaskResponse; + +public class MLExecuteConnectorAction extends ActionType { + public static final MLExecuteConnectorAction INSTANCE = new MLExecuteConnectorAction(); + public static final String NAME = "cluster:admin/opensearch/ml/connectors/execute"; + + private MLExecuteConnectorAction() { + super(NAME, MLTaskResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java new file mode 100644 index 0000000000..ab7ffa9c9f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.transport.MLTaskRequest; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(level = AccessLevel.PRIVATE) +@ToString +public class MLExecuteConnectorRequest extends MLTaskRequest { + + String connectorId; + MLInput mlInput; + + @Builder + public MLExecuteConnectorRequest(String connectorId, MLInput mlInput, boolean dispatchTask) { + super(dispatchTask); + this.mlInput = mlInput; + this.connectorId = connectorId; + } + + public MLExecuteConnectorRequest(String connectorId, MLInput mlInput) { + this(connectorId, mlInput, true); + } + + public MLExecuteConnectorRequest(StreamInput in) throws IOException { + super(in); + this.connectorId = in.readString(); + this.mlInput = new MLInput(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.connectorId); + this.mlInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.mlInput == null) { + exception = addValidationError("ML input can't be null", exception); + } else if (this.mlInput.getInputDataset() == null) { + exception = addValidationError("input data can't be null", exception); + } + + return exception; + } + + + public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLExecuteConnectorRequest) { + return (MLExecuteConnectorRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLExecuteConnectorRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLPredictionTaskRequest", e); + } + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java new file mode 100644 index 0000000000..bdbe0bebfc --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public class MLExecuteConnectorRequestTests { + private MLExecuteConnectorRequest mlExecuteConnectorRequest; + private MLInput mlInput; + private String connectorId; + private String action; + + @Before + public void setUp(){ + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build(); + connectorId = "test_connector"; + action = "execute"; + mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build(); + mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).mlInput(mlInput).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + mlExecuteConnectorRequest.writeTo(output); + MLExecuteConnectorRequest parsedRequest = new MLExecuteConnectorRequest(output.bytes().streamInput()); + assertEquals(mlExecuteConnectorRequest.getConnectorId(), parsedRequest.getConnectorId()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getAlgorithm(), parsedRequest.getMlInput().getAlgorithm()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), parsedRequest.getMlInput().getInputDataset().getInputDataType()); + assertEquals("hello", ((RemoteInferenceInputDataSet)parsedRequest.getMlInput().getInputDataset()).getParameters().get("input")); + } + + @Test + public void validateSuccess() { + assertNull(mlExecuteConnectorRequest.validate()); + } + + @Test + public void testConstructor() { + MLExecuteConnectorRequest executeConnectorRequest = new MLExecuteConnectorRequest(connectorId, mlInput); + assertTrue(executeConnectorRequest.isDispatchTask()); + } + + @Test + public void validateWithNullMLInputException() { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder() + .build(); + ActionRequestValidationException exception = executeConnectorRequest.validate(); + assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLInputDataSetException() { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput()) + .build(); + ActionRequestValidationException exception = executeConnectorRequest.validate(); + assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequestWithMLExecuteConnectorRequestSuccess() { + assertSame(MLExecuteConnectorRequest.fromActionRequest(mlExecuteConnectorRequest), mlExecuteConnectorRequest); + } + + @Test + public void fromActionRequestWithNonMLExecuteConnectorRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlExecuteConnectorRequest.writeTo(out); + } + }; + MLExecuteConnectorRequest result = MLExecuteConnectorRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlExecuteConnectorRequest); + assertEquals(mlExecuteConnectorRequest.getConnectorId(), result.getConnectorId()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getFunctionName(), result.getMlInput().getFunctionName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLExecuteConnectorRequest.fromActionRequest(actionRequest); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java new file mode 100644 index 0000000000..7c1b78f83d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running connector. + */ +@Log4j2 +@ToolAnnotation(ConnectorTool.TYPE) +public class ConnectorTool implements Tool { + public static final String TYPE = "ConnectorTool"; + public static final String CONNECTOR_ID = "connector_id"; + public static final String CONNECTOR_ACTION = "connector_action"; + + @Setter + @Getter + private String name = ConnectorTool.TYPE; + @Getter + @Setter + private String description = Factory.DEFAULT_DESCRIPTION; + @Getter + private String version; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + private Client client; + private String connectorId; + + public ConnectorTool(Client client, String connectorId, String connectorAction) { + this.client = client; + if (connectorId == null) { + throw new IllegalArgumentException("connector_id can't be null"); + } + this.connectorId = connectorId; + + outputParser = new Parser() { + @Override + public Object parse(Object o) { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build(); + ActionRequest request = new MLExecuteConnectorRequest(connectorId, mlInput); + + client.execute(MLExecuteConnectorAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); + modelTensorOutput.getMlModelOutputs(); + if (outputParser == null) { + listener.onResponse((T) modelTensorOutput.getMlModelOutputs()); + } else { + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + } + }, e -> { + log.error("Failed to run model " + connectorId, e); + listener.onFailure(e); + })); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + return true; + } + + public static class Factory implements Tool.Factory { + public static final String TYPE = "ConnectorTool"; + public static final String DEFAULT_DESCRIPTION = "This tool will invoke external service."; + private Client client; + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ConnectorTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public ConnectorTool create(Map map) { + return new ConnectorTool(client, (String) map.get(CONNECTOR_ID), (String) map.get(CONNECTOR_ACTION)); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java new file mode 100644 index 0000000000..5bc507869e --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +public class ConnectorToolTests { + + @Mock + private Client client; + private Map otherParams; + + @Mock + private Parser mockOutputParser; + + @Mock + private ActionListener listener; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + ConnectorTool.Factory.getInstance().init(client); + + otherParams = Map.of("other", "[\"bar\"]"); + } + + @Test + public void testConnectorTool_NullConnectorId() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> ConnectorTool.Factory.getInstance().create(Map.of("connector_action", "execute")) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("connector_id can't be null")); + } + + @Test + public void testConnectorTool_DefaultOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test_connector", "connector_action", "execute")); + tool.run(null, ActionListener.wrap(r -> { assertEquals("response 1", r); }, e -> { throw new RuntimeException("Test failed"); })); + } + + @Test + public void testConnectorTool_NullOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test_connector", "connector_action", "execute")); + tool.setOutputParser(null); + + tool.run(null, ActionListener.wrap(r -> { + List response = (List) r; + assertEquals(1, response.size()); + assertEquals(1, ((ModelTensors) response.get(0)).getMlModelTensors().size()); + ModelTensor modelTensor1 = ((ModelTensors) response.get(0)).getMlModelTensors().get(0); + assertEquals(2, modelTensor1.getDataAsMap().size()); + assertEquals("response 1", modelTensor1.getDataAsMap().get("response")); + assertEquals("action1", modelTensor1.getDataAsMap().get("action")); + }, e -> { throw new RuntimeException("Test failed"); })); + } + + @Test + public void testConnectorTool_NotNullParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertTrue(tool.validate(Map.of("key1", "value1"))); + } + + @Test + public void testConnectorTool_NullParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertFalse(tool.validate(Map.of())); + } + + @Test + public void testConnectorTool_EmptyParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertFalse(tool.validate(null)); + } + + @Test + public void testConnectorTool_GetType() { + ConnectorTool.Factory.getInstance().init(client); + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertEquals("ConnectorTool", tool.getType()); + } + + @Test + public void testRunWithError() { + // Mocking the client.execute to simulate an error + String errorMessage = "Test Exception"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(client).execute(any(), any(), any()); + + // Running the test + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + tool.setOutputParser(mockOutputParser); + tool.run(otherParams, listener); + + // Verifying that onFailure was called + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testTool() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertEquals(ConnectorTool.TYPE, tool.getName()); + assertEquals(ConnectorTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertTrue(tool.validate(otherParams)); + assertEquals(ConnectorTool.Factory.DEFAULT_DESCRIPTION, tool.getDescription()); + assertEquals(ConnectorTool.Factory.DEFAULT_DESCRIPTION, ConnectorTool.Factory.getInstance().getDefaultDescription()); + assertEquals(ConnectorTool.TYPE, ConnectorTool.Factory.getInstance().getDefaultType()); + assertNull(ConnectorTool.Factory.getInstance().getDefaultVersion()); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java new file mode 100644 index 0000000000..497e6768d8 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; + +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.script.ScriptService; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ExecuteConnectorTransportAction extends HandledTransportAction { + + Client client; + ClusterService clusterService; + ScriptService scriptService; + NamedXContentRegistry xContentRegistry; + + ConnectorAccessControlHelper connectorAccessControlHelper; + EncryptorImpl encryptor; + + @Inject + public ExecuteConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + ConnectorAccessControlHelper connectorAccessControlHelper, + EncryptorImpl encryptor + ) { + super(MLExecuteConnectorAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); + this.client = client; + this.clusterService = clusterService; + this.scriptService = scriptService; + this.xContentRegistry = xContentRegistry; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.encryptor = encryptor; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest(request); + String connectorId = executeConnectorRequest.getConnectorId(); + String connectorAction = ConnectorAction.ActionType.EXECUTE.name(); + + if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener.wrap(connector -> { + if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + connector.decrypt(connectorAction, (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor + .executeAction(connectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> { + actionListener.onResponse(taskResponse); + }, e -> { actionListener.onFailure(e); })); + } + }, e -> { + log.error("Failed to get connector " + connectorId, e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.runBefore(listener, threadContext::restore)); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + connectorId)); + } + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java index b50f935774..9337653fc4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java @@ -42,8 +42,8 @@ public TransportExecuteTaskAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - MLExecuteTaskRequest mlPredictionTaskRequest = MLExecuteTaskRequest.fromActionRequest(request); - FunctionName functionName = mlPredictionTaskRequest.getFunctionName(); - mlExecuteTaskRunner.run(functionName, mlPredictionTaskRequest, transportService, listener); + MLExecuteTaskRequest mlExecuteTaskRequest = MLExecuteTaskRequest.fromActionRequest(request); + FunctionName functionName = mlExecuteTaskRequest.getFunctionName(); + mlExecuteTaskRunner.run(functionName, mlExecuteTaskRequest, transportService, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 89b812b613..e9a79236b1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -44,6 +44,7 @@ import org.opensearch.ml.action.agents.TransportSearchAgentAction; import org.opensearch.ml.action.config.GetConfigTransportAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; +import org.opensearch.ml.action.connector.ExecuteConnectorTransportAction; import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; @@ -118,6 +119,7 @@ import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerGetAction; @@ -168,6 +170,7 @@ import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.engine.tools.AgentTool; import org.opensearch.ml.engine.tools.CatIndexTool; +import org.opensearch.ml.engine.tools.ConnectorTool; import org.opensearch.ml.engine.tools.IndexMappingTool; import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.engine.tools.SearchIndexTool; @@ -398,6 +401,7 @@ public MachineLearningPlugin(Settings settings) { new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class), new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class), new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class), + new ActionHandler<>(MLExecuteConnectorAction.INSTANCE, ExecuteConnectorTransportAction.class), new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class), new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class), new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class), @@ -579,6 +583,7 @@ public Collection createComponents( IndexMappingTool.Factory.getInstance().init(client); SearchIndexTool.Factory.getInstance().init(client, xContentRegistry); VisualizationsTool.Factory.getInstance().init(client); + ConnectorTool.Factory.getInstance().init(client); toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance()); toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance()); @@ -586,6 +591,7 @@ public Collection createComponents( toolFactories.put(IndexMappingTool.TYPE, IndexMappingTool.Factory.getInstance()); toolFactories.put(SearchIndexTool.TYPE, SearchIndexTool.Factory.getInstance()); toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance()); + toolFactories.put(ConnectorTool.TYPE, ConnectorTool.Factory.getInstance()); if (externalToolFactories != null) { toolFactories.putAll(externalToolFactories); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java new file mode 100644 index 0000000000..0753eac8a8 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java @@ -0,0 +1,158 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.action.connector; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.script.ScriptService; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class ExecuteConnectorTransportActionTests extends OpenSearchTestCase { + + private ExecuteConnectorTransportAction action; + + @Mock + private Client client; + + @Mock + ActionListener actionListener; + @Mock + private ClusterService clusterService; + @Mock + private TransportService transportService; + @Mock + private ActionFilters actionFilters; + @Mock + private ScriptService scriptService; + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + private Metadata metaData; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLExecuteConnectorRequest request; + @Mock + private EncryptorImpl encryptor; + @Mock + private HttpConnector connector; + @Mock + private Task task; + @Mock + ThreadPool threadPool; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + ClusterState testState = new ClusterState( + new ClusterName("clusterName"), + 123l, + "111111", + metaData, + null, + null, + null, + Map.of(), + 0, + false + ); + when(clusterService.state()).thenReturn(testState); + + when(request.getConnectorId()).thenReturn("test_connector_id"); + when(request.getConnectorAction()).thenReturn("execute"); + + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + action = new ExecuteConnectorTransportAction( + transportService, + actionFilters, + client, + clusterService, + scriptService, + xContentRegistry, + connectorAccessControlHelper, + encryptor + ); + } + + public void testExecute_NoConnectorIndex() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assertTrue(argCaptor.getValue().getMessage().contains("Can't find connector test_connector_id")); + } + + public void testExecute_FailedToGetConnector() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(metaData.hasIndex(anyString())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("test failure")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assertTrue(argCaptor.getValue().getMessage().contains("test failure")); + } + + public void testExecute_NullMLInput() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(metaData.hasIndex(anyString())).thenReturn(true); + when(connector.getProtocol()).thenReturn(ConnectorProtocols.HTTP); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assertTrue(argCaptor.getValue().getMessage().contains("\"mlInput\" is null")); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index f5da656d06..cf1f87e09e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -961,6 +961,13 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt assertTrue(taskDone.get()); } + public String registerConnector(String createConnectorInput) throws IOException, InterruptedException { + Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + return connectorId; + } + public String registerRemoteModel(String createConnectorInput, String modelName, boolean deploy) throws IOException, InterruptedException { Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java new file mode 100644 index 0000000000..bd678ae259 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; + +import org.apache.hc.core5.http.ParseException; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; + +public class RestConnectorToolIT extends RestBaseAgentToolsIT { + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private String bedrockClaudeConnectorId; + + @Before + public void setUp() throws Exception { + super.setUp(); + String bedrockClaudeConnectorEntity = "{\n" + + " \"name\": \"BedRock Claude instant-v1 Connector \",\n" + + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"execute\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman:${parameters.question}\\\\n\\\\nAssistant:\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + this.bedrockClaudeConnectorId = registerConnector(bedrockClaudeConnectorEntity); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + deleteExternalIndices(); + } + + public void testConnectorToolInFlowAgent_WrongAction() throws IOException, ParseException { + String registerAgentRequestBody = "{\n" + + " \"name\": \"Test agent with connector tool\",\n" + + " \"type\": \"flow\",\n" + + " \"description\": \"This is a demo agent for connector tool\",\n" + + " \"app_type\": \"test1\",\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"ConnectorTool\",\n" + + " \"name\": \"bedrock_model\",\n" + + " \"parameters\": {\n" + + " \"connector_id\": \"" + + bedrockClaudeConnectorId + + "\",\n" + + " \"connector_action\": \"predict\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); + MatcherAssert.assertThat(exception.getMessage(), containsString("no execute action found")); + } + + public void testConnectorToolInFlowAgent() throws IOException, ParseException { + String registerAgentRequestBody = "{\n" + + " \"name\": \"Test agent with connector tool\",\n" + + " \"type\": \"flow\",\n" + + " \"description\": \"This is a demo agent for connector tool\",\n" + + " \"app_type\": \"test1\",\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"ConnectorTool\",\n" + + " \"name\": \"bedrock_model\",\n" + + " \"parameters\": {\n" + + " \"connector_id\": \"" + + bedrockClaudeConnectorId + + "\",\n" + + " \"connector_action\": \"execute\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + String result = executeAgent(agentId, agentInput); + assertNotNull(result); + } + +}