From 8331fe6874a768116b3f219587ae31fec7a9d563 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 6 Jun 2024 10:05:23 -0700 Subject: [PATCH 01/12] fix error message with unwrapping the root cause (#2458) Signed-off-by: Jing Zhang --- .../opensearch/ml/utils/MLExceptionUtils.java | 4 + .../ml/utils/error/ErrorMessageFactory.java | 16 +--- .../ml/rest/RestMLExecuteActionTests.java | 75 +++++++++++++++++++ .../utils/error/ErrorMessageFactoryTests.java | 44 ++++++----- 4 files changed, 102 insertions(+), 37 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 6838a9ff79..68fee24fba 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -59,4 +59,8 @@ public static void logException(String errorMessage, Exception e, Logger log) { log.error(errorMessage, e); } } + + public static Throwable getRootCause(Throwable t) { + return ExceptionUtils.getRootCause(t); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java index 30aace4be3..69a3c94abe 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/error/ErrorMessageFactory.java @@ -6,6 +6,7 @@ package org.opensearch.ml.utils.error; import org.opensearch.OpenSearchException; +import org.opensearch.ml.utils.MLExceptionUtils; import lombok.experimental.UtilityClass; @@ -23,22 +24,9 @@ public static ErrorMessage createErrorMessage(Throwable e, int status) { int st = status; if (t instanceof OpenSearchException) { st = ((OpenSearchException) t).status().getStatus(); - } else { - t = unwrapCause(e); } + t = MLExceptionUtils.getRootCause(t); return new ErrorMessage(t, st); } - - protected static Throwable unwrapCause(Throwable t) { - Throwable result = t; - if (result instanceof OpenSearchException) { - return result; - } - if (result.getCause() == null) { - return result; - } - result = unwrapCause(result.getCause()); - return result; - } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java index 597ae57a8a..ac570a6a4d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteActionTests.java @@ -30,6 +30,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -44,6 +45,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.RemoteTransportException; public class RestMLExecuteActionTests extends OpenSearchTestCase { @@ -206,4 +208,77 @@ public void testPrepareRequest_disabled() { when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); assertThrows(IllegalStateException.class, () -> restMLExecuteAction.handleRequest(request, channel, client)); } + + public void testPrepareRequestClientException() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new IllegalArgumentException("Illegal Argument Exception")); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.BAD_REQUEST, response.status()); + String content = response.content().utf8ToString(); + String expectedError = + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; + assertEquals(expectedError, response.content().utf8ToString()); + } + + public void testPrepareRequestClientWrappedException() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener + .onFailure( + new RemoteTransportException("Remote Transport Exception", new IllegalArgumentException("Illegal Argument Exception")) + ); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.BAD_REQUEST, response.status()); + String expectedError = + "{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}"; + assertEquals(expectedError, response.content().utf8ToString()); + } + + public void testPrepareRequestSystemException() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("System Exception")); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + doNothing().when(channel).sendResponse(any()); + RestRequest request = getLocalSampleCalculatorRestRequest(); + restMLExecuteAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); + verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any()); + Input input = argumentCaptor.getValue().getInput(); + assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName()); + ArgumentCaptor restResponseArgumentCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(channel, times(1)).sendResponse(restResponseArgumentCaptor.capture()); + BytesRestResponse response = (BytesRestResponse) restResponseArgumentCaptor.getValue(); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, response.status()); + String expectedError = + "{\"error\":{\"reason\":\"System Error\",\"details\":\"System Exception\",\"type\":\"RuntimeException\"},\"status\":500}"; + assertEquals(expectedError, response.content().utf8ToString()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java index 00f3da1b01..5acdb847be 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/error/ErrorMessageFactoryTests.java @@ -5,43 +5,41 @@ package org.opensearch.ml.utils.error; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import org.junit.Test; import org.opensearch.OpenSearchException; import org.opensearch.core.rest.RestStatus; +import org.opensearch.transport.RemoteTransportException; public class ErrorMessageFactoryTests { - private Throwable nonOpenSearchThrowable = new Throwable(); - private Throwable openSearchThrowable = new OpenSearchException(nonOpenSearchThrowable); - - @Test - public void openSearchExceptionShouldCreateEsErrorMessage() { - Exception exception = new OpenSearchException(nonOpenSearchThrowable); - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertTrue(msg.exception instanceof OpenSearchException); - } - @Test - public void nonOpenSearchExceptionShouldCreateGenericErrorMessage() { - Exception exception = new Exception(nonOpenSearchThrowable); - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertFalse(msg.exception instanceof OpenSearchException); + public void openSearchExceptionWithoutNestedException() { + Throwable openSearchThrowable = new OpenSearchException("OpenSearch Exception"); + ErrorMessage errorMessage = ErrorMessageFactory.createErrorMessage(openSearchThrowable, RestStatus.BAD_REQUEST.getStatus()); + assertTrue(errorMessage.exception instanceof OpenSearchException); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus()); } @Test - public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsErrorMessage() { - Exception exception = (Exception) openSearchThrowable; - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertTrue(msg.exception instanceof OpenSearchException); + public void openSearchExceptionWithNestedException() { + Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception"); + Throwable openSearchThrowable = new RemoteTransportException("Remote Transport Exception", nestedThrowable); + ErrorMessage errorMessage = ErrorMessageFactory + .createErrorMessage(openSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus()); + assertTrue(errorMessage.exception instanceof IllegalArgumentException); + assertEquals(RestStatus.BAD_REQUEST.getStatus(), errorMessage.getStatus()); } @Test - public void nonOpenSearchExceptionWithMultiLayerWrappedEsExceptionCauseShouldCreateEsErrorMessage() { - Exception exception = new Exception(new Throwable(new Throwable(openSearchThrowable))); - ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus()); - assertTrue(msg.exception instanceof OpenSearchException); + public void nonOpenSearchExceptionWithNestedException() { + Throwable nestedThrowable = new IllegalArgumentException("Illegal Argument Exception"); + Throwable nonOpenSearchThrowable = new Exception("Remote Transport Exception", nestedThrowable); + ErrorMessage errorMessage = ErrorMessageFactory + .createErrorMessage(nonOpenSearchThrowable, RestStatus.INTERNAL_SERVER_ERROR.getStatus()); + assertTrue(errorMessage.exception instanceof IllegalArgumentException); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), errorMessage.getStatus()); } } From a0272f2138254cf47a4e82c01173f8fb2ba0a7be Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 6 Jun 2024 15:01:34 -0700 Subject: [PATCH 02/12] Add connector tool (#2512) * expose connector action parameter Signed-off-by: Yaliang Wu * add connector tool Signed-off-by: Yaliang Wu * fix ut Signed-off-by: Yaliang Wu * fix it Signed-off-by: Yaliang Wu * fix flaky test Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../opensearch/ml/common/FunctionName.java | 3 +- .../common/connector/AbstractConnector.java | 14 +- .../ml/common/connector/Connector.java | 10 +- .../ml/common/connector/ConnectorAction.java | 3 +- .../ml/common/connector/HttpConnector.java | 22 +- .../connector/MLExecuteConnectorAction.java | 18 ++ .../connector/MLExecuteConnectorRequest.java | 90 +++++++++ .../ml/common/connector/AwsConnectorTest.java | 11 +- .../common/connector/HttpConnectorTest.java | 25 +-- .../MLExecuteConnectorRequestTests.java | 120 +++++++++++ .../remote/AwsConnectorExecutor.java | 21 +- .../algorithms/remote/ConnectorUtils.java | 31 +-- .../remote/HttpJsonConnectorExecutor.java | 27 ++- .../remote/MLSdkAsyncHttpResponseHandler.java | 10 +- .../remote/RemoteConnectorExecutor.java | 40 ++-- .../engine/algorithms/remote/RemoteModel.java | 6 +- .../ml/engine/tools/ConnectorTool.java | 148 ++++++++++++++ .../remote/AwsConnectorExecutorTest.java | 190 ++++++++++++------ .../algorithms/remote/ConnectorUtilsTest.java | 27 +-- .../remote/HttpJsonConnectorExecutorTest.java | 46 +++-- .../MLSdkAsyncHttpResponseHandlerTest.java | 24 ++- .../ml/engine/tools/ConnectorToolTests.java | 177 ++++++++++++++++ .../ExecuteConnectorTransportAction.java | 100 +++++++++ .../execute/TransportExecuteTaskAction.java | 6 +- .../TransportRegisterModelAction.java | 5 +- .../ml/plugin/MachineLearningPlugin.java | 6 + .../ExecuteConnectorTransportActionTests.java | 156 ++++++++++++++ .../TransportRegisterModelActionTests.java | 4 +- .../ml/rest/MLCommonsRestTestCase.java | 7 + .../ml/rest/RestConnectorToolIT.java | 136 +++++++++++++ 30 files changed, 1279 insertions(+), 204 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 76dc55e7e3..cf308f1d8d 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -30,7 +30,8 @@ public enum FunctionName { SPARSE_TOKENIZE, TEXT_SIMILARITY, QUESTION_ANSWERING, - AGENT; + AGENT, + CONNECTOR; public static FunctionName from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index fadab3ef9a..90837425c4 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -70,7 +70,7 @@ public abstract class AbstractConnector implements Connector { @Setter protected ConnectorClientConfig connectorClientConfig; - protected Map createPredictDecryptedHeaders(Map headers) { + protected Map createDecryptedHeaders(Map headers) { if (headers == null) { return null; } @@ -116,9 +116,9 @@ public void parseResponse(T response, List modelTensors, boolea } @Override - public Optional findPredictAction() { + public Optional findAction(String action) { if (actions != null) { - return actions.stream().filter(a -> a.getActionType() == ConnectorAction.ActionType.PREDICT).findFirst(); + return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); } return Optional.empty(); } @@ -131,12 +131,12 @@ public void removeCredential() { } @Override - public String getPredictEndpoint(Map parameters) { - Optional predictAction = findPredictAction(); - if (!predictAction.isPresent()) { + public String getActionEndpoint(String action, Map parameters) { + Optional actionEndpoint = findAction(action); + if (!actionEndpoint.isPresent()) { return null; } - String predictEndpoint = predictAction.get().getUrl(); + String predictEndpoint = actionEndpoint.get().getUrl(); if (parameters != null && parameters.size() > 0) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); predictEndpoint = substitutor.replace(predictEndpoint); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index e74a453dc9..12f8ca0eba 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -56,18 +56,18 @@ public interface Connector extends ToXContentObject, Writeable { ConnectorClientConfig getConnectorClientConfig(); - String getPredictEndpoint(Map parameters); + String getActionEndpoint(String action, Map parameters); - String getPredictHttpMethod(); + String getActionHttpMethod(String action); - T createPredictPayload(Map parameters); + T createPayload(String action, Map parameters); - void decrypt(Function function); + void decrypt(String action, Function function); void encrypt(Function function); Connector cloneConnector(); - Optional findPredictAction(); + Optional findAction(String action); void removeCredential(); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index ae43c10867..e424914b4f 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { } public enum ActionType { - PREDICT + PREDICT, + EXECUTE } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 5bb00560a2..fc01ffad38 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -30,6 +30,7 @@ import java.util.regex.Pattern; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; @@ -307,10 +308,10 @@ public void update(MLCreateConnectorInput updateContent, Function T createPredictPayload(Map parameters) { - Optional predictAction = findPredictAction(); - if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) { - String payload = predictAction.get().getRequestBody(); + public T createPayload(String action, Map parameters) { + Optional connectorAction = findAction(action); + if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) { + String payload = connectorAction.get().getRequestBody(); payload = fillNullParameters(parameters, payload); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); @@ -348,15 +349,15 @@ private List findStringParametersWithNullDefaultValue(String input) { } @Override - public void decrypt(Function function) { + public void decrypt(String action, Function function) { Map decrypted = new HashMap<>(); for (String key : credential.keySet()) { decrypted.put(key, function.apply(credential.get(key))); } this.decryptedCredential = decrypted; - Optional predictAction = findPredictAction(); - Map headers = predictAction.isPresent() ? predictAction.get().getHeaders() : null; - this.decryptedHeaders = createPredictDecryptedHeaders(headers); + Optional connectorAction = findAction(action); + Map headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null; + this.decryptedHeaders = createDecryptedHeaders(headers); } @Override @@ -378,8 +379,9 @@ public void encrypt(Function function) { } } - public String getPredictHttpMethod() { - return findPredictAction().get().getMethod(); + @Override + public String getActionHttpMethod(String action) { + return findAction(action).get().getMethod(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java new file mode 100644 index 0000000000..02e1c59cb4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.opensearch.action.ActionType; +import org.opensearch.ml.common.transport.MLTaskResponse; + +public class MLExecuteConnectorAction extends ActionType { + public static final MLExecuteConnectorAction INSTANCE = new MLExecuteConnectorAction(); + public static final String NAME = "cluster:admin/opensearch/ml/connectors/execute"; + + private MLExecuteConnectorAction() { + super(NAME, MLTaskResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java new file mode 100644 index 0000000000..ab7ffa9c9f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.transport.MLTaskRequest; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(level = AccessLevel.PRIVATE) +@ToString +public class MLExecuteConnectorRequest extends MLTaskRequest { + + String connectorId; + MLInput mlInput; + + @Builder + public MLExecuteConnectorRequest(String connectorId, MLInput mlInput, boolean dispatchTask) { + super(dispatchTask); + this.mlInput = mlInput; + this.connectorId = connectorId; + } + + public MLExecuteConnectorRequest(String connectorId, MLInput mlInput) { + this(connectorId, mlInput, true); + } + + public MLExecuteConnectorRequest(StreamInput in) throws IOException { + super(in); + this.connectorId = in.readString(); + this.mlInput = new MLInput(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.connectorId); + this.mlInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.mlInput == null) { + exception = addValidationError("ML input can't be null", exception); + } else if (this.mlInput.getInputDataset() == null) { + exception = addValidationError("input data can't be null", exception); + } + + return exception; + } + + + public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLExecuteConnectorRequest) { + return (MLExecuteConnectorRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLExecuteConnectorRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLPredictionTaskRequest", e); + } + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index a242c213ea..36a964cef1 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -31,6 +31,7 @@ import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; @@ -110,7 +111,7 @@ public void constructor_NoPredictAction() { Assert.assertNotNull(connector); connector.encrypt(encryptFunction); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals(null, connector.getSessionToken()); @@ -149,13 +150,13 @@ public void constructor() { AwsConnector connector = createAwsConnector(parameters, credential, url); connector.encrypt(encryptFunction); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); Assert.assertEquals("test_service", connector.getServiceName()); Assert.assertEquals("us-west-2", connector.getRegion()); - Assert.assertEquals("https://test.com/model1", connector.getPredictEndpoint(parameters)); + Assert.assertEquals("https://test.com/model1", connector.getActionEndpoint(PREDICT.name(), parameters)); } @Test @@ -170,13 +171,13 @@ public void constructor_NoParameter() { String url = "https://test.com"; AwsConnector connector = createAwsConnector(null, credential, url); connector.encrypt(encryptFunction); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SERVICE", connector.getServiceName()); Assert.assertEquals("decrypted: ENCRYPTED: US-WEST-2", connector.getRegion()); - Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null)); + Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null)); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 4f1df76da2..c25f9653c3 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -30,9 +30,10 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Optional; import java.util.function.Function; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; + public class HttpConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -119,7 +120,7 @@ public void cloneConnector() { @Test public void decrypt() { HttpConnector connector = createHttpConnector(); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Map decryptedCredential = connector.getDecryptedCredential(); Assert.assertEquals(1, decryptedCredential.size()); Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key")); @@ -148,42 +149,42 @@ public void encrypted() { } @Test - public void getPredictEndpoint() { + public void getActionEndpoint() { HttpConnector connector = createHttpConnector(); - Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null)); + Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null)); } @Test - public void getPredictHttpMethod() { + public void getActionHttpMethod() { HttpConnector connector = createHttpConnector(); - Assert.assertEquals("POST", connector.getPredictHttpMethod()); + Assert.assertEquals("POST", connector.getActionHttpMethod(PREDICT.name())); } @Test - public void createPredictPayload_Invalid() { + public void createPayload_Invalid() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Some parameter placeholder not filled in payload: input"); HttpConnector connector = createHttpConnector(); - String predictPayload = connector.createPredictPayload(null); + String predictPayload = connector.createPayload(PREDICT.name(), null); connector.validatePayload(predictPayload); } @Test - public void createPredictPayload_InvalidJson() { + public void createPayload_InvalidJson() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Invalid payload: {\"input\": ${parameters.input} }"); String requestBody = "{\"input\": ${parameters.input} }"; HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - String predictPayload = connector.createPredictPayload(null); + String predictPayload = connector.createPayload(PREDICT.name(), null); connector.validatePayload(predictPayload); } @Test - public void createPredictPayload() { + public void createPayload() { HttpConnector connector = createHttpConnector(); Map parameters = new HashMap<>(); parameters.put("input", "test input value"); - String predictPayload = connector.createPredictPayload(parameters); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); connector.validatePayload(predictPayload); Assert.assertEquals("{\"input\": \"test input value\"}", predictPayload); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java new file mode 100644 index 0000000000..f95b236259 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public class MLExecuteConnectorRequestTests { + private MLExecuteConnectorRequest mlExecuteConnectorRequest; + private MLInput mlInput; + private String connectorId; + + @Before + public void setUp(){ + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build(); + connectorId = "test_connector"; + mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build(); + mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).mlInput(mlInput).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + mlExecuteConnectorRequest.writeTo(output); + MLExecuteConnectorRequest parsedRequest = new MLExecuteConnectorRequest(output.bytes().streamInput()); + assertEquals(mlExecuteConnectorRequest.getConnectorId(), parsedRequest.getConnectorId()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getAlgorithm(), parsedRequest.getMlInput().getAlgorithm()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), parsedRequest.getMlInput().getInputDataset().getInputDataType()); + assertEquals("hello", ((RemoteInferenceInputDataSet)parsedRequest.getMlInput().getInputDataset()).getParameters().get("input")); + } + + @Test + public void validateSuccess() { + assertNull(mlExecuteConnectorRequest.validate()); + } + + @Test + public void testConstructor() { + MLExecuteConnectorRequest executeConnectorRequest = new MLExecuteConnectorRequest(connectorId, mlInput); + assertTrue(executeConnectorRequest.isDispatchTask()); + } + + @Test + public void validateWithNullMLInputException() { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder() + .build(); + ActionRequestValidationException exception = executeConnectorRequest.validate(); + assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLInputDataSetException() { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput()) + .build(); + ActionRequestValidationException exception = executeConnectorRequest.validate(); + assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequestWithMLExecuteConnectorRequestSuccess() { + assertSame(MLExecuteConnectorRequest.fromActionRequest(mlExecuteConnectorRequest), mlExecuteConnectorRequest); + } + + @Test + public void fromActionRequestWithNonMLExecuteConnectorRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlExecuteConnectorRequest.writeTo(out); + } + }; + MLExecuteConnectorRequest result = MLExecuteConnectorRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlExecuteConnectorRequest); + assertEquals(mlExecuteConnectorRequest.getConnectorId(), result.getConnectorId()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getFunctionName(), result.getMlInput().getFunctionName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLExecuteConnectorRequest.fromActionRequest(actionRequest); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 035b6a6d8d..2ebc7ce563 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -77,7 +77,8 @@ public Logger getLogger() { @SuppressWarnings("removal") @Override - public void invokeRemoteModel( + public void invokeRemoteService( + String action, MLInput mlInput, Map parameters, String payload, @@ -85,22 +86,30 @@ public void invokeRemoteModel( ActionListener> actionListener ) { try { - SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); + SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(signRequest(request)) .requestContentPublisher(new SimpleHttpContentPublisher(request)) .responseHandler( - new MLSdkAsyncHttpResponseHandler(executionContext, actionListener, parameters, connector, scriptService, mlGuard) + new MLSdkAsyncHttpResponseHandler( + executionContext, + actionListener, + parameters, + connector, + scriptService, + mlGuard, + action + ) ) .build(); AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); } catch (RuntimeException exception) { - log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); + log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception); actionListener.onFailure(exception); } catch (Throwable e) { - log.error("Failed to execute predict in aws connector", e); - actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e)); + log.error("Failed to execute {} in aws connector", action, e); + actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e)); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index a6181e1b2f..cad0278a6d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -63,6 +63,7 @@ public class ConnectorUtils { } public static RemoteInferenceInputDataSet processInput( + String action, MLInput mlInput, Connector connector, Map parameters, @@ -71,22 +72,23 @@ public static RemoteInferenceInputDataSet processInput( if (mlInput == null) { throw new IllegalArgumentException("Input is null"); } - Optional predictAction = connector.findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); + Optional connectorAction = connector.findAction(action); + if (connectorAction.isEmpty()) { + throw new IllegalArgumentException("no " + action + " action found"); } - RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService); + RemoteInferenceInputDataSet inputData = processMLInput(action, mlInput, connector, parameters, scriptService); escapeRemoteInferenceInputData(inputData); return inputData; } private static RemoteInferenceInputDataSet processMLInput( + String action, MLInput mlInput, Connector connector, Map parameters, ScriptService scriptService ) { - String preProcessFunction = getPreprocessFunction(mlInput, connector); + String preProcessFunction = getPreprocessFunction(action, mlInput, connector); if (preProcessFunction == null) { if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); @@ -168,9 +170,9 @@ public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet in } } - private static String getPreprocessFunction(MLInput mlInput, Connector connector) { - Optional predictAction = connector.findPredictAction(); - String preProcessFunction = predictAction.get().getPreProcessFunction(); + private static String getPreprocessFunction(String action, MLInput mlInput, Connector connector) { + Optional connectorAction = connector.findAction(action); + String preProcessFunction = connectorAction.get().getPreProcessFunction(); if (preProcessFunction != null) { return preProcessFunction; } @@ -181,6 +183,7 @@ private static String getPreprocessFunction(MLInput mlInput, Connector connector } public static ModelTensors processOutput( + String action, String modelResponse, Connector connector, ScriptService scriptService, @@ -194,12 +197,11 @@ public static ModelTensors processOutput( throw new IllegalArgumentException("guardrails triggered for LLM output"); } List modelTensors = new ArrayList<>(); - Optional predictAction = connector.findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); + Optional connectorAction = connector.findAction(action); + if (connectorAction.isEmpty()) { + throw new IllegalArgumentException("no " + action + " action found"); } - ConnectorAction connectorAction = predictAction.get(); - String postProcessFunction = connectorAction.getPostProcessFunction(); + String postProcessFunction = connectorAction.get().getPostProcessFunction(); postProcessFunction = fillProcessFunctionParameter(parameters, postProcessFunction); String responseFilter = parameters.get(RESPONSE_FILTER_FIELD); @@ -263,6 +265,7 @@ public static SdkHttpFullRequest signRequest( } public static SdkHttpFullRequest buildSdkRequest( + String action, Connector connector, Map parameters, String payload, @@ -279,7 +282,7 @@ public static SdkHttpFullRequest buildSdkRequest( log.error("Content length is 0. Aborting request to remote model"); throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); } - String endpoint = connector.getPredictEndpoint(parameters); + String endpoint = connector.getActionEndpoint(action, parameters); SdkHttpFullRequest.Builder builder = SdkHttpFullRequest .builder() .method(method) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index be15740bdb..ee29f67a43 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -81,7 +81,8 @@ public Logger getLogger() { @SuppressWarnings("removal") @Override - public void invokeRemoteModel( + public void invokeRemoteService( + String action, MLInput mlInput, Map parameters, String payload, @@ -90,15 +91,15 @@ public void invokeRemoteModel( ) { try { SdkHttpFullRequest request; - switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { + switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) { case "POST": log.debug("original payload to remote model: " + payload); - validateHttpClientParameters(parameters); - request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); + validateHttpClientParameters(action, parameters); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - validateHttpClientParameters(parameters); - request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET); + validateHttpClientParameters(action, parameters); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); break; default: throw new IllegalArgumentException("unsupported http method"); @@ -108,7 +109,15 @@ public void invokeRemoteModel( .request(request) .requestContentPublisher(new SimpleHttpContentPublisher(request)) .responseHandler( - new MLSdkAsyncHttpResponseHandler(executionContext, actionListener, parameters, connector, scriptService, mlGuard) + new MLSdkAsyncHttpResponseHandler( + executionContext, + actionListener, + parameters, + connector, + scriptService, + mlGuard, + action + ) ) .build(); AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); @@ -121,8 +130,8 @@ public void invokeRemoteModel( } } - private void validateHttpClientParameters(Map parameters) throws Exception { - String endpoint = connector.getPredictEndpoint(parameters); + private void validateHttpClientParameters(String action, Map parameters) throws Exception { + String endpoint = connector.getActionEndpoint(action, parameters); URL url = new URL(endpoint); String protocol = url.getProtocol(); String host = url.getHost(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index b289a76157..6ea03058f0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -55,6 +55,8 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle private final Connector connector; + private final String action; + private final ScriptService scriptService; private final MLGuard mlGuard; @@ -68,7 +70,8 @@ public MLSdkAsyncHttpResponseHandler( Map parameters, Connector connector, ScriptService scriptService, - MLGuard mlGuard + MLGuard mlGuard, + String action ) { this.executionContext = executionContext; this.actionListener = actionListener; @@ -76,6 +79,7 @@ public MLSdkAsyncHttpResponseHandler( this.connector = connector; this.scriptService = scriptService; this.mlGuard = mlGuard; + this.action = action; } @Override @@ -184,12 +188,12 @@ private void response() { } try { - ModelTensors tensors = processOutput(body, connector, scriptService, parameters, mlGuard); + ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard); tensors.setStatusCode(statusCode); actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors)); } catch (Exception e) { log.error("Failed to process response body: {}", body, e); - actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e)); + actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e)); } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index f52532bcd1..11e43cef85 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -50,7 +50,7 @@ public interface RemoteConnectorExecutor { public String RETRY_EXECUTOR = "opensearch_ml_predict_remote"; - default void executePredict(MLInput mlInput, ActionListener actionListener) { + default void executeAction(String action, MLInput mlInput, ActionListener actionListener) { ActionListener>> tensorActionListener = ActionListener.wrap(r -> { // Only all sub-requests success will call logics here ModelTensors[] modelTensors = new ModelTensors[r.size()]; @@ -61,7 +61,7 @@ default void executePredict(MLInput mlInput, ActionListener acti try { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - Tuple calculatedChunkSize = calculateChunkSize(textDocsInputDataSet); + Tuple calculatedChunkSize = calculateChunkSize(action, textDocsInputDataSet); GroupedActionListener> groupedActionListener = new GroupedActionListener<>( tensorActionListener, calculatedChunkSize.v1() @@ -72,7 +72,8 @@ default void executePredict(MLInput mlInput, ActionListener acti List textDocs = textDocsInputDataSet .getDocs() .subList(processedDocs, Math.min(processedDocs + calculatedChunkSize.v2(), textDocsInputDataSet.getDocs().size())); - preparePayloadAndInvokeRemoteModel( + preparePayloadAndInvoke( + action, MLInput .builder() .algorithm(FunctionName.TEXT_EMBEDDING) @@ -83,7 +84,7 @@ default void executePredict(MLInput mlInput, ActionListener acti ); } } else { - preparePayloadAndInvokeRemoteModel(mlInput, new ExecutionContext(0), new GroupedActionListener<>(tensorActionListener, 1)); + preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), new GroupedActionListener<>(tensorActionListener, 1)); } } catch (Exception e) { actionListener.onFailure(e); @@ -95,12 +96,12 @@ default void executePredict(MLInput mlInput, ActionListener acti * @param textDocsInputDataSet * @return Tuple of chunk size and step size. */ - private Tuple calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) { + private Tuple calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) { int textDocsLength = textDocsInputDataSet.getDocs().size(); Map parameters = getConnector().getParameters(); if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size")); - // We need to check the parameter on runtime as parameter can be passed into predict request + // We need to check the parameter on runtime as parameter can be passed into action request if (stepSize <= 0) { throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); } else { @@ -111,11 +112,11 @@ private Tuple calculateChunkSize(TextDocsInputDataSet textDocs return Tuple.tuple(textDocsLength / stepSize + 1, stepSize); } } else { - Optional predictAction = getConnector().findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); + Optional connectorAction = getConnector().findAction(action); + if (connectorAction.isEmpty()) { + throw new IllegalArgumentException("no " + action + " action found"); } - String preProcessFunction = predictAction.get().getPreProcessFunction(); + String preProcessFunction = connectorAction.get().getPreProcessFunction(); if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) { // user defined preprocess script, this case, the chunk size is always equals to text docs length. return Tuple.tuple(textDocsLength, 1); @@ -155,7 +156,8 @@ default void setUserRateLimiterMap(Map userRateLimiterMap) default void setMlGuard(MLGuard mlGuard) {} - default void preparePayloadAndInvokeRemoteModel( + default void preparePayloadAndInvoke( + String action, MLInput mlInput, ExecutionContext executionContext, ActionListener> actionListener @@ -173,13 +175,13 @@ default void preparePayloadAndInvokeRemoteModel( inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); } parameters.putAll(inputParameters); - RemoteInferenceInputDataSet inputData = processInput(mlInput, connector, parameters, getScriptService()); + RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService()); if (inputData.getParameters() != null) { parameters.putAll(inputData.getParameters()); } // override again to always prioritize the input parameter parameters.putAll(inputParameters); - String payload = connector.createPredictPayload(parameters); + String payload = connector.createPayload(action, parameters); connector.validatePayload(payload); String userStr = getClient() .threadPool() @@ -201,9 +203,9 @@ && getUserRateLimiterMap().get(user.getName()) != null throw new IllegalArgumentException("guardrails triggered for user input"); } if (getConnectorClientConfig().getMaxRetryTimes() != 0) { - invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + invokeRemoteServiceWithRetry(action, mlInput, parameters, payload, executionContext, actionListener); } else { - invokeRemoteModel(mlInput, parameters, payload, executionContext, actionListener); + invokeRemoteService(action, mlInput, parameters, payload, executionContext, actionListener); } } } @@ -230,7 +232,8 @@ default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClien } } - default void invokeRemoteModelWithRetry( + default void invokeRemoteServiceWithRetry( + String action, MLInput mlInput, Map parameters, String payload, @@ -252,7 +255,7 @@ default void invokeRemoteModelWithRetry( public void tryAction(ActionListener> listener) { // the listener here is RetryingListener // If the request success, or can not retry, will call delegate listener - invokeRemoteModel(mlInput, parameters, payload, executionContext, listener); + invokeRemoteService(action, mlInput, parameters, payload, executionContext, listener); } @Override @@ -272,7 +275,8 @@ public boolean shouldRetry(Exception e) { invokeRemoteModelAction.run(); }; - void invokeRemoteModel( + void invokeRemoteService( + String action, MLInput mlInput, Map parameters, String payload, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 5828395641..c8685c010e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -5,6 +5,8 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; + import java.util.Map; import org.opensearch.client.Client; @@ -66,7 +68,7 @@ public void asyncPredict(MLInput mlInput, ActionListener actionL return; } try { - connectorExecutor.executePredict(mlInput, actionListener); + connectorExecutor.executeAction(PREDICT.name(), mlInput, actionListener); } catch (RuntimeException e) { log.error("Failed to call remote model.", e); actionListener.onFailure(e); @@ -90,7 +92,7 @@ public boolean isModelReady() { public void initModel(MLModel model, Map params, Encryptor encryptor) { try { Connector connector = model.getConnector().cloneConnector(); - connector.decrypt((credential) -> encryptor.decrypt(credential)); + connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential)); this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java new file mode 100644 index 0000000000..cb8b231ebf --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running connector. + */ +@Log4j2 +@ToolAnnotation(ConnectorTool.TYPE) +public class ConnectorTool implements Tool { + public static final String TYPE = "ConnectorTool"; + public static final String CONNECTOR_ID = "connector_id"; + public static final String CONNECTOR_ACTION = "connector_action"; + + @Setter + @Getter + private String name = ConnectorTool.TYPE; + @Getter + @Setter + private String description = Factory.DEFAULT_DESCRIPTION; + @Getter + private String version; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + private Client client; + private String connectorId; + + public ConnectorTool(Client client, String connectorId) { + this.client = client; + if (connectorId == null) { + throw new IllegalArgumentException("connector_id can't be null"); + } + this.connectorId = connectorId; + + outputParser = new Parser() { + @Override + public Object parse(Object o) { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build(); + ActionRequest request = new MLExecuteConnectorRequest(connectorId, mlInput); + + client.execute(MLExecuteConnectorAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); + modelTensorOutput.getMlModelOutputs(); + if (outputParser == null) { + listener.onResponse((T) modelTensorOutput.getMlModelOutputs()); + } else { + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + } + }, e -> { + log.error("Failed to run model " + connectorId, e); + listener.onFailure(e); + })); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + return true; + } + + public static class Factory implements Tool.Factory { + public static final String TYPE = "ConnectorTool"; + public static final String DEFAULT_DESCRIPTION = "This tool will invoke external service."; + private Client client; + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ConnectorTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public ConnectorTool create(Map map) { + return new ConnectorTool(client, (String) map.get(CONNECTOR_ID)); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index bf13c9f68c..cb192e83f9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -15,6 +16,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; @@ -114,7 +116,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { exceptionRule.expectMessage("Missing credential"); ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -132,7 +134,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { public void executePredict_RemoteInferenceInput_EmptyIpAddress() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http:///mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -150,7 +152,7 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -159,7 +161,12 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + executor + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof NullPointerException; @@ -170,7 +177,7 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { public void executePredict_TextDocsInferenceInput() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -188,7 +195,7 @@ public void executePredict_TextDocsInferenceInput() { .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -198,14 +205,18 @@ public void executePredict_TextDocsInferenceInput() { MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test public void executePredict_TextDocsInferenceInput_withStepSize() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -225,7 +236,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -235,21 +246,29 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); MLInputDataset inputDataSet1 = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), + actionListener + ); Mockito.verify(actionListener, times(0)).onFailure(any()); - Mockito.verify(executor, times(3)).preparePayloadAndInvokeRemoteModel(any(), any(), any()); + Mockito.verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any()); } @Test public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResults() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -269,7 +288,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -277,17 +296,21 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); Integer idx = Integer.parseInt(doc.substring(doc.length() - 1)); actionListener.onResponse(new Tuple<>(3 - idx, new ModelTensors(modelTensors.subList(3 - idx, 4 - idx)))); return null; - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); Mockito.verify(actionListener, times(1)).onResponse(responseCaptor.capture()); @@ -305,7 +328,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_thenFail() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -325,7 +348,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -333,8 +356,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); if (doc.endsWith("1")) { actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST)); @@ -342,11 +365,15 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t actionListener.onResponse(new Tuple<>(0, new ModelTensors(modelTensors.subList(0, 1)))); } return null; - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); @@ -358,7 +385,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleFailures() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -378,7 +405,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -386,8 +413,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); if (!doc.endsWith("1")) { actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST)); @@ -395,11 +422,15 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF actionListener.onResponse(new Tuple<>(0, new ModelTensors(modelTensors.subList(0, 1)))); } return null; - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); @@ -414,7 +445,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() throws NoSuchFieldException, IllegalAccessException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -431,7 +462,7 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor0 = new AwsConnectorExecutor(connector); Field httpClientField = AwsConnectorExecutor.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); @@ -444,7 +475,12 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + executor + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof NullPointerException; @@ -454,7 +490,7 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArgumentException() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -472,7 +508,7 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -482,7 +518,11 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof IllegalArgumentException; @@ -492,7 +532,7 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictionAction() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -509,7 +549,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio .parameters(parameters) .credential(credential) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -519,18 +559,22 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture()); assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException; - assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage()); + assert "no PREDICT action found".equals(exceptionArgumentCaptor.getValue().getMessage()); } @Test public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPreProcessFunction() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -550,7 +594,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -561,14 +605,18 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test - public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { + public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -590,7 +638,7 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(connectorClientConfig) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -603,10 +651,14 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); - Mockito.verify(executor, times(0)).invokeRemoteModelWithRetry(any(), any(), any(), any(), any()); - Mockito.verify(executor, times(1)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(0)).invokeRemoteServiceWithRetry(any(), any(), any(), any(), any(), any()); + Mockito.verify(executor, times(1)).invokeRemoteService(any(), any(), any(), any(), any(), any()); // execute with retry enabled ConnectorClientConfig connectorClientConfig2 = new ConnectorClientConfig(10, 10, 10, 1, 1, 1, RetryBackoffPolicy.CONSTANT); @@ -620,12 +672,16 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(connectorClientConfig2) .build(); - connector2.decrypt((c) -> encryptor.decrypt(c)); + connector2.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); executor.initialize(connector2); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); - Mockito.verify(executor, times(1)).invokeRemoteModelWithRetry(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(1)).invokeRemoteServiceWithRetry(any(), any(), any(), any(), any(), any()); Mockito.verify(actionListener, times(0)).onFailure(any()); } @@ -659,7 +715,7 @@ public void testGetRetryBackoffPolicy() { } @Test - public void invokeRemoteModelWithRetry_whenRetryableException_thenRetryUntilSuccess() { + public void invokeRemoteServiceWithRetry_whenRetryableException_thenRetryUntilSuccess() { MLInput mlInput = mock(MLInput.class); Map parameters = Map.of(); String payload = ""; @@ -674,7 +730,7 @@ public void invokeRemoteModelWithRetry_whenRetryableException_thenRetryUntilSucc @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 10 invocation, then success if (countOfInvocation++ < 10) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -683,7 +739,7 @@ public Void answer(InvocationOnMock invocation) { } return null; } - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); when(executor.getConnectorClientConfig()).thenReturn(connectorClientConfig); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); @@ -699,14 +755,14 @@ public Void answer(InvocationOnMock invocation) { return null; }).when(executorService).execute(any()); - executor.invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + executor.invokeRemoteServiceWithRetry(PREDICT.name(), mlInput, parameters, payload, executionContext, actionListener); Mockito.verify(actionListener, times(0)).onFailure(any()); Mockito.verify(actionListener, times(1)).onResponse(any()); - Mockito.verify(executor, times(11)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(11)).invokeRemoteService(any(), any(), any(), any(), any(), any()); } @Test - public void invokeRemoteModelWithRetry_whenRetryExceedMaxRetryTimes_thenCallOnFailure() { + public void invokeRemoteServiceWithRetry_whenRetryExceedMaxRetryTimes_thenCallOnFailure() { MLInput mlInput = mock(MLInput.class); Map parameters = Map.of(); String payload = ""; @@ -721,7 +777,7 @@ public void invokeRemoteModelWithRetry_whenRetryExceedMaxRetryTimes_thenCallOnFa @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 10 invocation, then success if (countOfInvocation++ < 10) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -730,7 +786,7 @@ public Void answer(InvocationOnMock invocation) { } return null; } - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); when(executor.getConnectorClientConfig()).thenReturn(connectorClientConfig); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); @@ -746,14 +802,14 @@ public Void answer(InvocationOnMock invocation) { return null; }).when(executorService).execute(any()); - executor.invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + executor.invokeRemoteServiceWithRetry(PREDICT.name(), mlInput, parameters, payload, executionContext, actionListener); Mockito.verify(actionListener, times(1)).onFailure(any()); Mockito.verify(actionListener, times(0)).onResponse(any()); - Mockito.verify(executor, times(6)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(6)).invokeRemoteService(any(), any(), any(), any(), any(), any()); } @Test - public void invokeRemoteModelWithRetry_whenNonRetryableException_thenCallOnFailure() { + public void invokeRemoteServiceWithRetry_whenNonRetryableException_thenCallOnFailure() { MLInput mlInput = mock(MLInput.class); Map parameters = Map.of(); String payload = ""; @@ -768,7 +824,7 @@ public void invokeRemoteModelWithRetry_whenNonRetryableException_thenCallOnFailu @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 2 invocation with retryable exception, then fail with non-retryable exception if (countOfInvocation++ < 2) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -777,7 +833,7 @@ public Void answer(InvocationOnMock invocation) { } return null; } - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); when(executor.getConnectorClientConfig()).thenReturn(connectorClientConfig); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); @@ -795,10 +851,10 @@ public Void answer(InvocationOnMock invocation) { ArgumentCaptor exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class); - executor.invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + executor.invokeRemoteServiceWithRetry(PREDICT.name(), mlInput, parameters, payload, executionContext, actionListener); Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture()); Mockito.verify(actionListener, times(0)).onResponse(any()); - Mockito.verify(executor, times(3)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(3)).invokeRemoteService(any(), any(), any(), any(), any(), any()); assert exceptionArgumentCaptor.getValue() instanceof OpenSearchStatusException; assertEquals("test failure", exceptionArgumentCaptor.getValue().getMessage()); assertEquals("test failure retryable", exceptionArgumentCaptor.getValue().getSuppressed()[0].getMessage()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index cb7f8a4fe8..31b0f5e420 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -56,7 +57,7 @@ public void setUp() { public void processInput_NullInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Input is null"); - ConnectorUtils.processInput(null, null, new HashMap<>(), null); + ConnectorUtils.processInput(PREDICT.name(), null, null, new HashMap<>(), null); } @Test @@ -66,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -78,7 +79,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); } @Test @@ -120,7 +121,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -132,7 +133,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); } @@ -168,14 +169,14 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() public void processOutput_NullResponse() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model response is null"); - ConnectorUtils.processOutput(null, null, null, null, null); + ConnectorUtils.processOutput(PREDICT.name(), null, null, null, null, null); } @Test public void processOutput_NoPostprocessFunction_jsonResponse() throws IOException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -192,7 +193,8 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); + ModelTensors tensors = ConnectorUtils + .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size()); @@ -206,7 +208,7 @@ public void processOutput_PostprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -224,7 +226,8 @@ public void processOutput_PostprocessFunction() throws IOException { .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); + ModelTensors tensors = ConnectorUtils + .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); @@ -246,7 +249,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction( ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody(requestBody) @@ -263,7 +266,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction( .actions(Arrays.asList(predictAction)) .build(); RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils - .processInput(mlInput, connector, new HashMap<>(), scriptService); + .processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index f0efd2efc2..8f920ffeba 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import java.lang.reflect.Field; import java.util.Arrays; @@ -47,10 +48,10 @@ public void setUp() { } @Test - public void invokeRemoteModel_WrongHttpMethod() { + public void invokeRemoteService_WrongHttpMethod() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("wrong_method") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -63,17 +64,17 @@ public void invokeRemoteModel_WrongHttpMethod() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(null, null, null, null, actionListener); + executor.invokeRemoteService(PREDICT.name(), null, null, null, null, actionListener); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); assertEquals("unsupported http method", captor.getValue().getMessage()); } @Test - public void invokeRemoteModel_invalidIpAddress() { + public void invokeRemoteService_invalidIpAddress() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://127.0.0.1/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -87,7 +88,14 @@ public void invokeRemoteModel_invalidIpAddress() { .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor - .invokeRemoteModel(createMLInput(), new HashMap<>(), "{\"input\": \"hello world\"}", new ExecutionContext(0), actionListener); + .invokeRemoteService( + PREDICT.name(), + createMLInput(), + new HashMap<>(), + "{\"input\": \"hello world\"}", + new ExecutionContext(0), + actionListener + ); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof IllegalArgumentException; @@ -95,10 +103,10 @@ public void invokeRemoteModel_invalidIpAddress() { } @Test - public void invokeRemoteModel_Empty_payload() { + public void invokeRemoteService_Empty_payload() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("") @@ -111,7 +119,7 @@ public void invokeRemoteModel_Empty_payload() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); + executor.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof IllegalArgumentException; @@ -119,10 +127,10 @@ public void invokeRemoteModel_Empty_payload() { } @Test - public void invokeRemoteModel_get_request() { + public void invokeRemoteService_get_request() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("GET") .url("http://openai.com/mock") .requestBody("") @@ -135,14 +143,14 @@ public void invokeRemoteModel_get_request() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); + executor.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); } @Test - public void invokeRemoteModel_post_request() { + public void invokeRemoteService_post_request() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("hello world") @@ -155,14 +163,15 @@ public void invokeRemoteModel_post_request() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); + executor + .invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); } @Test - public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException { + public void invokeRemoteService_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("hello world") @@ -178,7 +187,8 @@ public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFie Field httpClientField = HttpJsonConnectorExecutor.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); httpClientField.set(executor, null); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); + executor + .invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(argumentCaptor.capture()); assert argumentCaptor.getValue() instanceof NullPointerException; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index 11990e36d7..f6c9b76071 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.engine.algorithms.remote.MLSdkAsyncHttpResponseHandler.AMZ_ERROR_HEADER; import java.nio.ByteBuffer; @@ -59,6 +60,7 @@ public class MLSdkAsyncHttpResponseHandlerTest { private SdkHttpFullResponse sdkHttpResponse; @Mock private ScriptService scriptService; + private String action; private MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler; @@ -70,7 +72,7 @@ public void setup() { when(sdkHttpResponse.statusCode()).thenReturn(HttpStatusCode.OK); ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING) .url("http://test.com/mock") @@ -86,7 +88,7 @@ public void setup() { ConnectorAction noProcessFunctionPredictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -98,13 +100,15 @@ public void setup() { .protocol("http") .actions(Arrays.asList(noProcessFunctionPredictAction)) .build(); + action = PREDICT.name(); mlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler( executionContext, actionListener, parameters, connector, scriptService, - null + null, + action ); responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber(); headersMap = Map.of(AMZ_ERROR_HEADER, Arrays.asList("ThrottlingException:request throttled!")); @@ -171,7 +175,8 @@ public void test_OnStream_without_postProcessFunction() { parameters, noProcessFunctionConnector, scriptService, - null + null, + action ); noProcessFunctionMlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); noProcessFunctionMlSdkAsyncHttpResponseHandler.onStream(stream); @@ -261,7 +266,8 @@ public void test_onComplete_failed() { parameters, connector, scriptService, - null + null, + action ); SdkHttpFullResponse sdkHttpResponse = mock(SdkHttpFullResponse.class); @@ -357,7 +363,8 @@ public void test_onComplete_throttle_exception_onFailure() { parameters, connector, scriptService, - null + null, + action ); SdkHttpFullResponse sdkHttpResponse = mock(SdkHttpFullResponse.class); @@ -397,7 +404,8 @@ public void test_onComplete_processOutputFail_onFailure() { parameters, testConnector, scriptService, - null + null, + action ); mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); @@ -414,6 +422,6 @@ public void test_onComplete_processOutputFail_onFailure() { ArgumentCaptor captor = ArgumentCaptor.forClass(MLException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("Fail to execute predict in aws connector"); + assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector"); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java new file mode 100644 index 0000000000..05b3426f4d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +public class ConnectorToolTests { + + @Mock + private Client client; + private Map otherParams; + + @Mock + private Parser mockOutputParser; + + @Mock + private ActionListener listener; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + ConnectorTool.Factory.getInstance().init(client); + + otherParams = Map.of("other", "[\"bar\"]"); + } + + @Test + public void testConnectorTool_NullConnectorId() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> ConnectorTool.Factory.getInstance().create(Map.of("test1", "value1")) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("connector_id can't be null")); + } + + @Test + public void testConnectorTool_DefaultOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test_connector")); + tool.run(null, ActionListener.wrap(r -> { assertEquals("response 1", r); }, e -> { throw new RuntimeException("Test failed"); })); + } + + @Test + public void testConnectorTool_NullOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test_connector")); + tool.setOutputParser(null); + + tool.run(null, ActionListener.wrap(r -> { + List response = (List) r; + assertEquals(1, response.size()); + assertEquals(1, ((ModelTensors) response.get(0)).getMlModelTensors().size()); + ModelTensor modelTensor1 = ((ModelTensors) response.get(0)).getMlModelTensors().get(0); + assertEquals(2, modelTensor1.getDataAsMap().size()); + assertEquals("response 1", modelTensor1.getDataAsMap().get("response")); + assertEquals("action1", modelTensor1.getDataAsMap().get("action")); + }, e -> { throw new RuntimeException("Test failed"); })); + } + + @Test + public void testConnectorTool_NotNullParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertTrue(tool.validate(Map.of("key1", "value1"))); + } + + @Test + public void testConnectorTool_NullParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertFalse(tool.validate(Map.of())); + } + + @Test + public void testConnectorTool_EmptyParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertFalse(tool.validate(null)); + } + + @Test + public void testConnectorTool_GetType() { + ConnectorTool.Factory.getInstance().init(client); + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertEquals("ConnectorTool", tool.getType()); + } + + @Test + public void testRunWithError() { + // Mocking the client.execute to simulate an error + String errorMessage = "Test Exception"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(client).execute(any(), any(), any()); + + // Running the test + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + tool.setOutputParser(mockOutputParser); + tool.run(otherParams, listener); + + // Verifying that onFailure was called + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testTool() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertEquals(ConnectorTool.TYPE, tool.getName()); + assertEquals(ConnectorTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertTrue(tool.validate(otherParams)); + assertEquals(ConnectorTool.Factory.DEFAULT_DESCRIPTION, tool.getDescription()); + assertEquals(ConnectorTool.Factory.DEFAULT_DESCRIPTION, ConnectorTool.Factory.getInstance().getDefaultDescription()); + assertEquals(ConnectorTool.TYPE, ConnectorTool.Factory.getInstance().getDefaultType()); + assertNull(ConnectorTool.Factory.getInstance().getDefaultVersion()); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java new file mode 100644 index 0000000000..497e6768d8 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; + +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.script.ScriptService; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ExecuteConnectorTransportAction extends HandledTransportAction { + + Client client; + ClusterService clusterService; + ScriptService scriptService; + NamedXContentRegistry xContentRegistry; + + ConnectorAccessControlHelper connectorAccessControlHelper; + EncryptorImpl encryptor; + + @Inject + public ExecuteConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + ConnectorAccessControlHelper connectorAccessControlHelper, + EncryptorImpl encryptor + ) { + super(MLExecuteConnectorAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); + this.client = client; + this.clusterService = clusterService; + this.scriptService = scriptService; + this.xContentRegistry = xContentRegistry; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.encryptor = encryptor; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest(request); + String connectorId = executeConnectorRequest.getConnectorId(); + String connectorAction = ConnectorAction.ActionType.EXECUTE.name(); + + if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener.wrap(connector -> { + if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + connector.decrypt(connectorAction, (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor + .executeAction(connectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> { + actionListener.onResponse(taskResponse); + }, e -> { actionListener.onFailure(e); })); + } + }, e -> { + log.error("Failed to get connector " + connectorId, e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.runBefore(listener, threadContext::restore)); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + connectorId)); + } + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java index b50f935774..9337653fc4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java @@ -42,8 +42,8 @@ public TransportExecuteTaskAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - MLExecuteTaskRequest mlPredictionTaskRequest = MLExecuteTaskRequest.fromActionRequest(request); - FunctionName functionName = mlPredictionTaskRequest.getFunctionName(); - mlExecuteTaskRunner.run(functionName, mlPredictionTaskRequest, transportService, listener); + MLExecuteTaskRequest mlExecuteTaskRequest = MLExecuteTaskRequest.fromActionRequest(request); + FunctionName functionName = mlExecuteTaskRequest.getFunctionName(); + mlExecuteTaskRunner.run(functionName, mlExecuteTaskRequest, transportService, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 6c119d46d2..bde53795a3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; @@ -300,7 +301,9 @@ private void validateInternalConnector(MLRegisterModelInput registerModelInput) log.error("You must provide connector content when creating a remote model without providing connector id!"); throw new IllegalArgumentException("You must provide connector content when creating a remote model without connector id!"); } - if (registerModelInput.getConnector().getPredictEndpoint(registerModelInput.getConnector().getParameters()) == null) { + if (registerModelInput + .getConnector() + .getActionEndpoint(PREDICT.name(), registerModelInput.getConnector().getParameters()) == null) { log.error("Connector endpoint is required when creating a remote model without connector id!"); throw new IllegalArgumentException("Connector endpoint is required when creating a remote model without connector id!"); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 89b812b613..e9a79236b1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -44,6 +44,7 @@ import org.opensearch.ml.action.agents.TransportSearchAgentAction; import org.opensearch.ml.action.config.GetConfigTransportAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; +import org.opensearch.ml.action.connector.ExecuteConnectorTransportAction; import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; @@ -118,6 +119,7 @@ import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerGetAction; @@ -168,6 +170,7 @@ import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.engine.tools.AgentTool; import org.opensearch.ml.engine.tools.CatIndexTool; +import org.opensearch.ml.engine.tools.ConnectorTool; import org.opensearch.ml.engine.tools.IndexMappingTool; import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.engine.tools.SearchIndexTool; @@ -398,6 +401,7 @@ public MachineLearningPlugin(Settings settings) { new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class), new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class), new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class), + new ActionHandler<>(MLExecuteConnectorAction.INSTANCE, ExecuteConnectorTransportAction.class), new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class), new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class), new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class), @@ -579,6 +583,7 @@ public Collection createComponents( IndexMappingTool.Factory.getInstance().init(client); SearchIndexTool.Factory.getInstance().init(client, xContentRegistry); VisualizationsTool.Factory.getInstance().init(client); + ConnectorTool.Factory.getInstance().init(client); toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance()); toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance()); @@ -586,6 +591,7 @@ public Collection createComponents( toolFactories.put(IndexMappingTool.TYPE, IndexMappingTool.Factory.getInstance()); toolFactories.put(SearchIndexTool.TYPE, SearchIndexTool.Factory.getInstance()); toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance()); + toolFactories.put(ConnectorTool.TYPE, ConnectorTool.Factory.getInstance()); if (externalToolFactories != null) { toolFactories.putAll(externalToolFactories); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java new file mode 100644 index 0000000000..719f168ead --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.action.connector; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.script.ScriptService; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class ExecuteConnectorTransportActionTests extends OpenSearchTestCase { + + private ExecuteConnectorTransportAction action; + + @Mock + private Client client; + + @Mock + ActionListener actionListener; + @Mock + private ClusterService clusterService; + @Mock + private TransportService transportService; + @Mock + private ActionFilters actionFilters; + @Mock + private ScriptService scriptService; + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + private Metadata metaData; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLExecuteConnectorRequest request; + @Mock + private EncryptorImpl encryptor; + @Mock + private HttpConnector connector; + @Mock + private Task task; + @Mock + ThreadPool threadPool; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + ClusterState testState = new ClusterState( + new ClusterName("clusterName"), + 123l, + "111111", + metaData, + null, + null, + null, + Map.of(), + 0, + false + ); + when(clusterService.state()).thenReturn(testState); + + when(request.getConnectorId()).thenReturn("test_connector_id"); + + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + action = new ExecuteConnectorTransportAction( + transportService, + actionFilters, + client, + clusterService, + scriptService, + xContentRegistry, + connectorAccessControlHelper, + encryptor + ); + } + + public void testExecute_NoConnectorIndex() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assertTrue(argCaptor.getValue().getMessage().contains("Can't find connector test_connector_id")); + } + + public void testExecute_FailedToGetConnector() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(metaData.hasIndex(anyString())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("test failure")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assertTrue(argCaptor.getValue().getMessage().contains("test failure")); + } + + public void testExecute_NullMLInput() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(metaData.hasIndex(anyString())).thenReturn(true); + when(connector.getProtocol()).thenReturn(ConnectorProtocols.HTTP); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index d30ef15a5a..d936f199a2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -522,7 +522,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() { when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); Connector connector = mock(Connector.class); when(input.getConnector()).thenReturn(connector); - when(connector.getPredictEndpoint(any(Map.class))).thenReturn("https://api.openai.com"); + when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn("https://api.openai.com"); MLCreateConnectorResponse mlCreateConnectorResponse = mock(MLCreateConnectorResponse.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -556,7 +556,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi when(request.getRegisterModelInput()).thenReturn(input); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); Connector connector = mock(Connector.class); - when(connector.getPredictEndpoint(any(Map.class))).thenReturn(null); + when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn(null); when(input.getConnector()).thenReturn(connector); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index f5da656d06..cf1f87e09e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -961,6 +961,13 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt assertTrue(taskDone.get()); } + public String registerConnector(String createConnectorInput) throws IOException, InterruptedException { + Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + return connectorId; + } + public String registerRemoteModel(String createConnectorInput, String modelName, boolean deploy) throws IOException, InterruptedException { Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java new file mode 100644 index 0000000000..4ae9653d60 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; + +import org.apache.hc.core5.http.ParseException; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; + +public class RestConnectorToolIT extends RestBaseAgentToolsIT { + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private String bedrockClaudeConnectorId; + private String bedrockClaudeConnectorIdForPredict; + + @Before + public void setUp() throws Exception { + super.setUp(); + Thread.sleep(20000); + this.bedrockClaudeConnectorId = createBedrockClaudeConnector("execute"); + this.bedrockClaudeConnectorIdForPredict = createBedrockClaudeConnector("predict"); + } + + private String createBedrockClaudeConnector(String action) throws IOException, InterruptedException { + String bedrockClaudeConnectorEntity = "{\n" + + " \"name\": \"BedRock Claude instant-v1 Connector \",\n" + + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"" + + action + + "\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman:${parameters.question}\\\\n\\\\nAssistant:\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + return registerConnector(bedrockClaudeConnectorEntity); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + deleteExternalIndices(); + } + + public void testConnectorToolInFlowAgent_WrongAction() throws IOException, ParseException { + String registerAgentRequestBody = "{\n" + + " \"name\": \"Test agent with connector tool\",\n" + + " \"type\": \"flow\",\n" + + " \"description\": \"This is a demo agent for connector tool\",\n" + + " \"app_type\": \"test1\",\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"ConnectorTool\",\n" + + " \"name\": \"bedrock_model\",\n" + + " \"parameters\": {\n" + + " \"connector_id\": \"" + + bedrockClaudeConnectorIdForPredict + + "\",\n" + + " \"connector_action\": \"predict\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); + MatcherAssert.assertThat(exception.getMessage(), containsString("no EXECUTE action found")); + } + + public void testConnectorToolInFlowAgent() throws IOException, ParseException { + String registerAgentRequestBody = "{\n" + + " \"name\": \"Test agent with connector tool\",\n" + + " \"type\": \"flow\",\n" + + " \"description\": \"This is a demo agent for connector tool\",\n" + + " \"app_type\": \"test1\",\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"ConnectorTool\",\n" + + " \"name\": \"bedrock_model\",\n" + + " \"parameters\": {\n" + + " \"connector_id\": \"" + + bedrockClaudeConnectorId + + "\",\n" + + " \"connector_action\": \"execute\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + String result = executeAgent(agentId, agentInput); + assertNotNull(result); + } + +} From 2b98d203ad7f1e65dbc83f9f935a3b3b75579f06 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Mon, 10 Jun 2024 15:12:14 -0500 Subject: [PATCH 03/12] adding immediate refresh to delete model group request (#2514) * adding immediate refresh to delete model group request Signed-off-by: Bhavana Ramaram * fix format violations Signed-off-by: Bhavana Ramaram * add IT tests Signed-off-by: Bhavana Ramaram * remove thread sleep Signed-off-by: Bhavana Ramaram --------- Signed-off-by: Bhavana Ramaram --- .../DeleteModelGroupTransportAction.java | 4 +- .../rest/RestMLDeleteModelGroupActionIT.java | 51 +++++++++++++------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index 63e43f8cb2..d75a5668dc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -15,6 +15,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -68,7 +69,8 @@ public DeleteModelGroupTransportAction( protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLModelGroupDeleteRequest mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.fromActionRequest(request); String modelGroupId = mlModelGroupDeleteRequest.getModelGroupId(); - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId); + DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionIT.java index 0f658358aa..f5a665cf0c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionIT.java @@ -6,32 +6,53 @@ package org.opensearch.ml.rest; import java.io.IOException; -import java.util.Map; -import org.apache.hc.core5.http.HttpEntity; -import org.junit.Ignore; +import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.utils.TestHelper; public class RestMLDeleteModelGroupActionIT extends MLCommonsRestTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); + private MLRegisterModelGroupInput mlRegisterModelGroupInput; + private String modelGroupId; + + @Before + public void setup() throws IOException { + mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder().name("testGroupID").description("This is test Group").build(); + registerModelGroup(client(), TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { + this.modelGroupId = (String) registerModelGroupResult.get("model_group_id"); + }); + } - @Ignore public void testDeleteModelGroupAPI_Success() throws IOException { - Response trainModelGroupResponse = ingestModelData(); - HttpEntity entity = trainModelGroupResponse.getEntity(); - assertNotNull(trainModelGroupResponse); - String entityString = TestHelper.httpEntityToString(entity); - Map map = gson.fromJson(entityString, Map.class); - String model_group_id = (String) map.get("model_group_id"); - - Response deleteModelResponse = TestHelper - .makeRequest(client(), "DELETE", "/_plugins/_ml/model_groups/" + model_group_id, null, "", null); - assertNotNull(deleteModelResponse); - assertEquals(RestStatus.OK, TestHelper.restStatus(deleteModelResponse)); + + Response deleteModelGroupResponse = TestHelper + .makeRequest(client(), "DELETE", "/_plugins/_ml/model_groups/" + modelGroupId, null, "", null); + assertNotNull(deleteModelGroupResponse); + assertEquals(RestStatus.OK, TestHelper.restStatus(deleteModelGroupResponse)); + } + + public void testDeleteAndRegisterModelGroup_Success() throws IOException { + + Response deleteModelGroupResponse = TestHelper + .makeRequest(client(), "DELETE", "/_plugins/_ml/model_groups/" + modelGroupId, null, "", null); + + if (TestHelper.restStatus(deleteModelGroupResponse).equals(RestStatus.OK)) { + MLRegisterModelGroupInput newMlRegisterModelGroupInput = MLRegisterModelGroupInput + .builder() + .name("testGroupID") + .description("This is a new test Group") + .build(); + + registerModelGroup(client(), TestHelper.toJsonString(newMlRegisterModelGroupInput), registerModelGroupResponse -> { + assertNotNull(registerModelGroupResponse); + assertEquals("CREATED", registerModelGroupResponse.get("status")); + }); + } } } From 954e8b3b6f5df733b61414139cc44ce12e871268 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Mon, 10 Jun 2024 14:09:56 -0700 Subject: [PATCH 04/12] [Backport main] Remove strict version dependency to compile minimum compatible version (#2486) * [Feature/multi_tenancy] Remove strict version dependency to compile minimum compatible version (#2485) * Remove strict version dependency to compile minimum compatible version Signed-off-by: Daniel Widdis * Only declare version constants once Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis Signed-off-by: Sicheng Song * Remove unnecessary dependency Signed-off-by: Sicheng Song * remove space Signed-off-by: Sicheng Song --------- Signed-off-by: Daniel Widdis Signed-off-by: Sicheng Song Co-authored-by: Daniel Widdis --- .../main/java/org/opensearch/ml/common/CommonValue.java | 6 ++++++ .../java/org/opensearch/ml/common/agent/MLAgent.java | 3 ++- .../ml/common/dataset/TextDocsInputDataSet.java | 3 ++- .../org/opensearch/ml/common/model/MLDeploySetting.java | 3 ++- .../transport/connector/MLCreateConnectorInput.java | 3 ++- .../common/transport/register/MLRegisterModelInput.java | 9 +++++---- .../java/org/opensearch/ml/common/agent/MLAgentTest.java | 7 ++++--- .../transport/connector/MLCreateConnectorInputTests.java | 3 ++- plugin/build.gradle | 2 +- 9 files changed, 26 insertions(+), 13 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 2f22dac12b..305cbaa6f7 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -5,6 +5,7 @@ package org.opensearch.ml.common; +import org.opensearch.Version; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.controller.MLController; @@ -527,4 +528,9 @@ public class CommonValue { + "\": {\"type\": \"long\"}\n" + " }\n" + "}"; + // Calculate Versions independently of OpenSearch core version + public static final Version VERSION_2_11_0 = Version.fromString("2.11.0"); + public static final Version VERSION_2_12_0 = Version.fromString("2.12.0"); + public static final Version VERSION_2_13_0 = Version.fromString("2.13.0"); + public static final Version VERSION_2_14_0 = Version.fromString("2.14.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index f068863b1d..a7a67d2e00 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -15,6 +15,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.MLModel; @@ -47,7 +48,7 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String APP_TYPE_FIELD = "app_type"; public static final String IS_HIDDEN_FIELD = "is_hidden"; - private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = Version.V_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = CommonValue.VERSION_2_13_0; private String name; private String type; diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java index 37b58b84ff..34cc561ace 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java @@ -12,6 +12,7 @@ import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -29,7 +30,7 @@ public class TextDocsInputDataSet extends MLInputDataset{ private List docs; - private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL = Version.V_2_11_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL = CommonValue.VERSION_2_11_0; @Builder(toBuilder = true) public TextDocsInputDataSet(List docs, ModelResultFilter resultFilter) { diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java index 4bd864b237..23c81d9ead 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.CommonValue; import java.io.IOException; @@ -27,7 +28,7 @@ public class MLDeploySetting implements ToXContentObject, Writeable { public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled"; public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes"; private static final long DEFAULT_TTL_MINUTES = -1; - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = Version.V_2_14_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = CommonValue.VERSION_2_14_0; private Boolean isAutoDeployEnabled; private Long modelTTLInMinutes; // in minutes diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 78b1ed4af2..007b65e286 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorClientConfig; @@ -46,7 +47,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { public static final String ACCESS_MODE_FIELD = "access_mode"; public static final String DRY_RUN_FIELD = "dry_run"; - private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = Version.V_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_13_0; public static final String DRY_RUN_CONNECTOR_NAME = "dryRunConnector"; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index bffa04328b..2db0d47a8c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -15,6 +15,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; @@ -68,10 +69,10 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; public static final String GUARDRAILS_FIELD = "guardrails"; - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0; - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = Version.V_2_12_0; - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS_AND_AUTO_DEPLOY = Version.V_2_13_0; - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE = Version.V_2_14_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = CommonValue.VERSION_2_11_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = CommonValue.VERSION_2_12_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS_AND_AUTO_DEPLOY = CommonValue.VERSION_2_13_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE = CommonValue.VERSION_2_14_0; private FunctionName functionName; private String modelName; diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index 34e03c8419..8b1a96e07b 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -18,6 +18,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; @@ -196,7 +197,7 @@ public void constructor_NonConversationalNoLLM() { public void writeTo_ReadFrom_HiddenFlag_VersionCompatibility() throws IOException { MLAgent agent = new MLAgent("test", "FLOW", "test", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", true); BytesStreamOutput output = new BytesStreamOutput(); - Version oldVersion = Version.fromString("2.12.0"); + Version oldVersion = CommonValue.VERSION_2_12_0; output.setVersion(oldVersion); // Version before MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT agent.writeTo(output); @@ -206,10 +207,10 @@ public void writeTo_ReadFrom_HiddenFlag_VersionCompatibility() throws IOExceptio assertNull(agentOldVersion.getIsHidden()); // Hidden should be null for old versions output = new BytesStreamOutput(); - output.setVersion(Version.V_2_13_0); // Version at or after MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT + output.setVersion(CommonValue.VERSION_2_13_0); // Version at or after MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT agent.writeTo(output); StreamInput streamInput1 = output.bytes().streamInput(); - streamInput1.setVersion(Version.V_2_13_0); + streamInput1.setVersion(CommonValue.VERSION_2_13_0); MLAgent agentNewVersion = new MLAgent(output.bytes().streamInput()); assertEquals(Boolean.TRUE, agentNewVersion.getIsHidden()); // Hidden should be true for new versions } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index f42d3afc7a..8eb885ade2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -21,6 +21,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorClientConfig; import org.opensearch.ml.common.connector.MLPostProcessFunction; @@ -256,7 +257,7 @@ public void testWriteToVersionCompatibility() throws IOException { MLCreateConnectorInput input = mlCreateConnectorInput; // Assuming mlCreateConnectorInput is already initialized // Simulate an older version of OpenSearch that does not support connectorClientConfig - Version oldVersion = Version.fromString("2.12.0"); // Change this as per your old version + Version oldVersion = CommonValue.VERSION_2_12_0; // Change this as per your old version BytesStreamOutput output = new BytesStreamOutput(); output.setVersion(oldVersion); diff --git a/plugin/build.gradle b/plugin/build.gradle index 6d2986358e..4921d73acc 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -285,7 +285,7 @@ List jacocoExclusions = [ 'org.opensearch.ml.profile.MLPredictRequestStats', 'org.opensearch.ml.action.deploy.TransportDeployModelAction', 'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction', - 'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction', + 'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction', 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction', 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1', 'org.opensearch.ml.action.tasks.GetTaskTransportAction', From a43ba9594cbb0466d215f3206704f59c408fb678 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 10 Jun 2024 15:57:31 -0700 Subject: [PATCH 05/12] guardrails model support (#2491) * guardrails model support Signed-off-by: Jing Zhang * add IT for remote guardrails model Signed-off-by: Jing Zhang * address comments Signed-off-by: Jing Zhang * address more comments 1 Signed-off-by: Jing Zhang --------- Signed-off-by: Jing Zhang --- .../org/opensearch/ml/common/CommonValue.java | 4 + .../opensearch/ml/common/model/Guardrail.java | 96 +----- .../ml/common/model/Guardrails.java | 61 +++- .../ml/common/model/LocalRegexGuardrail.java | 277 +++++++++++++++++ .../opensearch/ml/common/model/MLGuard.java | 163 +--------- .../ml/common/model/ModelGuardrail.java | 208 +++++++++++++ .../opensearch/ml/common/model/StopWords.java | 9 + .../ml/common/model/GuardrailTests.java | 69 ----- .../ml/common/model/GuardrailsTests.java | 24 +- .../model/LocalRegexGuardrailTests.java | 292 ++++++++++++++++++ .../ml/common/model/MLGuardTests.java | 210 +------------ .../ml/common/model/ModelGuardrailTests.java | 101 ++++++ .../algorithms/remote/ConnectorUtils.java | 8 +- .../remote/RemoteConnectorExecutor.java | 2 +- .../ml/rest/RestMLGuardrailsIT.java | 281 ++++++++++++++++- 15 files changed, 1261 insertions(+), 544 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java create mode 100644 common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java delete mode 100644 common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 305cbaa6f7..422467241b 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common; +import com.google.common.collect.ImmutableSet; import org.opensearch.Version; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.controller.MLController; +import java.util.Set; + import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD; @@ -71,6 +74,7 @@ public class CommonValue { public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; + public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java index d690fdce7f..598121b8ed 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java @@ -5,101 +5,19 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.client.Client; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.Map; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +public abstract class Guardrail implements ToXContentObject { -@EqualsAndHashCode -@Getter -public class Guardrail implements ToXContentObject { - public static final String STOP_WORDS_FIELD = "stop_words"; - public static final String REGEX_FIELD = "regex"; + public abstract void writeTo(StreamOutput out) throws IOException; - private List stopWords; - private String[] regex; + public abstract Boolean validate(String input, Map parameters); - @Builder(toBuilder = true) - public Guardrail(List stopWords, String[] regex) { - this.stopWords = stopWords; - this.regex = regex; - } - - public Guardrail(StreamInput input) throws IOException { - if (input.readBoolean()) { - stopWords = new ArrayList<>(); - int size = input.readInt(); - for (int i=0; i 0) { - out.writeBoolean(true); - out.writeInt(stopWords.size()); - for (StopWords e : stopWords) { - e.writeTo(out); - } - } else { - out.writeBoolean(false); - } - out.writeStringArray(regex); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (stopWords != null && stopWords.size() > 0) { - builder.field(STOP_WORDS_FIELD, stopWords); - } - if (regex != null) { - builder.field(REGEX_FIELD, regex); - } - builder.endObject(); - return builder; - } - - public static Guardrail parse(XContentParser parser) throws IOException { - List stopWords = null; - String[] regex = null; - - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - - switch (fieldName) { - case STOP_WORDS_FIELD: - stopWords = new ArrayList<>(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - stopWords.add(StopWords.parse(parser)); - } - break; - case REGEX_FIELD: - regex = parser.list().toArray(new String[0]); - break; - default: - parser.skipChildren(); - break; - } - } - return Guardrail.builder() - .stopWords(stopWords) - .regex(regex) - .build(); - } + public abstract void init(NamedXContentRegistry xContentRegistry, Client client); } diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java index 1153262935..db7558b7cc 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java @@ -15,6 +15,8 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.util.Map; +import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -24,6 +26,7 @@ public class Guardrails implements ToXContentObject { public static final String TYPE_FIELD = "type"; public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail"; public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail"; + public static final Set types = Set.of("local_regex", "model"); private String type; private Guardrail inputGuardrail; @@ -39,10 +42,26 @@ public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardra public Guardrails(StreamInput input) throws IOException { type = input.readString(); if (input.readBoolean()) { - inputGuardrail = new Guardrail(input); + switch (type) { + case "local_regex": + inputGuardrail = new LocalRegexGuardrail(input); + break; + case "model": + break; + default: + throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type)); + } } if (input.readBoolean()) { - outputGuardrail = new Guardrail(input); + switch (type) { + case "local_regex": + outputGuardrail = new LocalRegexGuardrail(input); + break; + case "model": + break; + default: + throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type)); + } } } @@ -80,8 +99,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static Guardrails parse(XContentParser parser) throws IOException { String type = null; - Guardrail inputGuardrail = null; - Guardrail outputGuardrail = null; + Map inputGuardrailMap = null; + Map outputGuardrailMap = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -93,20 +112,46 @@ public static Guardrails parse(XContentParser parser) throws IOException { type = parser.text(); break; case INPUT_GUARDRAIL_FIELD: - inputGuardrail = Guardrail.parse(parser); + inputGuardrailMap = parser.map(); break; case OUTPUT_GUARDRAIL_FIELD: - outputGuardrail = Guardrail.parse(parser); + outputGuardrailMap = parser.map(); break; default: parser.skipChildren(); break; } } + if (!validateType(type)) { + throw new IllegalArgumentException("The type of guardrails is required, can not be null."); + } + return Guardrails.builder() .type(type) - .inputGuardrail(inputGuardrail) - .outputGuardrail(outputGuardrail) + .inputGuardrail(createGuardrail(type, inputGuardrailMap)) + .outputGuardrail(createGuardrail(type, outputGuardrailMap)) .build(); } + + private static Boolean validateType(String type) { + if (types.contains(type)) { + return true; + } + return false; + } + + private static Guardrail createGuardrail(String type, Map params) { + if (params == null || params.isEmpty()) { + return null; + } + + switch (type) { + case "local_regex": + return new LocalRegexGuardrail(params); + case "model": + return new ModelGuardrail(params); + default: + throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type)); + } + } } diff --git a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java new file mode 100644 index 0000000000..0f142bde3b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java @@ -0,0 +1,277 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.stopWordsIndices; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@Log4j2 +@EqualsAndHashCode +@Getter +public class LocalRegexGuardrail extends Guardrail { + public static final String STOP_WORDS_FIELD = "stop_words"; + public static final String REGEX_FIELD = "regex"; + + private List stopWords; + private String[] regex; + private List regexPattern; + private Map> stopWordsIndicesInput; + private NamedXContentRegistry xContentRegistry; + private Client client; + + @Builder(toBuilder = true) + public LocalRegexGuardrail(List stopWords, String[] regex) { + this.stopWords = stopWords; + this.regex = regex; + } + public LocalRegexGuardrail(@NonNull Map params) { + List words = (List) params.get(STOP_WORDS_FIELD); + stopWords = new ArrayList<>(); + if (words != null && !words.isEmpty()) { + for (Map e : words) { + stopWords.add(new StopWords(e)); + } + } + List regexes = (List) params.get(REGEX_FIELD); + if (regexes != null && !regexes.isEmpty()) { + this.regex = regexes.toArray(new String[0]); + } + } + + public LocalRegexGuardrail(StreamInput input) throws IOException { + if (input.readBoolean()) { + stopWords = new ArrayList<>(); + int size = input.readInt(); + for (int i=0; i 0) { + out.writeBoolean(true); + out.writeInt(stopWords.size()); + for (StopWords e : stopWords) { + e.writeTo(out); + } + } else { + out.writeBoolean(false); + } + out.writeOptionalStringArray(regex); + } + + @Override + public Boolean validate(String input, Map parameters) { + return validateRegexList(input, regexPattern) && validateStopWords(input, stopWordsIndicesInput); + } + + @Override + public void init(NamedXContentRegistry xContentRegistry, Client client) { + this.xContentRegistry = xContentRegistry; + this.client = client; + init(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (stopWords != null && stopWords.size() > 0) { + builder.field(STOP_WORDS_FIELD, stopWords); + } + if (regex != null) { + builder.field(REGEX_FIELD, regex); + } + builder.endObject(); + return builder; + } + + public static LocalRegexGuardrail parse(XContentParser parser) throws IOException { + List stopWords = null; + String[] regex = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case STOP_WORDS_FIELD: + stopWords = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + stopWords.add(StopWords.parse(parser)); + } + break; + case REGEX_FIELD: + regex = parser.list().toArray(new String[0]); + break; + default: + parser.skipChildren(); + break; + } + } + return LocalRegexGuardrail.builder() + .stopWords(stopWords) + .regex(regex) + .build(); + } + + private void init() { + stopWordsIndicesInput = stopWordsToMap(); + List regexList = regex == null ? new ArrayList<>() : Arrays.asList(regex); + regexPattern = regexList.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); + } + + private Map> stopWordsToMap() { + Map> map = new HashMap<>(); + if (stopWords != null && !stopWords.isEmpty()) { + for (StopWords e : stopWords) { + if (e.getIndex() != null && e.getSourceFields() != null) { + map.put(e.getIndex(), Arrays.asList(e.getSourceFields())); + } + } + } + return map; + } + + public Boolean validateRegexList(String input, List regexPatterns) { + if (regexPatterns == null || regexPatterns.isEmpty()) { + return true; + } + for (Pattern pattern : regexPatterns) { + if (!validateRegex(input, pattern)) { + return false; + } + } + return true; + } + + public Boolean validateRegex(String input, Pattern pattern) { + Matcher matcher = pattern.matcher(input); + return !matcher.matches(); + } + + public Boolean validateStopWords(String input, Map> stopWordsIndices) { + if (stopWordsIndices == null || stopWordsIndices.isEmpty()) { + return true; + } + for (Map.Entry entry : stopWordsIndices.entrySet()) { + if (!validateStopWordsSingleIndex(input, (String) entry.getKey(), (List) entry.getValue())) { + return false; + } + } + return true; + } + + /** + * Validate the input string against stop words + * @param input the string to validate against stop words + * @param indexName the index containing stop words + * @param fieldNames a list of field names containing stop words + * @return true if no stop words matching, otherwise false. + */ + public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) { + SearchRequest searchRequest; + AtomicBoolean hitStopWords = new AtomicBoolean(false); + String queryBody; + Map documentMap = new HashMap<>(); + for (String field : fieldNames) { + documentMap.put(field, input); + } + Map queryBodyMap = Map + .of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); + CountDownLatch latch = new CountDownLatch(1); + ThreadContext.StoredContext context = null; + + try { + queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.size(1); //Only need 1 doc returned, if hit. + searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); + if (isStopWordsSystemIndex(indexName)) { + context = client.threadPool().getThreadContext().stashContext(); + ThreadContext.StoredContext finalContext = context; + client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + hitStopWords.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); + hitStopWords.set(true); + }), latch), () -> finalContext.restore())); + } else { + client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + hitStopWords.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); + hitStopWords.set(true); + }), latch)); + } + } catch (Exception e) { + log.error("[validateStopWords] Searching stop words index failed.", e); + latch.countDown(); + hitStopWords.set(true); + } finally { + if (context != null) { + context.close(); + } + } + + try { + latch.await(5, SECONDS); + } catch (InterruptedException e) { + log.error("[validateStopWords] Searching stop words index was timeout.", e); + throw new IllegalStateException(e); + } + return hitStopWords.get(); + } + + private boolean isStopWordsSystemIndex(String index) { + return stopWordsIndices.contains(index); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java index dcb0e65ad7..3aa8d060b1 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java @@ -5,188 +5,43 @@ package org.opensearch.ml.common.model; -import com.google.common.collect.ImmutableSet; import lombok.Getter; -import lombok.NonNull; import lombok.extern.log4j.Log4j2; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.builder.SearchSourceBuilder; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; - -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.opensearch.ml.common.utils.StringUtils.gson; @Log4j2 @Getter public class MLGuard { - private Map> stopWordsIndicesInput = new HashMap<>(); - private Map> stopWordsIndicesOutput = new HashMap<>(); - private List inputRegex; - private List outputRegex; - private List inputRegexPattern; - private List outputRegexPattern; private NamedXContentRegistry xContentRegistry; private Client client; - private Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); + private Guardrails guardrails; public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) { this.xContentRegistry = xContentRegistry; this.client = client; - if (guardrails == null) { - return; - } - Guardrail inputGuardrail = guardrails.getInputGuardrail(); - Guardrail outputGuardrail = guardrails.getOutputGuardrail(); - if (inputGuardrail != null) { - fillStopWordsToMap(inputGuardrail, stopWordsIndicesInput); - inputRegex = inputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(inputGuardrail.getRegex()); - inputRegexPattern = inputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); + this.guardrails = guardrails; + if (this.guardrails != null && this.guardrails.getInputGuardrail() != null) { + this.guardrails.getInputGuardrail().init(xContentRegistry, client); } - if (outputGuardrail != null) { - fillStopWordsToMap(outputGuardrail, stopWordsIndicesOutput); - outputRegex = outputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(outputGuardrail.getRegex()); - outputRegexPattern = outputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); + if (this.guardrails != null && this.guardrails.getOutputGuardrail() != null) { + this.guardrails.getOutputGuardrail().init(xContentRegistry, client); } } - private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map> map) { - List stopWords = guardrail.getStopWords(); - if (stopWords == null || stopWords.isEmpty()) { - return; - } - for (StopWords e : stopWords) { - if (e.getIndex() != null && e.getSourceFields() != null) { - map.put(e.getIndex(), Arrays.asList(e.getSourceFields())); - } - } - } - - public Boolean validate(String input, Type type) { + public Boolean validate(String input, Type type, Map parameters) { switch (type) { case INPUT: // validate input - return validateRegexList(input, inputRegexPattern) && validateStopWords(input, stopWordsIndicesInput); + return guardrails.getInputGuardrail() == null ? true : guardrails.getInputGuardrail().validate(input, parameters); case OUTPUT: // validate output - return validateRegexList(input, outputRegexPattern) && validateStopWords(input, stopWordsIndicesOutput); + return guardrails.getOutputGuardrail() == null ? true : guardrails.getOutputGuardrail().validate(input, parameters); default: throw new IllegalArgumentException("Unsupported type to validate for guardrails."); } } - public Boolean validateRegexList(String input, List regexPatterns) { - if (regexPatterns == null || regexPatterns.isEmpty()) { - return true; - } - for (Pattern pattern : regexPatterns) { - if (!validateRegex(input, pattern)) { - return false; - } - } - return true; - } - - public Boolean validateRegex(String input, Pattern pattern) { - Matcher matcher = pattern.matcher(input); - return !matcher.matches(); - } - - public Boolean validateStopWords(String input, Map> stopWordsIndices) { - if (stopWordsIndices == null || stopWordsIndices.isEmpty()) { - return true; - } - for (Map.Entry entry : stopWordsIndices.entrySet()) { - if (!validateStopWordsSingleIndex(input, (String) entry.getKey(), (List) entry.getValue())) { - return false; - } - } - return true; - } - - public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) { - SearchRequest searchRequest; - AtomicBoolean hitStopWords = new AtomicBoolean(false); - String queryBody; - Map documentMap = new HashMap<>(); - for (String field : fieldNames) { - documentMap.put(field, input); - } - Map queryBodyMap = Map - .of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); - CountDownLatch latch = new CountDownLatch(1); - ThreadContext.StoredContext context = null; - - try { - queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); - searchSourceBuilder.parseXContent(queryParser); - searchSourceBuilder.size(1); //Only need 1 doc returned, if hit. - searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); - if (isStopWordsSystemIndex(indexName)) { - context = client.threadPool().getThreadContext().stashContext(); - ThreadContext.StoredContext finalContext = context; - client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { - if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { - hitStopWords.set(true); - } - }, e -> { - log.error("Failed to search stop words index {}", indexName, e); - hitStopWords.set(true); - }), latch), () -> finalContext.restore())); - } else { - client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> { - if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { - hitStopWords.set(true); - } - }, e -> { - log.error("Failed to search stop words index {}", indexName, e); - hitStopWords.set(true); - }), latch)); - } - } catch (Exception e) { - log.error("[validateStopWords] Searching stop words index failed.", e); - latch.countDown(); - hitStopWords.set(true); - } finally { - if (context != null) { - context.close(); - } - } - - try { - latch.await(5, SECONDS); - } catch (InterruptedException e) { - log.error("[validateStopWords] Searching stop words index was timeout.", e); - throw new IllegalStateException(e); - } - return hitStopWords.get(); - } - - private boolean isStopWordsSystemIndex(String index) { - return stopWordsIndices.contains(index); - } - public enum Type { INPUT, OUTPUT diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java new file mode 100644 index 0000000000..07d75a32ce --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -0,0 +1,208 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@Log4j2 +@EqualsAndHashCode +@Getter +public class ModelGuardrail extends Guardrail { + public static final String MODEL_ID_FIELD = "model_id"; + public static final String RESPONSE_FILTER_FIELD = "response_filter"; + public static final String RESPONSE_ACCEPT_FIELD = "response_accept"; + + private String modelId; + private String responseFilter; + private String responseAccept; + private NamedXContentRegistry xContentRegistry; + private Client client; + private Pattern regexAcceptPattern; + + @Builder(toBuilder = true) + public ModelGuardrail(String modelId, String responseFilter, String responseAccept) { + this.modelId = modelId; + this.responseFilter = responseFilter; + this.responseAccept = responseAccept; + } + public ModelGuardrail(@NonNull Map params) { + this((String) params.get(MODEL_ID_FIELD), (String) params.get(RESPONSE_FILTER_FIELD), (String) params.get(RESPONSE_ACCEPT_FIELD)); + } + + public ModelGuardrail(StreamInput input) throws IOException { + modelId = input.readString(); + responseFilter = input.readString(); + responseAccept = input.readString(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(responseFilter); + out.writeString(responseAccept); + } + + private Boolean validateAcceptRegex(String input) { + Matcher matcher = regexAcceptPattern.matcher(input); + return matcher.matches(); + } + + @Override + public Boolean validate(String in, Map parameters) { + String input = parameters == null ? null : parameters.get("question"); + if (input == null || input.isEmpty()) { + log.info("Guardrail request is empty."); + return true; + } + log.info("Guardrail request: {}", input); + AtomicBoolean isAccepted = new AtomicBoolean(true); + ActionListener internalListener = ActionListener.wrap(predictionResponse -> { + ModelTensorOutput output = (ModelTensorOutput) predictionResponse.getOutput(); + ModelTensor tensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0); + String guardrailResponse = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(tensor.getDataAsMap().get("response"))); + log.info("Guardrail response: {}", guardrailResponse); + if (!validateAcceptRegex(guardrailResponse)) { + isAccepted.set(false); + } + }, e -> {log.error("[ModelGuardrail] Failed to get prediction response.", e);}); + ActionListener actionListener = wrapActionListener(internalListener, res -> { + MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res); + return predictionResponse; + }); + CountDownLatch latch = new CountDownLatch(1); + Map guardrailModelParams = new HashMap<>(); + guardrailModelParams.put("question", input); + if (responseFilter != null && !responseFilter.isEmpty()) { + guardrailModelParams.put("response_filter", responseFilter); + } + log.info("Guardrail resFilter: {}", responseFilter); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()) + .build() + ); + client + .execute( + MLPredictionTaskAction.INSTANCE, + request, + new LatchedActionListener(actionListener, latch) + ); + try { + latch.await(5, SECONDS); + } catch (InterruptedException e) { + log.error("[ModelGuardrail] Validation was timeout.", e); + } + + return isAccepted.get(); + } + + @Override + public void init(NamedXContentRegistry xContentRegistry, Client client) { + this.xContentRegistry = xContentRegistry; + this.client = client; + regexAcceptPattern = Pattern.compile(responseAccept); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } + if (responseFilter != null) { + builder.field(RESPONSE_FILTER_FIELD, responseFilter); + } + if (responseAccept != null) { + builder.field(RESPONSE_ACCEPT_FIELD, responseAccept); + } + builder.endObject(); + return builder; + } + + public static ModelGuardrail parse(XContentParser parser) throws IOException { + String modelId = null; + String responseFilter = null; + String responseAccept = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case RESPONSE_FILTER_FIELD: + responseFilter = parser.text(); + break; + case RESPONSE_ACCEPT_FIELD: + responseAccept = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return ModelGuardrail.builder() + .modelId(modelId) + .responseFilter(responseFilter) + .responseAccept(responseAccept) + .build(); + } + + private ActionListener wrapActionListener( + final ActionListener listener, + final Function recreate + ) { + ActionListener actionListener = ActionListener.wrap(r -> { + listener.onResponse(recreate.apply(r)); + ; + }, e -> { listener.onFailure(e); }); + return actionListener; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/StopWords.java b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java index 19307b398d..648f465891 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/StopWords.java +++ b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; @@ -15,6 +16,8 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.util.List; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -33,6 +36,12 @@ public StopWords(String index, String[] sourceFields) { this.sourceFields = sourceFields; } + public StopWords(@NonNull Map params) { + List fields = (List) params.get(SOURCE_FIELDS_FIELD); + this.index = (String) params.get(INDEX_NAME_FIELD); + this.sourceFields = fields == null ? null : fields.toArray(new String[0]); + } + public StopWords(StreamInput input) throws IOException { index = input.readString(); sourceFields = input.readStringArray(); diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java deleted file mode 100644 index b6b140d119..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.model; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.TestHelper; -import org.opensearch.search.SearchModule; - -import java.io.IOException; -import java.util.Collections; -import java.util.List; - -import static org.junit.Assert.*; - -public class GuardrailTests { - StopWords stopWords; - String[] regex; - - @Before - public void setUp() { - stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); - regex = List.of("regex1").toArray(new String[0]); - } - - @Test - public void writeTo() throws IOException { - Guardrail guardrail = new Guardrail(List.of(stopWords), regex); - BytesStreamOutput output = new BytesStreamOutput(); - guardrail.writeTo(output); - Guardrail guardrail1 = new Guardrail(output.bytes().streamInput()); - - Assert.assertArrayEquals(guardrail.getStopWords().toArray(), guardrail1.getStopWords().toArray()); - Assert.assertArrayEquals(guardrail.getRegex(), guardrail1.getRegex()); - } - - @Test - public void toXContent() throws IOException { - Guardrail guardrail = new Guardrail(List.of(stopWords), regex); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - guardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); - String content = TestHelper.xContentBuilderToString(builder); - - Assert.assertEquals("{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}", content); - } - - @Test - public void parse() throws IOException { - String jsonStr = "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); - parser.nextToken(); - Guardrail guardrail = Guardrail.parse(parser); - - Assert.assertArrayEquals(guardrail.getStopWords().toArray(), List.of(stopWords).toArray()); - Assert.assertArrayEquals(guardrail.getRegex(), regex); - } -} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java index dc0c3d116c..a1b589d07c 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java @@ -22,25 +22,23 @@ import java.util.Collections; import java.util.List; -import static org.junit.Assert.*; - public class GuardrailsTests { StopWords stopWords; String[] regex; - Guardrail inputGuardrail; - Guardrail outputGuardrail; + LocalRegexGuardrail inputLocalRegexGuardrail; + LocalRegexGuardrail outputLocalRegexGuardrail; @Before public void setUp() { stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); regex = List.of("regex1").toArray(new String[0]); - inputGuardrail = new Guardrail(List.of(stopWords), regex); - outputGuardrail = new Guardrail(List.of(stopWords), regex); + inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); } @Test public void writeTo() throws IOException { - Guardrails guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + Guardrails guardrails = new Guardrails("local_regex", inputLocalRegexGuardrail, outputLocalRegexGuardrail); BytesStreamOutput output = new BytesStreamOutput(); guardrails.writeTo(output); Guardrails guardrails1 = new Guardrails(output.bytes().streamInput()); @@ -52,12 +50,12 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { - Guardrails guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + Guardrails guardrails = new Guardrails("local_regex", inputLocalRegexGuardrail, outputLocalRegexGuardrail); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); guardrails.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"type\":\"test_type\"," + + Assert.assertEquals("{\"type\":\"local_regex\"," + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}", content); @@ -65,7 +63,7 @@ public void toXContent() throws IOException { @Test public void parse() throws IOException { - String jsonStr = "{\"type\":\"test_type\"," + + String jsonStr = "{\"type\":\"local_regex\"," + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}"; XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, @@ -73,8 +71,8 @@ public void parse() throws IOException { parser.nextToken(); Guardrails guardrails = Guardrails.parse(parser); - Assert.assertEquals(guardrails.getType(), "test_type"); - Assert.assertEquals(guardrails.getInputGuardrail(), inputGuardrail); - Assert.assertEquals(guardrails.getOutputGuardrail(), outputGuardrail); + Assert.assertEquals(guardrails.getType(), "local_regex"); + Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail); + Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail); } } \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java new file mode 100644 index 0000000000..6c0cdfb1ef --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java @@ -0,0 +1,292 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.apache.lucene.search.TotalHits; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchModule; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.suggest.Suggest; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class LocalRegexGuardrailTests { + NamedXContentRegistry xContentRegistry; + @Mock + Client client; + @Mock + ThreadPool threadPool; + ThreadContext threadContext; + + StopWords stopWords; + String[] regex; + List regexPatterns; + LocalRegexGuardrail localRegexGuardrail; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + Settings settings = Settings.builder().build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + + stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); + regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); + localRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + } + + @Test + public void readWriteEmptyContent() throws IOException { + LocalRegexGuardrail localRegexGuardrail = new LocalRegexGuardrail(Collections.emptyList(), new String[0]); + BytesStreamOutput output = new BytesStreamOutput(); + localRegexGuardrail.writeTo(output); + LocalRegexGuardrail localRegexGuardrail1 = new LocalRegexGuardrail(output.bytes().streamInput()); + + Assert.assertNull(localRegexGuardrail1.getStopWords()); + Assert.assertArrayEquals(Collections.emptyList().toArray(), localRegexGuardrail1.getRegex()); + } + + @Test + public void readWriteNullContent() throws IOException { + LocalRegexGuardrail localRegexGuardrail = new LocalRegexGuardrail(null, null); + BytesStreamOutput output = new BytesStreamOutput(); + localRegexGuardrail.writeTo(output); + LocalRegexGuardrail localRegexGuardrail1 = new LocalRegexGuardrail(output.bytes().streamInput()); + + Assert.assertNull(localRegexGuardrail1.getStopWords()); + Assert.assertNull(localRegexGuardrail1.getRegex()); + } + + @Test + public void writeTo() throws IOException { + LocalRegexGuardrail localRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + BytesStreamOutput output = new BytesStreamOutput(); + localRegexGuardrail.writeTo(output); + LocalRegexGuardrail localRegexGuardrail1 = new LocalRegexGuardrail(output.bytes().streamInput()); + + Assert.assertArrayEquals(localRegexGuardrail.getStopWords().toArray(), localRegexGuardrail1.getStopWords().toArray()); + Assert.assertArrayEquals(localRegexGuardrail.getRegex(), localRegexGuardrail1.getRegex()); + } + + @Test + public void toXContent() throws IOException { + LocalRegexGuardrail localRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + localRegexGuardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + LocalRegexGuardrail localRegexGuardrail = LocalRegexGuardrail.parse(parser); + + Assert.assertArrayEquals(localRegexGuardrail.getStopWords().toArray(), List.of(stopWords).toArray()); + Assert.assertArrayEquals(localRegexGuardrail.getRegex(), regex); + } + + @Test + public void validateRegexListSuccess() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, regexPatterns); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexListFailed() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, regexPatterns); + + Assert.assertFalse(res); + } + + @Test + public void validateRegexListNull() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, null); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexListEmpty() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, List.of()); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexSuccess() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegex(input, regexPatterns.get(0)); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexFailed() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegex(input, regexPatterns.get(0)); + + Assert.assertFalse(res); + } + + @Test + public void validateStopWords() throws IOException { + Map> stopWordsIndices = Map.of("test_index", List.of("test_field")); + SearchResponse searchResponse = createSearchResponse(1); + ActionFuture future = createSearchResponseFuture(searchResponse); + when(this.client.search(any())).thenReturn(future); + + Boolean res = localRegexGuardrail.validateStopWords("hello world", stopWordsIndices); + Assert.assertTrue(res); + } + + @Test + public void validateStopWordsNull() { + Boolean res = localRegexGuardrail.validateStopWords("hello world", null); + Assert.assertTrue(res); + } + + @Test + public void validateStopWordsEmpty() { + Boolean res = localRegexGuardrail.validateStopWords("hello world", Map.of()); + Assert.assertTrue(res); + } + + @Test + public void validateStopWordsSingleIndex() throws IOException { + SearchResponse searchResponse = createSearchResponse(1); + ActionFuture future = createSearchResponseFuture(searchResponse); + when(this.client.search(any())).thenReturn(future); + + Boolean res = localRegexGuardrail.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field")); + Assert.assertTrue(res); + } + + private SearchResponse createSearchResponse(int size) throws IOException { + XContentBuilder content = localRegexGuardrail.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + SearchHit[] hits = new SearchHit[size]; + if (size > 0) { + hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); + } + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + + private ActionFuture createSearchResponseFuture(SearchResponse searchResponse) { + return new ActionFuture<>() { + @Override + public SearchResponse actionGet() { + return searchResponse; + } + + @Override + public SearchResponse actionGet(String timeout) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(long timeoutMillis) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(long timeout, TimeUnit unit) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(TimeValue timeout) { + return searchResponse; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public SearchResponse get() { + return searchResponse; + } + + @Override + public SearchResponse get(long timeout, TimeUnit unit) { + return searchResponse; + } + }; + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java index 4af3072c8a..b2a29ba7c7 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java @@ -5,43 +5,22 @@ package org.opensearch.ml.common.model; -import org.apache.lucene.search.TotalHits; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.Client; -import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; -import org.opensearch.search.aggregations.InternalAggregations; -import org.opensearch.search.internal.InternalSearchResponse; -import org.opensearch.search.profile.SearchProfileShardResults; -import org.opensearch.search.suggest.Suggest; import org.opensearch.threadpool.ThreadPool; -import java.io.IOException; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.regex.Pattern; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; public class MLGuardTests { @@ -56,8 +35,8 @@ public class MLGuardTests { StopWords stopWords; String[] regex; List regexPatterns; - Guardrail inputGuardrail; - Guardrail outputGuardrail; + LocalRegexGuardrail inputLocalRegexGuardrail; + LocalRegexGuardrail outputLocalRegexGuardrail; Guardrails guardrails; MLGuard mlGuard; @@ -73,16 +52,16 @@ public void setUp() { stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); - inputGuardrail = new Guardrail(List.of(stopWords), regex); - outputGuardrail = new Guardrail(List.of(stopWords), regex); - guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail); mlGuard = new MLGuard(guardrails, xContentRegistry, client); } @Test public void validateInput() { String input = "\n\nHuman:hello stop words.\n\nAssistant:"; - Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT); + Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT, Collections.emptyMap()); Assert.assertFalse(res); } @@ -92,182 +71,13 @@ public void validateInitializedStopWordsEmpty() { stopWords = new StopWords(null, null); regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); - inputGuardrail = new Guardrail(List.of(stopWords), regex); - outputGuardrail = new Guardrail(List.of(stopWords), regex); - guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail); mlGuard = new MLGuard(guardrails, xContentRegistry, client); String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT); + Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT, Collections.emptyMap()); Assert.assertTrue(res); } - - @Test - public void validateOutput() { - String input = "\n\nHuman:hello stop words.\n\nAssistant:"; - Boolean res = mlGuard.validate(input, MLGuard.Type.OUTPUT); - - Assert.assertFalse(res); - } - - @Test - public void validateRegexListSuccess() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, regexPatterns); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexListFailed() { - String input = "\n\nHuman:hello stop words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, regexPatterns); - - Assert.assertFalse(res); - } - - @Test - public void validateRegexListNull() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, null); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexListEmpty() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, List.of()); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexSuccess() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0)); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexFailed() { - String input = "\n\nHuman:hello stop words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0)); - - Assert.assertFalse(res); - } - - @Test - public void validateStopWords() throws IOException { - Map> stopWordsIndices = Map.of("test_index", List.of("test_field")); - SearchResponse searchResponse = createSearchResponse(1); - ActionFuture future = createSearchResponseFuture(searchResponse); - when(this.client.search(any())).thenReturn(future); - - Boolean res = mlGuard.validateStopWords("hello world", stopWordsIndices); - Assert.assertTrue(res); - } - - @Test - public void validateStopWordsNull() { - Boolean res = mlGuard.validateStopWords("hello world", null); - Assert.assertTrue(res); - } - - @Test - public void validateStopWordsEmpty() { - Boolean res = mlGuard.validateStopWords("hello world", Map.of()); - Assert.assertTrue(res); - } - - @Test - public void validateStopWordsSingleIndex() throws IOException { - SearchResponse searchResponse = createSearchResponse(1); - ActionFuture future = createSearchResponseFuture(searchResponse); - when(this.client.search(any())).thenReturn(future); - - Boolean res = mlGuard.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field")); - Assert.assertTrue(res); - } - - private SearchResponse createSearchResponse(int size) throws IOException { - XContentBuilder content = guardrails.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - SearchHit[] hits = new SearchHit[size]; - if (size > 0) { - hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); - } - return new SearchResponse( - new InternalSearchResponse( - new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f), - InternalAggregations.EMPTY, - new Suggest(Collections.emptyList()), - new SearchProfileShardResults(Collections.emptyMap()), - false, - false, - 1 - ), - "", - 5, - 5, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - } - - private ActionFuture createSearchResponseFuture(SearchResponse searchResponse) { - return new ActionFuture<>() { - @Override - public SearchResponse actionGet() { - return searchResponse; - } - - @Override - public SearchResponse actionGet(String timeout) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(long timeoutMillis) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(long timeout, TimeUnit unit) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(TimeValue timeout) { - return searchResponse; - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public SearchResponse get() throws InterruptedException, ExecutionException { - return searchResponse; - } - - @Override - public SearchResponse get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - return searchResponse; - } - }; - } } \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java new file mode 100644 index 0000000000..9c30dc06e1 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java @@ -0,0 +1,101 @@ +package org.opensearch.ml.common.model; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.when; + +public class ModelGuardrailTests { + NamedXContentRegistry xContentRegistry; + @Mock + Client client; + + Pattern regexPattern; + ModelGuardrail modelGuardrail; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + doNothing().when(this.client).execute(any(), any(), any()); + modelGuardrail = new ModelGuardrail("test_model_id", "$.test", "^accept$"); + regexPattern = Pattern.compile("^accept$"); + } + + @Test + public void writeTo() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + modelGuardrail.writeTo(output); + ModelGuardrail modelGuardrail1 = new ModelGuardrail(output.bytes().streamInput()); + + Assert.assertEquals(modelGuardrail.getModelId(), modelGuardrail1.getModelId()); + Assert.assertEquals(modelGuardrail.getResponseFilter(), modelGuardrail1.getResponseFilter()); + Assert.assertEquals(modelGuardrail.getResponseAccept(), modelGuardrail1.getResponseAccept()); + } + + @Test + public void validateParametersNull() { + Assert.assertTrue(modelGuardrail.validate("test", null)); + } + + @Test + public void validateParametersEmpty() { + Assert.assertTrue(modelGuardrail.validate("test", Collections.emptyMap())); + } + + @Test + public void validateParametersEmpty1() { + Assert.assertTrue(modelGuardrail.validate("test", Map.of("question", ""))); + } + + @Test + public void init() { + Assert.assertNull(modelGuardrail.getRegexAcceptPattern()); + modelGuardrail.init(xContentRegistry, client); + Assert.assertEquals(regexPattern.toString(), modelGuardrail.getRegexAcceptPattern().toString()); + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + modelGuardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_accept\":\"^accept$\"}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_accept\":\"^accept$\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + ModelGuardrail modelGuardrail1 = ModelGuardrail.parse(parser); + + Assert.assertEquals(modelGuardrail1.getModelId(), modelGuardrail.getModelId()); + Assert.assertEquals(modelGuardrail1.getResponseFilter(), modelGuardrail.getResponseFilter()); + Assert.assertEquals(modelGuardrail1.getResponseAccept(), modelGuardrail.getResponseAccept()); + } +} \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index cad0278a6d..0adfb99663 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -193,7 +193,13 @@ public static ModelTensors processOutput( if (modelResponse == null) { throw new IllegalArgumentException("model response is null"); } - if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) { + if (mlGuard != null + && !mlGuard + .validate( + modelResponse, + MLGuard.Type.OUTPUT, + Map.of("question", org.opensearch.ml.common.utils.StringUtils.processTextDoc(modelResponse)) + )) { throw new IllegalArgumentException("guardrails triggered for LLM output"); } List modelTensors = new ArrayList<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 11e43cef85..e786122cbe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -199,7 +199,7 @@ && getUserRateLimiterMap().get(user.getName()) != null RestStatus.TOO_MANY_REQUESTS ); } else { - if (getMlGuard() != null && !getMlGuard().validate(payload, MLGuard.Type.INPUT)) { + if (getMlGuard() != null && !getMlGuard().validate(payload, MLGuard.Type.INPUT, parameters)) { throw new IllegalArgumentException("guardrails triggered for user input"); } if (getConnectorClientConfig().getMaxRetryTimes() != 0) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index c275d2263e..0ca15a6220 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -8,9 +8,10 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -21,13 +22,13 @@ import org.opensearch.ml.utils.TestData; import org.opensearch.ml.utils.TestHelper; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; public class RestMLGuardrailsIT extends MLCommonsRestTestCase { final String OPENAI_KEY = System.getenv("OPENAI_KEY"); - final String COHERE_KEY = System.getenv("COHERE_KEY"); + final String acceptRegex = "^\\s*[Aa]ccept\\s*$"; + final String rejectRegex = "^\\s*[Rr]eject\\s*$"; final String completionModelConnectorEntity = "{\n" + "\"name\": \"OpenAI Connector\",\n" @@ -65,6 +66,44 @@ public class RestMLGuardrailsIT extends MLCommonsRestTestCase { + " ]\n" + "}"; + final String completionModelConnectorEntityWithGuardrail = "{\n" + + "\"name\": \"OpenAI Connector\",\n" + + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + "\"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + + "\"protocol\": \"http\",\n" + + "\"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"gpt-3.5-turbo-instruct\",\n" + + " \"prompt\": \"You are a helpful assistant and an expert judge of content quality. Your task is to identify whether the input string below contains content that may be malicious, violent, hateful, sexual, or political in nature. Your answer should consist of a single word, either reject or accept. If the input belongs to any of these categories, please write reject. Otherwise, write accept. \\\\n\\\\nHere is the input: ${parameters.question}. \\\\n\\\\nYour answer: \",\n" + + " \"response_filter\": \"$.choices[0].text\"" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -85,7 +124,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModelGuardrails("openAI-GPT-3.5 completions", connectorId); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); @@ -123,7 +162,7 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModelGuardrails("openAI-GPT-3.5 completions", connectorId); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); @@ -135,8 +174,134 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; + predictRemoteModel(modelId, predictInput); + } + + public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + // Create guardrails model. + Response response = createConnector(completionModelConnectorEntityWithGuardrail); + Map responseMap = parseResponseToMap(response); + String guardrailConnectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "accept". + String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + response = predictRemoteModel(guardrailModelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + String validationResult = (String) responseMap.get("response"); + Assert.assertTrue(validateRegex(validationResult, acceptRegex)); + // Create predict model. + response = createConnector(completionModelConnectorEntity); + responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + // Predict. + predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"${parameters.question}\",\n" + + " \"question\": \"hello\"\n" + + " }\n" + + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); + responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + if (responseList == null) { + assertTrue(checkThrottlingOpenAI(responseMap)); + return; + } + responseMap = (Map) responseList.get(0); + assertFalse(((String) responseMap.get("text")).isEmpty()); + } + + public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + exceptionRule.expect(ResponseException.class); + exceptionRule.expectMessage("guardrails triggered for user input"); + // Create guardrails model. + Response response = createConnector(completionModelConnectorEntityWithGuardrail); + Map responseMap = parseResponseToMap(response); + String guardrailConnectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "reject". + String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"I will be executed or tortured.\"\n" + " }\n" + "}"; + response = predictRemoteModel(guardrailModelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + String validationResult = (String) responseMap.get("response"); + Assert.assertTrue(validateRegex(validationResult, rejectRegex)); + // Create predict model. + response = createConnector(completionModelConnectorEntity); + responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + // Predict with throwing guardrail exception. + predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"${parameters.question}\",\n" + + " \"question\": \"I will be executed or tortured.\"\n" + + " }\n" + + "}"; + predictRemoteModel(modelId, predictInput); } protected void createStopWordsIndex() throws IOException { @@ -158,7 +323,7 @@ protected void createStopWordsIndex() throws IOException { "_bulk?refresh=true", null, TestHelper.toHttpEntity(TestData.STOP_WORDS_DATA.replaceAll("stop_words", indexName)), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + null ); Response statsResponse = TestHelper.makeRequest(client(), "GET", indexName, ImmutableMap.of(), "", null); @@ -169,7 +334,45 @@ protected Response createConnector(String input) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); } - protected Response registerRemoteModelGuardrails(String name, String connectorId) throws IOException { + protected Response registerRemoteModel(String modelGroupName, String name, String connectorId) throws IOException { + String registerModelGroupEntity = "{\n" + + " \"name\": \"" + + modelGroupName + + "\",\n" + + " \"description\": \"This is an example description\"\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + TestHelper.toHttpEntity(registerModelGroupEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((String) responseMap.get("status"), "CREATED"); + String modelGroupId = (String) responseMap.get("model_group_id"); + + String registerModelEntity = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"version\": \"1.0.0\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\"\n" + + "}"; + return TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); + } + + protected Response registerRemoteModelWithLocalRegexGuardrails(String name, String connectorId) throws IOException { String registerModelGroupEntity = "{\n" + " \"name\": \"remote_model_group\",\n" + " \"description\": \"This is an example description\"\n" @@ -201,6 +404,7 @@ protected Response registerRemoteModelGuardrails(String name, String connectorId + connectorId + "\",\n" + " \"guardrails\": {\n" + + " \"type\": \"local_regex\",\n" + " \"input_guardrail\": {\n" + " \"stop_words\": [\n" + " {" @@ -225,6 +429,58 @@ protected Response registerRemoteModelGuardrails(String name, String connectorId .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } + protected Response registerRemoteModelWithModelGuardrails(String name, String connectorId, String guardrailModelId) throws IOException { + + String registerModelGroupEntity = "{\n" + + " \"name\": \"remote_model_group\",\n" + + " \"description\": \"This is an example description\"\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + TestHelper.toHttpEntity(registerModelGroupEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((String) responseMap.get("status"), "CREATED"); + String modelGroupId = (String) responseMap.get("model_group_id"); + + String registerModelEntity = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"version\": \"1.0.0\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\",\n" + + " \"guardrails\": {\n" + + " \"type\": \"model\",\n" + + " \"input_guardrail\": {\n" + + " \"model_id\": \"" + + guardrailModelId + + "\",\n" + + " \"response_accept\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\"" + + " },\n" + + " \"output_guardrail\": {\n" + + " \"model_id\": \"" + + guardrailModelId + + "\",\n" + + " \"response_accept\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\"" + + " }\n" + + " }\n" + + "}"; + return TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); + } + protected Response deployRemoteModel(String modelId) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null); } @@ -247,7 +503,7 @@ protected void disableClusterConnectorAccessControl() throws IOException { "_cluster/settings", null, "{\"persistent\":{\"plugins.ml_commons.connector_access_control_enabled\":false, \"plugins.ml_commons.sync_up_job_interval_in_seconds\":3}}", - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + null ); assertEquals(200, response.getStatusLine().getStatusCode()); } @@ -255,4 +511,11 @@ protected void disableClusterConnectorAccessControl() throws IOException { protected Response getTask(String taskId) throws IOException { return TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); } + + private Boolean validateRegex(String input, String regex) { + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(input); + return matcher.matches(); + + } } From b051160a102b83c3680c2d390fa3d3f00fc69806 Mon Sep 17 00:00:00 2001 From: "Mate, Kim" <18133668+mateon01@users.noreply.github.com> Date: Tue, 11 Jun 2024 08:33:18 +0900 Subject: [PATCH 06/12] Titan Embedding Connector Blueprint content referenced by users of OpenSearch 2.11 version #2517 (#2519) * change - use built-in functions based on the OpenSearch version of the titan connect blueprint * change - use built-in functions based on the OpenSearch version of the titan connect blueprint * chage - Refined version support information for Connector Blueprint * chage - Refined version support information for Connector Blueprint * Apply suggestions from code review Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu Co-authored-by: Kim, Sewoong Co-authored-by: Yaliang Wu --- ...ock_connector_titan_embedding_blueprint.md | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md b/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md index 3e01a12d1e..73bcb1dfab 100644 --- a/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md +++ b/docs/remote_inference_blueprints/bedrock_connector_titan_embedding_blueprint.md @@ -89,6 +89,42 @@ POST /_plugins/_ml/connectors/_create } ``` +As of version 2.12 of the OpenSearch Service, we support the connector.pre_process.bedrock.embedding and connector.post_process.bedrock.embedding embedding functions. +However, If you are using AWS OpenSearch Service version 2.11, there are no built-in functions for pre_process_function and post_process_function. +So, you need to add the script as shown below. + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "", + "service_name": "bedrock", + "model": "amazon.titan-embed-text-v1" + }, + "credential": { + "roleArn": "" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "\n StringBuilder builder = new StringBuilder();\n builder.append(\"\\\"\");\n String first = params.text_docs[0];\n builder.append(first);\n builder.append(\"\\\"\");\n def parameters = \"{\" +\"\\\"inputText\\\":\" + builder + \"}\";\n return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";", + "post_process_function": "\n def name = \"sentence_embedding\";\n def dataType = \"FLOAT32\";\n if (params.embedding == null || params.embedding.length == 0) {\n return params.message;\n }\n def shape = [params.embedding.length];\n def json = \"{\" +\n \"\\\"name\\\":\\\"\" + name + \"\\\",\" +\n \"\\\"data_type\\\":\\\"\" + dataType + \"\\\",\" +\n \"\\\"shape\\\":\" + shape + \",\" +\n \"\\\"data\\\":\" + params.embedding +\n \"}\";\n return json;\n " + } + ] +} +``` + Sample response: ```json { From 22b558d1f985d6de2c7b4713a6928fde73ea0030 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Tue, 11 Jun 2024 09:34:12 -0700 Subject: [PATCH 07/12] Fix model still deployed after calling undeploy API (#2510) * Fix model still deployed after calling undeploy API Signed-off-by: Sicheng Song * Add UT coverage Signed-off-by: Sicheng Song * Fix style Signed-off-by: Sicheng Song * Add UT coverage Signed-off-by: Sicheng Song * Add UT coverage Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song --- .../TransportUndeployModelAction.java | 201 +++++----- .../TransportUndeployModelActionTests.java | 357 ++++++++++++++++-- .../ml/tools/ToolIntegrationWithLLMTest.java | 1 - 3 files changed, 437 insertions(+), 122 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index 662971b2c7..6456039774 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -29,7 +29,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -42,10 +41,10 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; -import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -59,11 +58,8 @@ public class TransportUndeployModelAction extends private final MLModelManager mlModelManager; private final ClusterService clusterService; private final Client client; - private DiscoveryNodeHelper nodeFilter; + private final DiscoveryNodeHelper nodeFilter; private final MLStats mlStats; - private NamedXContentRegistry xContentRegistry; - - private ModelAccessControlHelper modelAccessControlHelper; @Inject public TransportUndeployModelAction( @@ -74,9 +70,7 @@ public TransportUndeployModelAction( ThreadPool threadPool, Client client, DiscoveryNodeHelper nodeFilter, - MLStats mlStats, - NamedXContentRegistry xContentRegistry, - ModelAccessControlHelper modelAccessControlHelper + MLStats mlStats ) { super( MLUndeployModelAction.NAME, @@ -90,107 +84,128 @@ public TransportUndeployModelAction( MLUndeployModelNodeResponse.class ); this.mlModelManager = mlModelManager; + this.clusterService = clusterService; this.client = client; this.nodeFilter = nodeFilter; this.mlStats = mlStats; - this.xContentRegistry = xContentRegistry; - this.modelAccessControlHelper = modelAccessControlHelper; } @Override - protected MLUndeployModelNodesResponse newResponse( - MLUndeployModelNodesRequest nodesRequest, - List responses, - List failures + protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener listener) { + ActionListener wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> { + processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener); + }, listener::onFailure); + super.doExecute(task, request, wrappedListener); + } + + void processUndeployModelResponseAndUpdate( + MLUndeployModelNodesResponse undeployModelNodesResponse, + ActionListener listener ) { - if (responses != null) { - Map> actualRemovedNodesMap = new HashMap<>(); - Map modelWorkNodesBeforeRemoval = new HashMap<>(); - responses.forEach(r -> { - Map nodeCounts = r.getModelWorkerNodeBeforeRemoval(); - - if (nodeCounts != null) { - for (Map.Entry entry : nodeCounts.entrySet()) { - // when undeploy a undeployed model, the entry.getvalue() is null - if (entry.getValue() != null - && (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey()) - || modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) { - modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue()); - } + List responses = undeployModelNodesResponse.getNodes(); + if (responses == null || responses.isEmpty()) { + listener.onResponse(undeployModelNodesResponse); + return; + } + + Map> actualRemovedNodesMap = new HashMap<>(); + Map modelWorkNodesBeforeRemoval = new HashMap<>(); + responses.forEach(r -> { + Map nodeCounts = r.getModelWorkerNodeBeforeRemoval(); + + if (nodeCounts != null) { + for (Map.Entry entry : nodeCounts.entrySet()) { + // when undeploy an undeployed model, the entry.getvalue() is null + if (entry.getValue() != null + && (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey()) + || modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) { + modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue()); } } + } - Map modelUndeployStatus = r.getModelUndeployStatus(); - for (Map.Entry entry : modelUndeployStatus.entrySet()) { - String status = entry.getValue(); - if (UNDEPLOYED.equals(status)) { - String modelId = entry.getKey(); - if (!actualRemovedNodesMap.containsKey(modelId)) { - actualRemovedNodesMap.put(modelId, new ArrayList<>()); - } - actualRemovedNodesMap.get(modelId).add(r.getNode().getId()); + Map modelUndeployStatus = r.getModelUndeployStatus(); + for (Map.Entry entry : modelUndeployStatus.entrySet()) { + String status = entry.getValue(); + if (UNDEPLOYED.equals(status)) { + String modelId = entry.getKey(); + if (!actualRemovedNodesMap.containsKey(modelId)) { + actualRemovedNodesMap.put(modelId, new ArrayList<>()); } + actualRemovedNodesMap.get(modelId).add(r.getNode().getId()); } - }); - - MLSyncUpInput syncUpInput = MLSyncUpInput - .builder() - .removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap)) - .build(); - - MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (actualRemovedNodesMap.size() > 0) { - BulkRequest bulkRequest = new BulkRequest(); - Map deployToAllNodes = new HashMap<>(); - for (String modelId : actualRemovedNodesMap.keySet()) { - UpdateRequest updateRequest = new UpdateRequest(); - List removedNodes = actualRemovedNodesMap.get(modelId); - int removedNodeCount = removedNodes.size(); - /** - * If allow custom deploy is false, user can only undeploy all nodes and status is undeployed. - * If allow custom deploy is true, user can undeploy all nodes and status is undeployed, - * or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and - * we need to update both planning worker nodes (count) and current worker nodes (count) - * and deployToAllNodes value in model index. - */ - Map updateDocument = new HashMap<>(); - if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes. - updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of()); - updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0); - updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0); - updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED); - } else { // undeploy partial nodes. - // TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed, - // and the user could be undeploying not running model nodes, and we should update model status to deployed. - updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false); - List newPlanningWorkerNodes = Arrays - .stream(modelWorkNodesBeforeRemoval.get(modelId)) - .filter(x -> !removedNodes.contains(x)) - .collect(Collectors.toList()); - updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes); - updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); - updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); - deployToAllNodes.put(modelId, false); - } - updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument); - bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + } + }); + + MLSyncUpInput syncUpInput = MLSyncUpInput + .builder() + .removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap)) + .build(); + + MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (actualRemovedNodesMap.size() > 0) { + BulkRequest bulkRequest = new BulkRequest(); + Map deployToAllNodes = new HashMap<>(); + for (String modelId : actualRemovedNodesMap.keySet()) { + UpdateRequest updateRequest = new UpdateRequest(); + List removedNodes = actualRemovedNodesMap.get(modelId); + int removedNodeCount = removedNodes.size(); + /** + * If allow custom deploy is false, user can only undeploy all nodes and status is undeployed. + * If allow custom deploy is true, user can undeploy all nodes and status is undeployed, + * or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and + * we need to update both planning worker nodes (count) and current worker nodes (count) + * and deployToAllNodes value in model index. + */ + Map updateDocument = new HashMap<>(); + if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes. + updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of()); + updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0); + updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0); + updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED); + } else { // undeploy partial nodes. + // TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed, + // and the user could be undeploying not running model nodes, and we should update model status to deployed. + updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false); + List newPlanningWorkerNodes = Arrays + .stream(modelWorkNodesBeforeRemoval.get(modelId)) + .filter(x -> !removedNodes.contains(x)) + .collect(Collectors.toList()); + updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes); + updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); + updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); + deployToAllNodes.put(modelId, false); } - syncUpInput.setDeployToAllNodes(deployToAllNodes); - ActionListener actionListener = ActionListener.wrap(r -> { - log - .debug( - "updated model state as undeployed for : {}", - Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0])) - ); - }, e -> { log.error("Failed to update model state as undeployed", e); }); - client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { syncUpUndeployedModels(syncUpRequest); })); - } else { - syncUpUndeployedModels(syncUpRequest); + updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument); + bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); } + syncUpInput.setDeployToAllNodes(deployToAllNodes); + ActionListener actionListener = ActionListener.wrap(r -> { + log + .debug( + "updated model state as undeployed for : {}", + Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0])) + ); + }, e -> { log.error("Failed to update model state as undeployed", e); }); + client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { + syncUpUndeployedModels(syncUpRequest); + listener.onResponse(undeployModelNodesResponse); + })); + } else { + syncUpUndeployedModels(syncUpRequest); + listener.onResponse(undeployModelNodesResponse); } } + } + + @Override + protected MLUndeployModelNodesResponse newResponse( + MLUndeployModelNodesRequest nodesRequest, + List responses, + List failures + ) { return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java index 8de5cf5ec8..87d42f3847 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java @@ -9,11 +9,10 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.io.IOException; import java.net.InetAddress; @@ -25,34 +24,41 @@ import java.util.concurrent.ExecutorService; import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.Spy; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; -import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; -import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -77,6 +83,18 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { @Mock private Client client; + @Mock + ClusterState clusterState; + + @Mock + Task task; + + @Spy + ActionListener actionListener; + + @Mock + MLSyncUpNodeResponse syncUpNodeResponse; + @Mock private DiscoveryNodeHelper nodeFilter; @@ -93,10 +111,22 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { private TransportUndeployModelAction action; - private DiscoveryNode localNode; + DiscoveryNode localNode; + + private DiscoveryNode node1; + + private DiscoveryNode node2; + + DiscoveryNode[] nodesArray; + + @Mock + private MLUndeployModelNodesResponse undeployModelNodesResponse; @Mock - private ModelAccessControlHelper modelAccessControlHelper; + private TransportNodesAction transportNodesAction; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setup() throws IOException { @@ -105,24 +135,26 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.generic()).thenReturn(executorService); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); return null; }).when(executorService).execute(any(Runnable.class)); - action = new TransportUndeployModelAction( - transportService, - actionFilters, - mlModelManager, - clusterService, - null, - client, - nodeFilter, - mlStats, - xContentRegistry, - modelAccessControlHelper + action = spy( + new TransportUndeployModelAction( + transportService, + actionFilters, + mlModelManager, + clusterService, + threadPool, + client, + nodeFilter, + mlStats + ) ); + localNode = new DiscoveryNode( "foo0", "foo0", @@ -131,8 +163,34 @@ public void setup() throws IOException { Collections.singleton(CLUSTER_MANAGER_ROLE), Version.CURRENT ); + + InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); + InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); + + DiscoveryNode node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(inetAddress1, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNode node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(inetAddress2, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); when(clusterService.localNode()).thenReturn(localNode); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.nodes()).thenReturn(nodes); } public void testConstructor() { @@ -171,7 +229,23 @@ public void testNodeOperation() { assertNotNull(response); } - public void testNewResponseWithUndeployedModelStatus() { + public void testDoExecuteTransportUndeployedModelAction() { + MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + + action.doExecute(task, nodesRequest, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(MLUndeployModelNodesResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + } + + public void testProcessUndeployModelResponseAndUpdateNullResponse() { + when(undeployModelNodesResponse.getNodes()).thenReturn(null); + action.processUndeployModelResponseAndUpdate(undeployModelNodesResponse, actionListener); + } + + public void testProcessUndeployModelResponseAndUpdateResponse() { final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( new String[] { "nodeId1", "nodeId2" }, new String[] { "modelId1", "modelId2" } @@ -187,13 +261,240 @@ public void testNewResponseWithUndeployedModelStatus() { responses.add(response2); final List failures = new ArrayList<>(); final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - assertNotNull(response); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(1)).bulk(argumentCaptor.capture(), any()); - UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0); - assertEquals(ML_MODEL_INDEX, updateRequest.index()); - Map updateContent = updateRequest.doc().sourceAsMap(); - assertEquals(MLModelState.UNDEPLOYED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD)); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateBulkException() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Bulk request failed")); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateSyncUpException() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("SyncUp request failed")); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseDeployStatusWrong() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "wrong_status"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployPartialNodes() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus1 = new HashMap<>(); + modelToDeployStatus1.put("modelId1", "undeployed"); + Map modelToDeployStatus2 = new HashMap<>(); + modelToDeployStatus2.put("modelId1", "deployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] { "foo0", "foo0" }); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus1, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus2, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployEmptyNodes() { + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", new String[] {}); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + BulkResponse bulkResponse = mock(BulkResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + + MLSyncUpNodesResponse syncUpNodesResponse = mock(MLSyncUpNodesResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(syncUpNodesResponse); + return null; + }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + verify(actionListener).onResponse(response); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployNodeEntrySetNull() { + exceptionRule.expect(NullPointerException.class); + + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", null); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + action.processUndeployModelResponseAndUpdate(response, actionListener); + } + + public void testProcessUndeployModelResponseAndUpdateResponseUndeployModelWorkerNodeBeforeRemovalNull() { + exceptionRule.expect(NullPointerException.class); + + final MLUndeployModelNodesRequest nodesRequest = new MLUndeployModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + Map modelToDeployStatus = new HashMap<>(); + modelToDeployStatus.put("modelId1", "undeployed"); + MLUndeployModelNodeResponse response1 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, null); + MLUndeployModelNodeResponse response2 = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, null); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + + action.processUndeployModelResponseAndUpdate(response, actionListener); } public void testNewResponseWithNotFoundModelStatus() { diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index fe033645d6..3e7c2e64f4 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -72,7 +72,6 @@ public void stopMockLLM() { @After public void deleteModel() throws IOException { undeployModel(modelId); - waitModelUndeployed(modelId); deleteModel(client(), modelId, null); } From 7cd52915d04d8ac7ddb6e37a74a256603587ce69 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 11 Jun 2024 15:11:48 -0500 Subject: [PATCH 08/12] ml inference ingest processor support for local models (#2508) * ml inference ingest processor support for local models Signed-off-by: Bhavana Ramaram --- .../ml/plugin/MachineLearningPlugin.java | 5 +- .../processor/MLInferenceIngestProcessor.java | 201 +++-- .../ml/processor/ModelExecutor.java | 84 +- ...LInferenceIngestProcessorFactoryTests.java | 36 +- .../MLInferenceIngestProcessorTests.java | 849 ++++++++++++++++-- .../RestMLInferenceIngestProcessorIT.java | 218 ++++- 6 files changed, 1265 insertions(+), 128 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e9a79236b1..6d808c64bb 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -1006,7 +1006,10 @@ public void loadExtensions(ExtensionLoader loader) { public Map getProcessors(org.opensearch.ingest.Processor.Parameters parameters) { Map processors = new HashMap<>(); processors - .put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client)); + .put( + MLInferenceIngestProcessor.TYPE, + new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry) + ); return Collections.unmodifiableMap(processors); } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index c06f32803c..b19853e02c 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -6,25 +6,31 @@ import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.GroupedActionListener; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.ingest.ValueSource; -import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.utils.StringUtils; @@ -42,10 +48,16 @@ */ public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor { + private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class); + public static final String DOT_SYMBOL = "."; private final InferenceProcessorAttributes inferenceProcessorAttributes; private final boolean ignoreMissing; + private final String functionName; + private final boolean fullResponsePath; private final boolean ignoreFailure; + private final boolean override; + private final String modelInput; private final ScriptService scriptService; private static Client client; public static final String TYPE = "ml_inference"; @@ -53,9 +65,14 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod // allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the // prediction outcomes, return the whole prediction outcome by skipping filtering public static final String IGNORE_MISSING = "ignore_missing"; + public static final String OVERRIDE = "override"; + public static final String FUNCTION_NAME = "function_name"; + public static final String FULL_RESPONSE_PATH = "full_response_path"; + public static final String MODEL_INPUT = "model_input"; // At default, ml inference processor allows maximum 10 prediction tasks running in parallel // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + private final NamedXContentRegistry xContentRegistry; private Configuration suppressExceptionConfiguration = Configuration .builder() @@ -71,9 +88,14 @@ protected MLInferenceIngestProcessor( String tag, String description, boolean ignoreMissing, + String functionName, + boolean fullResponsePath, boolean ignoreFailure, + boolean override, + String modelInput, ScriptService scriptService, - Client client + Client client, + NamedXContentRegistry xContentRegistry ) { super(tag, description); this.inferenceProcessorAttributes = new InferenceProcessorAttributes( @@ -84,9 +106,14 @@ protected MLInferenceIngestProcessor( maxPredictionTask ); this.ignoreMissing = ignoreMissing; + this.functionName = functionName; + this.fullResponsePath = fullResponsePath; this.ignoreFailure = ignoreFailure; + this.override = override; + this.modelInput = modelInput; this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -162,10 +189,48 @@ private void processPredictions( List> processOutputMap, int inputMapIndex, int inputMapSize - ) { + ) throws IOException { Map modelParameters = new HashMap<>(); + Map modelConfigs = new HashMap<>(); + if (inferenceProcessorAttributes.getModelConfigMaps() != null) { modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + } + + Map ingestDocumentSourceAndMetaData = new HashMap<>(); + ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); + ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); + + Map> newOutputMapping = new HashMap<>(); + if (processOutputMap != null) { + + Map outputMapping = processOutputMap.get(inputMapIndex); + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + newOutputMapping.put(newDocumentFieldName, dotPathsInArray); + } + + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPaths = newOutputMapping.get(newDocumentFieldName); + + int existingFields = 0; + for (String path : dotPaths) { + if (ingestDocument.hasField(path)) { + existingFields++; + } + } + if (!override && existingFields == dotPaths.size()) { + logger.debug("{} already exists in the ingest document. Removing it from output mapping", newDocumentFieldName); + newOutputMapping.remove(newDocumentFieldName); + } + } + if (newOutputMapping.size() == 0) { + batchPredictionListener.onResponse(null); + return; + } } // when no input mapping is provided, default to read all fields from documents as model input if (inputMapSize == 0) { @@ -184,15 +249,30 @@ private void processPredictions( } } - ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId()); + Set inputMapKeys = new HashSet<>(modelParameters.keySet()); + inputMapKeys.removeAll(modelConfigs.keySet()); + + Map inputMappings = new HashMap<>(); + for (String k : inputMapKeys) { + inputMappings.put(k, modelParameters.get(k)); + } + ActionRequest request = getMLModelInferenceRequest( + xContentRegistry, + modelParameters, + modelConfigs, + inputMappings, + inferenceProcessorAttributes.getModelId(), + functionName, + modelInput + ); client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { @Override public void onResponse(MLTaskResponse mlTaskResponse) { - ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); + MLOutput mlOutput = mlTaskResponse.getOutput(); if (processOutputMap == null || processOutputMap.isEmpty()) { - appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument); + appendFieldValue(mlOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument); } else { // outMapping serves as a filter to modelTensorOutput, the fields that are not specified // in the outputMapping will not write to document @@ -202,14 +282,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) { // document field as key, model field as value String newDocumentFieldName = entry.getKey(); String modelOutputFieldName = entry.getValue(); - if (ingestDocument.hasField(newDocumentFieldName)) { - throw new IllegalArgumentException( - "document already has field name " - + newDocumentFieldName - + ". Not allow to overwrite the same field name, please check output_map." - ); + if (!newOutputMapping.containsKey(newDocumentFieldName)) { + continue; } - appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument); + appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument); } } batchPredictionListener.onResponse(null); @@ -305,63 +381,61 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN /** * Appends the model output value to the specified field in the IngestDocument without modifying the source. * - * @param modelTensorOutput the ModelTensorOutput containing the model output + * @param mlOutput the MLOutput containing the model output * @param modelOutputFieldName the name of the field in the model output * @param newDocumentFieldName the name of the field in the IngestDocument to append the value to * @param ingestDocument the IngestDocument to append the value to */ private void appendFieldValue( - ModelTensorOutput modelTensorOutput, + MLOutput mlOutput, String modelOutputFieldName, String newDocumentFieldName, IngestDocument ingestDocument ) { - Object modelOutputValue = null; - if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) { + if (mlOutput == null) { + throw new RuntimeException("model inference output is null"); + } - modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing); + Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); - Map ingestDocumentSourceAndMetaData = new HashMap<>(); - ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); - ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); - List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + Map ingestDocumentSourceAndMetaData = new HashMap<>(); + ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); + ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); - if (dotPathsInArray.size() == 1) { - ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); + if (dotPathsInArray.size() == 1) { + ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } else { + if (!(modelOutputValue instanceof List)) { + throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); + } + List modelOutputValueArray = (List) modelOutputValue; + // check length of the prediction array to be the same of the document array + if (dotPathsInArray.size() != modelOutputValueArray.size()) { + throw new RuntimeException( + "the prediction field: " + + modelOutputFieldName + + " is an array in size of " + + modelOutputValueArray.size() + + " but the document field array from field " + + newDocumentFieldName + + " is in size of " + + dotPathsInArray.size() + ); + } + // Iterate over dotPathInArray + for (int i = 0; i < dotPathsInArray.size(); i++) { + String dotPathInArray = dotPathsInArray.get(i); + Object modelOutputValueInArray = modelOutputValueArray.get(i); + ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); + .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); - } else { - if (!(modelOutputValue instanceof List)) { - throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); - } - List modelOutputValueArray = (List) modelOutputValue; - // check length of the prediction array to be the same of the document array - if (dotPathsInArray.size() != modelOutputValueArray.size()) { - throw new RuntimeException( - "the prediction field: " - + modelOutputFieldName - + " is an array in size of " - + modelOutputValueArray.size() - + " but the document field array from field " - + newDocumentFieldName - + " is in size of " - + dotPathsInArray.size() - ); - } - // Iterate over dotPathInArray - for (int i = 0; i < dotPathsInArray.size(); i++) { - String dotPathInArray = dotPathsInArray.get(i); - Object modelOutputValueInArray = modelOutputValueArray.get(i); - ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); - TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); - ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); - } } - } else { - throw new RuntimeException("model inference output cannot be null"); } } @@ -374,6 +448,7 @@ public static class Factory implements Processor.Factory { private final ScriptService scriptService; private final Client client; + private final NamedXContentRegistry xContentRegistry; /** * Constructs a new instance of the Factory class. @@ -381,9 +456,10 @@ public static class Factory implements Processor.Factory { * @param scriptService the ScriptService instance to be used by the Factory * @param client the Client instance to be used by the Factory */ - public Factory(ScriptService scriptService, Client client) { + public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) { this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -410,6 +486,14 @@ public MLInferenceIngestProcessor create( int maxPredictionTask = ConfigurationUtils .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); + boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false); + String functionName = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name()); + String modelInput = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }"); + boolean defaultValue = !functionName.equalsIgnoreCase("remote"); + boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue); + boolean ignoreFailure = ConfigurationUtils .readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false); // convert model config user input data structure to Map @@ -440,9 +524,14 @@ public MLInferenceIngestProcessor create( processorTag, description, ignoreMissing, + functionName, + fullResponsePath, ignoreFailure, + override, + modelInput, scriptService, - client + client, + xContentRegistry ); } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index 1abc770d07..ff46c13f62 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -5,17 +5,29 @@ package org.opensearch.ml.processor; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -45,17 +57,47 @@ public interface ModelExecutor { * @return an ActionRequest instance for remote model inference * @throws IllegalArgumentException if the input parameters are null */ - default ActionRequest getRemoteModelInferenceRequest(Map parameters, String modelId) { + default ActionRequest getMLModelInferenceRequest( + NamedXContentRegistry xContentRegistry, + Map parameters, + Map modelConfigs, + Map inputMappings, + String modelId, + String functionNameStr, + String modelInput + ) throws IOException { if (parameters == null) { throw new IllegalArgumentException("wrong input. The model input cannot be empty."); } - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + FunctionName functionName = FunctionName.REMOTE; + if (functionNameStr != null) { + functionName = FunctionName.from(functionNameStr); + } + + Map inputParams = new HashMap<>(); + if (FunctionName.REMOTE == functionName) { + inputParams.put("parameters", StringUtils.toJson(parameters)); + } else { + inputParams.putAll(parameters); + } - MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + String payload = modelInput; + StringSubstitutor modelConfigSubstitutor = new StringSubstitutor(modelConfigs, "${model_config.", "}"); + payload = modelConfigSubstitutor.replace(payload); + StringSubstitutor inputMapSubstitutor = new StringSubstitutor(inputMappings, "${input_map.", "}"); + payload = inputMapSubstitutor.replace(payload); + StringSubstitutor parametersSubstitutor = new StringSubstitutor(inputParams, "${ml_inference.", "}"); + payload = parametersSubstitutor.replace(payload); - ActionRequest request = new MLPredictionTaskRequest(modelId, mlInput, null); + if (!isJson(payload)) { + throw new IllegalArgumentException("Invalid payload: " + payload); + } + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, payload); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInput mlInput = MLInput.parse(parser, functionName.name()); - return request; + return new MLPredictionTaskRequest(modelId, mlInput); } @@ -74,7 +116,9 @@ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String m try { // getMlModelOutputs() returns a list or collection. // Adding null check for modelTensorOutput - if (modelTensorOutput != null && !modelTensorOutput.getMlModelOutputs().isEmpty()) { + if (modelTensorOutput != null + && modelTensorOutput.getMlModelOutputs() != null + && !modelTensorOutput.getMlModelOutputs().isEmpty()) { // getMlModelOutputs() returns a list of ModelTensors // accessing the first element. // TODO currently remote model only return single tensor, might need to processor multiple tensors later @@ -130,11 +174,35 @@ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String m throw new RuntimeException("Model outputs are null or empty."); } } catch (Exception e) { - throw new RuntimeException("An unexpected error occurred: " + e.getMessage()); + throw new RuntimeException(e.getMessage()); } return modelOutputValue; } + default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldName, boolean ignoreMissing, boolean fullResponsePath) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + String modelOutputJsonStr = mlOutput.toXContent(builder, ToXContent.EMPTY_PARAMS).toString(); + Map modelTensorOutputMap = gson.fromJson(modelOutputJsonStr, Map.class); + if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) { + return getModelOutputValue((ModelTensorOutput) mlOutput, modelOutputFieldName, ignoreMissing); + } else if (modelOutputFieldName == null || modelTensorOutputMap == null) { + return modelTensorOutputMap; + } else { + try { + return JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName); + } catch (Exception e) { + if (ignoreMissing) { + return modelTensorOutputMap; + } else { + throw new IllegalArgumentException("model inference output cannot find such json path: " + modelOutputFieldName, e); + } + } + } + } catch (Exception e) { + throw new RuntimeException("An unexpected error occurred: " + e.getMessage()); + } + } + /** * Parses the data from the given ModelTensor and returns it as an Object. * The method handles different data types (integer, floating-point, string, and boolean) diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java index 577e8b8693..7ca077a82f 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java @@ -5,6 +5,9 @@ package org.opensearch.ml.processor; import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.FULL_RESPONSE_PATH; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.FUNCTION_NAME; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.MODEL_INPUT; import java.util.ArrayList; import java.util.HashMap; @@ -15,6 +18,7 @@ import org.mockito.Mock; import org.opensearch.OpenSearchParseException; import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.Processor; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; @@ -25,10 +29,12 @@ public class MLInferenceIngestProcessorFactoryTests extends OpenSearchTestCase { private Client client; @Mock private ScriptService scriptService; + @Mock + private NamedXContentRegistry xContentRegistry; @Before public void init() { - factory = new MLInferenceIngestProcessor.Factory(scriptService, client); + factory = new MLInferenceIngestProcessor.Factory(scriptService, client, xContentRegistry); } public void testCreateRequiredFields() throws Exception { @@ -42,6 +48,34 @@ public void testCreateRequiredFields() throws Exception { assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); } + public void testCreateLocalModelProcessor() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }"); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config); + assertNotNull(mLInferenceIngestProcessor); + assertEquals(mLInferenceIngestProcessor.getTag(), processorTag); + assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); + } + public void testCreateNoFieldPresent() throws Exception { Map config = new HashMap<>(); try { diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index d11cc213de..203392eb75 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -9,10 +9,12 @@ import static org.mockito.Mockito.*; import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME; +import java.io.IOException; import java.nio.ByteBuffer; import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -26,6 +28,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.IngestDocument; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.output.model.MLResultDataType; @@ -52,6 +55,9 @@ public class MLInferenceIngestProcessorTests extends OpenSearchTestCase { private ScriptService scriptService; @Mock private BiConsumer handler; + + @Mock + NamedXContentRegistry xContentRegistry; private static final String PROCESSOR_TAG = "inference"; private static final String DESCRIPTION = "inference_test"; private IngestDocument ingestDocument; @@ -74,30 +80,53 @@ public void setup() { } private MLInferenceIngestProcessor createMLInferenceProcessor( - String model_id, - Map model_config, - List> input_map, - List> output_map, + String modelId, + List> inputMaps, + List> outputMaps, + Map modelConfigMaps, boolean ignoreMissing, - boolean ignoreFailure + String functionName, + boolean fullResponsePath, + boolean ignoreFailure, + boolean override, + String modelInput ) { + functionName = functionName != null ? functionName : "remote"; + modelInput = modelInput != null ? modelInput : "{ \"parameters\": ${ml_inference.parameters} }"; + return new MLInferenceIngestProcessor( - model_id, - input_map, - output_map, - model_config, + modelId, + inputMaps, + outputMaps, + modelConfigMaps, RANDOM_MULTIPLIER, PROCESSOR_TAG, DESCRIPTION, ignoreMissing, + functionName, + fullResponsePath, ignoreFailure, + override, + modelInput, scriptService, - client + client, + xContentRegistry ); } public void testExecute_Exception() throws Exception { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); try { IngestDocument document = processor.execute(ingestDocument); } catch (UnsupportedOperationException e) { @@ -111,9 +140,20 @@ public void testExecute_Exception() throws Exception { */ public void testExecute_nestedObjectStringDocumentSuccess() { - List> inputMap = getInputMapsForNestedObjectChunks("chunks.chunk"); - - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -137,10 +177,21 @@ public void testExecute_nestedObjectStringDocumentSuccess() { * test nested object document with array of Map, * the value Object is a Map */ - public void testExecute_nestedObjectMapDocumentSuccess() { + public void testExecute_nestedObjectMapDocumentSuccess() throws IOException { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -181,9 +232,10 @@ public void testExecute_nestedObjectMapDocumentSuccess() { embedding_text.add("this is first"); embedding_text.add("this is second"); inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(inputParameters, "model1"); + .getMLModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest @@ -224,10 +276,21 @@ public void testExecute_jsonPathWithMissingLeaves() { * test nested object document with array of Map, * the value Object is a also a nested object, */ - public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() { + public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() throws IOException { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -254,9 +317,10 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSucce embedding_text.add("this is third"); embedding_text.add("this is fourth"); inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(inputParameters, "model1"); + .getMLModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest @@ -278,7 +342,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess( List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -311,7 +386,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingL List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -345,7 +431,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingL } public void testExecute_InferenceException() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); when(client.execute(any(), any())).thenThrow(new RuntimeException("Executing Model failed with exception")); try { processor.execute(ingestDocument, handler); @@ -355,7 +452,18 @@ public void testExecute_InferenceException() { } public void testExecute_InferenceOnFailure() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); RuntimeException inferenceFailure = new RuntimeException("Executing Model failed with exception"); doAnswer(invocation -> { @@ -375,7 +483,18 @@ public void testExecute_AppendFieldValueExceptionOnResponse() throws Exception { String originalOutPutFieldName = "response1"; output.put("text_embedding", originalOutPutFieldName); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); @@ -409,7 +528,18 @@ public void testExecute_whenInputFieldNotFound_ExceptionWithIgnoreMissingFalse() Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -429,7 +559,42 @@ public void testExecute_whenInputFieldNotFound_SuccessWithIgnoreMissingTrue() { Map output = new HashMap<>(); output.put("text_embedding", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); + + processor.execute(ingestDocument, handler); + } + + public void testExecute_localModelInputFieldNotFound_SuccessWithIgnoreMissingTrue() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + Map model_config = new HashMap<>(); + model_config.put("return_number", "true"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + true, + "text_embedding", + true, + false, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); processor.execute(ingestDocument, handler); } @@ -447,7 +612,18 @@ public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingFalse() { Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -469,7 +645,18 @@ public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingTrue() { Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + true, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); @@ -491,7 +678,18 @@ public void testExecute_IOExceptionWithIgnoreMissingFalse() throws JsonProcessin ObjectMapper mapper = mock(ObjectMapper.class); when(mapper.readValue(Mockito.anyString(), eq(Object.class))).thenThrow(JsonProcessingException.class); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -501,8 +699,56 @@ public void testExecute_IOExceptionWithIgnoreMissingFalse() throws JsonProcessin } public void testExecute_NoModelInput_Exception() { - MLInferenceIngestProcessor processorIgnoreMissingTrue = createMLInferenceProcessor("model1", null, null, null, true, false); - MLInferenceIngestProcessor processorIgnoreMissingFalse = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processorIgnoreMissingTrue = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); + MLInferenceIngestProcessor processorIgnoreMissingFalse = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); + + MLInferenceIngestProcessor localModelProcessorIgnoreMissingFalse = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "text_embedding", + false, + false, + false, + null + ); + + MLInferenceIngestProcessor localModelProcessorIgnoreMissingTrue = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "text_embedding", + false, + false, + false, + null + ); Map sourceAndMetadata = new HashMap<>(); IngestDocument emptyIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -517,10 +763,32 @@ public void testExecute_NoModelInput_Exception() { assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); } + try { + localModelProcessorIgnoreMissingTrue.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + try { + localModelProcessorIgnoreMissingFalse.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + } public void testExecute_AppendModelOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -543,7 +811,18 @@ public void testExecute_AppendModelOutputSuccess() { } public void testExecute_SingleTensorInDataOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); Float[] value = new Float[] { 1.0f, 2.0f, 3.0f }; List outputs = new ArrayList<>(); @@ -578,7 +857,18 @@ public void testExecute_SingleTensorInDataOutputSuccess() { } public void testExecute_MultipleTensorInDataOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); List outputs = new ArrayList<>(); Float[] value = new Float[] { 1.0f }; @@ -640,7 +930,18 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() { output.put("classification", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -671,7 +972,18 @@ public void testExecute_getModelOutputFieldWithDotPathSuccess() { output.put("language_identification", "response.language"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", List.of("en", "en"), "score", "0.9876"))) @@ -703,7 +1015,18 @@ public void testExecute_getModelOutputFieldWithInvalidDotPathSuccess() { output.put("language_identification", "response.lan"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -733,7 +1056,18 @@ public void testExecute_getModelOutputFieldWithInvalidDotPathException() { output.put("response.lan", "language_identification"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -768,7 +1102,18 @@ public void testExecute_getModelOutputFieldInNestedWithInvalidDotPathException() output.put("chunks.*.chunk.text.*.context_embedding", "response.language1"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -803,7 +1148,18 @@ public void testExecute_getModelOutputFieldWithExistedFieldNameException() { output.put("key1", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + true, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -818,17 +1174,13 @@ public void testExecute_getModelOutputFieldWithExistedFieldNameException() { }).when(client).execute(any(), any(), any()); processor.execute(ingestDocument, handler); - verify(handler) - .accept( - eq(null), - argThat( - exception -> exception - .getMessage() - .equals( - "document already has field name key1. Not allow to overwrite the same field name, please check output_map." - ) - ) - ); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", ImmutableMap.of("language", "en", "score", "0.9876")); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); } public void testExecute_documentNotExistedFieldNameException() { @@ -842,7 +1194,18 @@ public void testExecute_documentNotExistedFieldNameException() { output.put("classification", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); verify(handler) @@ -852,7 +1215,18 @@ public void testExecute_documentNotExistedFieldNameException() { public void testExecute_nestedDocumentNotExistedFieldNameException() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context1"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + false, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); verify(handler) @@ -871,7 +1245,18 @@ public void testExecute_getModelOutputFieldDifferentLengthException() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -907,7 +1292,18 @@ public void testExecute_getModelOutputFieldDifferentLengthIgnoreFailureSuccess() List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -937,7 +1333,18 @@ public void testExecute_getMlModelTensorsIsNull() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); doAnswer(invocation -> { @@ -959,12 +1366,74 @@ public void testExecute_getMlModelTensorsIsNull() { } + public void testExecute_localMLModelTensorsIsNull() { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "chunks.*.chunk.text.*.context"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("chunks.*.chunk.text.*.context_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "text_embedding", + true, + false, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + + verify(handler) + .accept( + eq(null), + argThat( + exception -> exception + .getMessage() + .equals( + "An unexpected error occurred: model inference output " + + "cannot find such json path: $.inference_results[0].output[0].data" + ) + ) + ); + + } + public void testExecute_getMlModelTensorsIsNullIgnoreFailure() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); doAnswer(invocation -> { @@ -985,7 +1454,18 @@ public void testExecute_modelTensorOutputIsNull() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -997,7 +1477,11 @@ public void testExecute_modelTensorOutputIsNull() { IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); processor.execute(nestedObjectIngestDocument, handler); - verify(handler).accept(eq(null), argThat(exception -> exception.getMessage().equals("model inference output cannot be null"))); + verify(handler) + .accept( + eq(null), + argThat(exception -> exception.getMessage().equals("An unexpected error occurred: Model outputs are null or empty.")) + ); } @@ -1006,7 +1490,18 @@ public void testExecute_modelTensorOutputIsNullIgnoreFailureSuccess() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -1021,6 +1516,238 @@ public void testExecute_modelTensorOutputIsNullIgnoreFailureSuccess() { verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); } + /** + * Test processor configuration with nested object document + * and array of Map, where the value Object is a List + */ + public void testExecute_localModelSuccess() { + + // Processor configuration + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "_ingest._value.title"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("_ingest._value.title_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model_1", + inputMap, + outputMap, + null, + true, + "text_embedding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + + // Mocking the model output + List modelPredictionOutput = Arrays.asList(1, 2, 3, 4); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "inference_results", + Arrays + .asList( + ImmutableMap + .of( + "output", + Arrays.asList(ImmutableMap.of("name", "sentence_embedding", "data", modelPredictionOutput)) + ) + ) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + // Setting up the ingest document + Map sourceAndMetadata = new HashMap<>(); + List> books = new ArrayList<>(); + Map book1 = new HashMap<>(); + book1.put("title", Arrays.asList("first book")); + book1.put("description", "This is first book"); + Map book2 = new HashMap<>(); + book2.put("title", Arrays.asList("second book")); + book2.put("description", "This is second book"); + books.add(book1); + books.add(book2); + sourceAndMetadata.put("books", books); + + Map ingestMetadata = new HashMap<>(); + ingestMetadata.put("pipeline", "test_pipeline"); + ingestMetadata.put("timestamp", ZonedDateTime.now()); + Map ingestValue = new HashMap<>(); + ingestValue.put("title", Arrays.asList("first book")); + ingestValue.put("description", "This is first book"); + ingestMetadata.put("_value", ingestValue); + sourceAndMetadata.put("_ingest", ingestMetadata); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // Validate the document + List> updatedBooks = new ArrayList<>(); + Map updatedBook1 = new HashMap<>(); + updatedBook1.put("title", Arrays.asList("first book")); + updatedBook1.put("description", "This is first book"); + updatedBook1.put("title_embedding", modelPredictionOutput); + Map updatedBook2 = new HashMap<>(); + updatedBook2.put("title", Arrays.asList("second book")); + updatedBook2.put("description", "This is second book"); + updatedBook2.put("title_embedding", modelPredictionOutput); + updatedBooks.add(updatedBook1); + updatedBooks.add(updatedBook2); + sourceAndMetadata.put("books", updatedBooks); + + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + } + + public void testExecute_localSparseEncodingModelMultipleModelTensors() { + + // Processor configuration + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "chunks.*.chunk.text.*.context"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("chunks.*.chunk.text.*.context_embedding", "$.inference_results.*.output.*.dataAsMap.response"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model_1", + inputMap, + outputMap, + null, + true, + "sparse_encoding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + + // Mocking the model output with simple values + List> modelEmbeddings = new ArrayList<>(); + Map embedding = ImmutableMap.of("response", Arrays.asList(1.0, 2.0, 3.0, 4.0)); + for (int i = 1; i <= 4; i++) { + modelEmbeddings.add(embedding); + } + + List modelTensors = new ArrayList<>(); + for (Map embeddings : modelEmbeddings) { + modelTensors.add(ModelTensor.builder().dataAsMap(embeddings).build()); + } + + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(Collections.singletonList(ModelTensors.builder().mlModelTensors(modelTensors).build())) + .build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + IngestDocument ingestDocument = new IngestDocument(getNestedObjectWithAnotherNestedObjectSource(), new HashMap<>()); + processor.execute(ingestDocument, handler); + verify(handler).accept(eq(ingestDocument), isNull()); + + List> chunks = (List>) ingestDocument.getFieldValue("chunks", List.class); + + List> firstChunkTexts = (List>) ((Map) chunks.get(0).get("chunk")) + .get("text"); + Assert.assertEquals(modelEmbeddings.get(0).get("response"), firstChunkTexts.get(0).get("context_embedding")); + Assert.assertEquals(modelEmbeddings.get(1).get("response"), firstChunkTexts.get(1).get("context_embedding")); + + List> secondChunkTexts = (List>) ((Map) chunks.get(1).get("chunk")) + .get("text"); + Assert.assertEquals(modelEmbeddings.get(2).get("response"), secondChunkTexts.get(0).get("context_embedding")); + Assert.assertEquals(modelEmbeddings.get(3).get("response"), secondChunkTexts.get(1).get("context_embedding")); + + } + + public void testExecute_localModelOutputIsNullIgnoreFailureSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "text_embedding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + + public void testExecute_localModelTensorsIsNullIgnoreFailure() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + public void testParseGetDataInTensor_IntegerDataType() { ModelTensor mockTensor = mock(ModelTensor.class); when(mockTensor.getDataType()).thenReturn(MLResultDataType.INT8); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java index 1937b8c496..f8d623fc74 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java @@ -5,6 +5,8 @@ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestHelper.makeRequest; import java.io.IOException; @@ -17,6 +19,12 @@ import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.utils.TestHelper; import com.google.common.collect.ImmutableList; @@ -26,6 +34,8 @@ public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase { private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); private String openAIChatModelId; private String bedrockEmbeddingModelId; + + private String localModelId; private final String completionModelConnectorEntity = "{\n" + " \"name\": \"OpenAI text embedding model Connector\",\n" + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" @@ -350,6 +360,192 @@ public void testMLInferenceProcessorWithForEachProcessor() throws Exception { Assert.assertEquals(1536, embedding2.size()); } + public void testMLInferenceProcessorLocalModelObjectField() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String createPipelineRequestBody = "{\n" + + " \"description\": \"test ml model ingest processor\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": ${ml_inference.text_docs} }\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"diary\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"diary_embedding\": \"$.inference_results.*.output.*.data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"diary_embedding_pipeline\"\n" + + " }\n" + + " }\n" + + " }"; + String uploadDocumentRequestBody = "{\n" + + " \"id\": 1,\n" + + " \"diary\": [\"happy\",\"first day at school\"],\n" + + " \"weather\": \"rainy\"\n" + + " }"; + String index_name = "daily_index"; + createPipelineProcessor(createPipelineRequestBody, "diary_embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); + + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); + List embeddingList = JsonPath.parse(document).read("_source.diary_embedding"); + Assert.assertEquals(2, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.diary_embedding[0]"); + Assert.assertEquals(768, embedding1.size()); + Assert.assertEquals(0.42101282, (Double) embedding1.get(0), 0.005); + + List embedding2 = JsonPath.parse(document).read("_source.diary_embedding[1]"); + Assert.assertEquals(768, embedding2.size()); + Assert.assertEquals(0.49191704, (Double) embedding2.get(0), 0.005); + } + + // TODO: add tests for other local model types such as sparse/cross encoders + public void testMLInferenceProcessorLocalModelNestedField() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String createPipelineRequestBody = "{\n" + + " \"description\": \"ingest reviews and generate embedding\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": ${ml_inference.text_docs} }\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"book.*.chunk.text.*.context\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"book.*.chunk.text.*.context_embedding\": \"$.inference_results.*.output.*.data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": true,\n" + + " \"ignore_failure\": true\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"embedding_pipeline\"\n" + + " }\n" + + " }\n" + + " }"; + String uploadDocumentRequestBody = "{\n" + + " \"book\": [\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the first part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the second part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the third part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the fourth part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String index_name = "book_index"; + createPipelineProcessor(createPipelineRequestBody, "embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); + + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); + + List embeddingList = JsonPath.parse(document).read("_source.book[*].chunk.text[*].context_embedding"); + Assert.assertEquals(4, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.book[0].chunk.text[0].context_embedding"); + Assert.assertEquals(768, embedding1.size()); + Assert.assertEquals(0.48988956, (Double) embedding1.get(0), 0.005); + + List embedding2 = JsonPath.parse(document).read("_source.book[0].chunk.text[1].context_embedding"); + Assert.assertEquals(768, embedding2.size()); + Assert.assertEquals(0.49552172, (Double) embedding2.get(0), 0.005); + + List embedding3 = JsonPath.parse(document).read("_source.book[1].chunk.text[0].context_embedding"); + Assert.assertEquals(768, embedding3.size()); + Assert.assertEquals(0.5004309, (Double) embedding3.get(0), 0.005); + + List embedding4 = JsonPath.parse(document).read("_source.book[1].chunk.text[1].context_embedding"); + Assert.assertEquals(768, embedding4.size()); + Assert.assertEquals(0.47907734, (Double) embedding4.get(0), 0.005); + } + protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception { Response pipelineCreateResponse = TestHelper .makeRequest( @@ -378,7 +574,6 @@ protected void createIndex(String indexName, String requestBody) throws Exceptio protected void uploadDocument(final String index, final String docId, final String jsonBody) throws IOException { Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); - request.setJsonEntity(jsonBody); client().performRequest(request); } @@ -390,4 +585,25 @@ protected Map getDocument(final String index, final String docId) throws Excepti return parseResponseToMap(docResponse); } + protected MLRegisterModelInput registerModelInput() throws IOException, InterruptedException { + + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .build(); + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") + .build(); + } + } From b57fc59d20b5e27f35ba52a98d4a036b433de351 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jun 2024 13:16:22 -0700 Subject: [PATCH 09/12] fix flaky IT (#2530) Signed-off-by: Yaliang Wu --- .../test/java/org/opensearch/ml/rest/RestConnectorToolIT.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java index 4ae9653d60..76a1c20e61 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.disableClusterConnectorAccessControl; import java.io.IOException; @@ -27,6 +28,7 @@ public class RestConnectorToolIT extends RestBaseAgentToolsIT { @Before public void setUp() throws Exception { super.setUp(); + disableClusterConnectorAccessControl(); Thread.sleep(20000); this.bedrockClaudeConnectorId = createBedrockClaudeConnector("execute"); this.bedrockClaudeConnectorIdForPredict = createBedrockClaudeConnector("predict"); From 9fa49f42f0f43ccfcf78cbc6b110f6f7cd6b684f Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Wed, 12 Jun 2024 04:30:06 +0800 Subject: [PATCH 10/12] Add claud3 blueprint (#2464) * Add claud3 blueprint Signed-off-by: Hailong Cui * Apply suggestions from code review Co-authored-by: Yaliang Wu Signed-off-by: Hailong Cui --------- Signed-off-by: Hailong Cui Co-authored-by: Yaliang Wu --- ...k_connector_anthropic_claude3_blueprint.md | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md diff --git a/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md new file mode 100644 index 0000000000..bad60a89dd --- /dev/null +++ b/docs/remote_inference_blueprints/bedrock_connector_anthropic_claude3_blueprint.md @@ -0,0 +1,167 @@ +# Bedrock connector blueprint example for Claude V3 model + +## 1. Add connector endpoint to trusted URLs: + +Note: no need to do this after 2.11.0 + +```json +PUT /_cluster/settings +{ + "persistent": { + "plugins.ml_commons.trusted_connector_endpoints_regex": [ + "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" + ] + } +} +``` + +## 2. Create connector for Amazon Bedrock: + +If you are using self-managed Opensearch, you should supply AWS credentials: + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock claude v3", + "description": "Test connector for Amazon Bedrock claude v3", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "auth": "Sig_V4", + "response_filter": "$.content[0].text", + "max_tokens_to_sample": "8000", + "anthropic_version": "bedrock-2023-05-31", + "model": "anthropic.claude-3-sonnet-20240229-v1:0" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "request_body": "{\"messages\":[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"${parameters.inputs}\"}]}],\"anthropic_version\":\"${parameters.anthropic_version}\",\"max_tokens\":${parameters.max_tokens_to_sample}}" + } + ] +} +``` + +If using the AWS Opensearch Service, you can provide an IAM role arn that allows access to the bedrock service. +Refer to this [AWS doc](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/ml-amazon-connector.html) + +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Amazon Bedrock", + "description": "Test connector for Amazon Bedrock", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "roleArn": "" + }, + "parameters": { + "region": "", + "service_name": "bedrock", + "auth": "Sig_V4", + "response_filter": "$.content[0].text", + "max_tokens_to_sample": "8000", + "anthropic_version": "bedrock-2023-05-31" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "headers": { + "content-type": "application/json" + }, + "url": "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke", + "request_body": "{\"messages\":[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"${parameters.prompt}\"}]}],\"anthropic_version\":\"${parameters.anthropic_version}\",\"max_tokens\":${parameters.max_tokens_to_sample}}" + } + ] +} +``` + +Sample response: +```json +{ + "connector_id": "nMopmY8B8aiZvtEZLu9B" +} +``` + +## 3. Create model group: + +```json +POST /_plugins/_ml/model_groups/_register +{ + "name": "remote_model_group_claude3", + "description": "This is an example description" +} +``` + +Sample response: +```json +{ + "model_group_id": "IMobmY8B8aiZvtEZeO_i", + "status": "CREATED" +} +``` + +## 4. Register model to model group & deploy model: + +```json +POST /_plugins/_ml/models/_register?deploy=true +{ + "name": "anthropic.claude-v3", + "function_name": "remote", + "model_group_id": "IMobmY8B8aiZvtEZeO_i", + "description": "claude v3 model", + "connector_id": "nMopmY8B8aiZvtEZLu9B" +} +``` + +Sample response: +```json +{ + "task_id": "rMormY8B8aiZvtEZIO_j", + "status": "CREATED", + "model_id": "rcormY8B8aiZvtEZIe89" +} +``` + +## 5. Test model inference + +```json +POST /_plugins/_ml/models/rcormY8B8aiZvtEZIe89/_predict +{ + "parameters": { + "inputs": "What is the meaning of life?" + } +} +``` + +Sample response: +```json +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "response": "There is no single, universally accepted answer to the meaning of life. It's a question that has been pondered by philosophers, theologians, and thinkers across cultures for centuries. Here are some of the major perspectives on deriving meaning in life:\n\n- Religious/spiritual views - Many religions provide a framework for finding meaning through connection to the divine, fulfilling religious teachings/duties, and an afterlife.\n\n- Existentialist philosophy - Thinkers like Sartre and Camus emphasized that we each have the freedom and responsibility to create our own subjective meaning in an objectively meaningless universe.\n\n- Hedonism - The view that the pursuit of pleasure and avoiding suffering is the highest good and most meaningful way to live.\n\n- Virtue ethics - Finding meaning through living an ethical life based on virtues like courage, temperance, justice, and wisdom.\n\n- Humanistic psychology - Psychologists like Maslow and Frankl emphasized fulfillment from reaching one's full human potential and finding a sense of purpose.\n\n- Naturalism/Nihilism - Some believe life itself has no inherent meaning beyond the physical/natural world we empirically experience.\n\nUltimately, the \"meaning of life\" is an existential question that challenges each individual to decide what makes their own life feel meaningful, based on their own worldview, beliefs, and values. There is no objectively \"correct\" universal answer." + } + } + ], + "status_code": 200 + } + ] +} +``` From 62e238f9de1c29634b8cfea752de86aa17168e21 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 12 Jun 2024 04:58:58 +0800 Subject: [PATCH 11/12] Fix bedrock embedding generation issue (#2495) * Fix bedrock connector embedding generation issue Signed-off-by: zane-neo * format code Signed-off-by: zane-neo * add IT Signed-off-by: zane-neo * add ITs Signed-off-by: zane-neo * format code Signed-off-by: zane-neo * change input to fix number format exception in local Signed-off-by: zane-neo * Add log to identify the failure IT root cause Signed-off-by: zane-neo * format code Signed-off-by: zane-neo * remove debug log Signed-off-by: zane-neo * Update plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java Co-authored-by: Yaliang Wu Signed-off-by: zane-neo * address comments Signed-off-by: zane-neo --------- Signed-off-by: zane-neo Co-authored-by: Yaliang Wu --- .../remote/RemoteConnectorExecutor.java | 12 ++- .../remote/AwsConnectorExecutorTest.java | 79 ++++++++++++++++ .../ml/rest/MLCommonsRestTestCase.java | 7 ++ .../ml/rest/RestBedRockInferenceIT.java | 91 +++++++++++++++++++ .../templates/BedRockConnectorBodies.json | 63 +++++++++++++ 5 files changed, 248 insertions(+), 4 deletions(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index e786122cbe..22c866873b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -93,7 +93,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) { @@ -117,11 +117,15 @@ private Tuple calculateChunkSize(String action, TextDocsInputD throw new IllegalArgumentException("no " + action + " action found"); } String preProcessFunction = connectorAction.get().getPreProcessFunction(); - if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) { - // user defined preprocess script, this case, the chunk size is always equals to text docs length. + if (preProcessFunction == null) { + // default preprocess case, consider this a batch. + return Tuple.tuple(1, textDocsLength); + } else if (MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction) + || !MLPreProcessFunction.contains(preProcessFunction)) { + // bedrock and user defined preprocess script, the chunk size is always equals to text docs length. return Tuple.tuple(textDocsLength, 1); } - // consider as batch. + // Other cases: non-bedrock and user defined preprocess script, consider as batch. return Tuple.tuple(1, textDocsLength); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index cb192e83f9..98d5feb7ba 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -612,6 +612,85 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre ); } + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddingPreProcessFunction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("aws_sigv4") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(executor.getScriptService()).thenReturn(scriptService); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); + } + + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreprocessFunction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("aws_sigv4") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(executor.getScriptService()).thenReturn(scriptService); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); + } + @Test public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { ConnectorAction predictAction = ConnectorAction diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index cf1f87e09e..886494de3c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -911,6 +911,13 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } + public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException { + String requestBody = TestHelper.toJsonString(input); + Response response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + return parseResponseToMap(response); + } + public Consumer> verifyTextEmbeddingModelDeployed() { return (modelProfile) -> { if (modelProfile.containsKey("model_state")) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java new file mode 100644 index 0000000000..fea981afe7 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.SneakyThrows; + +public class RestBedRockInferenceIT extends MLCommonsRestTestCase { + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + @SneakyThrows + @Before + public void setup() throws IOException, InterruptedException { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + } + + public void test_bedrock_embedding_model() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 2, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, output.get(1) instanceof Map); + validateOutput(errorMsg, (Map) output.get(0)); + validateOutput(errorMsg, (Map) output.get(1)); + } + } + + private void validateOutput(String errorMsg, Map output) { + assertTrue(errorMsg, output.containsKey("output")); + assertTrue(errorMsg, output.get("output") instanceof List); + List outputList = (List) output.get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + } +} diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json new file mode 100644 index 0000000000..5b75b5ab72 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json @@ -0,0 +1,63 @@ +{ + "without_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "with_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1", + "input_docs_processed_step_size": "1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +} From 06d17424d65de18b189b7c259325128b09f3a9ec Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jun 2024 15:45:50 -0700 Subject: [PATCH 12/12] add setting to allow private IP (#2534) * add setting to allow private IP Signed-off-by: Yaliang Wu * fix ut Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../remote/HttpJsonConnectorExecutor.java | 5 +- .../remote/RemoteConnectorExecutor.java | 3 ++ .../engine/algorithms/remote/RemoteModel.java | 3 ++ .../httpclient/MLHttpClientFactory.java | 13 +++-- .../remote/HttpJsonConnectorExecutorTest.java | 49 +++++++++++++++++++ .../httpclient/MLHttpClientFactoryTests.java | 37 +++++++++----- .../opensearch/ml/model/MLModelManager.java | 8 ++- .../ml/plugin/MachineLearningPlugin.java | 15 +++--- .../ml/settings/MLCommonsSettings.java | 3 ++ .../ml/settings/MLFeatureEnabledSetting.java | 12 +++++ .../ml/model/MLModelManagerTests.java | 9 +++- 11 files changed, 132 insertions(+), 25 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index ee29f67a43..5ac0245701 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -16,6 +16,7 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; @@ -62,6 +63,8 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { @Setter @Getter private MLGuard mlGuard; + @Setter + private volatile AtomicBoolean connectorPrivateIpEnabled; private SdkAsyncHttpClient httpClient; @@ -136,6 +139,6 @@ private void validateHttpClientParameters(String action, Map par String protocol = url.getProtocol(); String host = url.getHost(); int port = url.getPort(); - MLHttpClientFactory.validate(protocol, host, port); + MLHttpClientFactory.validate(protocol, host, port, connectorPrivateIpEnabled); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 22c866873b..28b7617103 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -15,6 +15,7 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; @@ -150,6 +151,8 @@ default void setScriptService(ScriptService scriptService) {} default void setClient(Client client) {} + default void setConnectorPrivateIpEnabled(AtomicBoolean connectorPrivateIpEnabled) {} + default void setXContentRegistry(NamedXContentRegistry xContentRegistry) {} default void setClusterService(ClusterService clusterService) {} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index c8685c010e..0f208adb7d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -43,6 +44,7 @@ public class RemoteModel implements Predictable { public static final String RATE_LIMITER = "rate_limiter"; public static final String USER_RATE_LIMITER_MAP = "user_rate_limiter_map"; public static final String GUARDRAILS = "guardrails"; + public static final String CONNECTOR_PRIVATE_IP_ENABLED = "connectorPrivateIpEnabled"; private RemoteConnectorExecutor connectorExecutor; @@ -101,6 +103,7 @@ public void initModel(MLModel model, Map params, Encryptor encry this.connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER)); this.connectorExecutor.setUserRateLimiterMap((Map) params.get(USER_RATE_LIMITER_MAP)); this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS)); + this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED)); } catch (RuntimeException e) { log.error("Failed to init remote model.", e); throw e; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index 339523b313..ffc95c30de 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -14,6 +14,7 @@ import java.time.Duration; import java.util.Arrays; import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @@ -43,9 +44,11 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, * @param protocol The protocol supported in remote inference, currently only http and https are supported. * @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost. * @param port The port number of the remote inference server, port number must be in range [0, 65536]. - * @throws UnknownHostException + * @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536]. + * @throws UnknownHostException Allow to use private IP or not. */ - public static void validate(String protocol, String host, int port) throws UnknownHostException { + public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled) + throws UnknownHostException { if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { log.error("Remote inference protocol is not http or https: " + protocol); throw new IllegalArgumentException("Protocol is not http or https: " + protocol); @@ -62,12 +65,12 @@ public static void validate(String protocol, String host, int port) throws Unkno log.error("Remote inference port out of range: " + port); throw new IllegalArgumentException("Port out of range: " + port); } - validateIp(host); + validateIp(host, connectorPrivateIpEnabled); } - private static void validateIp(String hostName) throws UnknownHostException { + private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException { InetAddress[] addresses = InetAddress.getAllByName(hostName); - if (hasPrivateIpAddress(addresses)) { + if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) { log.error("Remote inference host name has private ip address: " + hostName); throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 8f920ffeba..8f27be79a6 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -6,6 +6,8 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; @@ -13,6 +15,7 @@ import java.lang.reflect.Field; import java.util.Arrays; import java.util.HashMap; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Rule; @@ -102,6 +105,52 @@ public void invokeRemoteService_invalidIpAddress() { assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage()); } + @Test + public void invokeRemoteService_EnabledPrivateIpAddress() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://127.0.0.1/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + AtomicBoolean privateIpEnabled = new AtomicBoolean(true); + executor.setConnectorPrivateIpEnabled(privateIpEnabled); + executor + .invokeRemoteService( + PREDICT.name(), + createMLInput(), + new HashMap<>(), + "{\"input\": \"hello world\"}", + new ExecutionContext(0), + actionListener + ); + Mockito.verify(actionListener, never()).onFailure(any()); + + privateIpEnabled.set(false); + executor + .invokeRemoteService( + PREDICT.name(), + createMLInput(), + new HashMap<>(), + "{\"input\": \"hello world\"}", + new ExecutionContext(0), + actionListener + ); + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); + Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue() instanceof IllegalArgumentException; + assertEquals("Remote inference host name has private ip address: 127.0.0.1", captor.getValue().getMessage()); + } + @Test public void invokeRemoteService_Empty_payload() { ConnectorAction predictAction = ConnectorAction diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index d1d9b42dcc..1d79ac995e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertNotNull; import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Rule; import org.junit.Test; @@ -28,70 +29,84 @@ public void test_getSdkAsyncHttpClient_success() { @Test public void test_validateIp_validIp_noException() throws Exception { - MLHttpClientFactory.validate("http", "api.openai.com", 80); + AtomicBoolean privateIpEnabled = new AtomicBoolean(false); + MLHttpClientFactory.validate("http", "api.openai.com", 80, privateIpEnabled); } @Test public void test_validateIp_rarePrivateIp_throwException() throws Exception { + AtomicBoolean privateIpEnabled = new AtomicBoolean(false); try { - MLHttpClientFactory.validate("http", "0254.020.00.01", 80); + MLHttpClientFactory.validate("http", "0254.020.00.01", 80, privateIpEnabled); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validate("http", "172.1048577", 80); + MLHttpClientFactory.validate("http", "172.1048577", 80, privateIpEnabled); } catch (Exception e) { assertNotNull(e); } try { - MLHttpClientFactory.validate("http", "2886729729", 80); + MLHttpClientFactory.validate("http", "2886729729", 80, privateIpEnabled); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validate("http", "192.11010049", 80); + MLHttpClientFactory.validate("http", "192.11010049", 80, privateIpEnabled); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validate("http", "3232300545", 80); + MLHttpClientFactory.validate("http", "3232300545", 80, privateIpEnabled); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80); + MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80, privateIpEnabled); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - MLHttpClientFactory.validate("http", "153.24.76.232", 80); + MLHttpClientFactory.validate("http", "153.24.76.232", 80, privateIpEnabled); } catch (IllegalArgumentException e) { assertNotNull(e); } } + @Test + public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception { + AtomicBoolean privateIpEnabled = new AtomicBoolean(true); + MLHttpClientFactory.validate("http", "0254.020.00.01", 80, privateIpEnabled); + MLHttpClientFactory.validate("http", "172.1048577", 80, privateIpEnabled); + MLHttpClientFactory.validate("http", "2886729729", 80, privateIpEnabled); + MLHttpClientFactory.validate("http", "192.11010049", 80, privateIpEnabled); + MLHttpClientFactory.validate("http", "3232300545", 80, privateIpEnabled); + MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80, privateIpEnabled); + MLHttpClientFactory.validate("http", "153.24.76.232", 80, privateIpEnabled); + } + @Test public void test_validateSchemaAndPort_success() throws Exception { - MLHttpClientFactory.validate("http", "api.openai.com", 80); + MLHttpClientFactory.validate("http", "api.openai.com", 80, new AtomicBoolean(false)); } @Test public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); - MLHttpClientFactory.validate("ftp", "api.openai.com", 80); + MLHttpClientFactory.validate("ftp", "api.openai.com", 80, new AtomicBoolean(false)); } @Test public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Port out of range: 65537"); - MLHttpClientFactory.validate("https", "api.openai.com", 65537); + MLHttpClientFactory.validate("https", "api.openai.com", 65537, new AtomicBoolean(false)); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 6c608c95cd..78d36a9975 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -27,6 +27,7 @@ import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.GUARDRAILS; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.RATE_LIMITER; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE; @@ -128,6 +129,7 @@ import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.engine.utils.FileUtils; import org.opensearch.ml.profile.MLModelProfile; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; @@ -169,6 +171,7 @@ public class MLModelManager { private final MLTaskManager mlTaskManager; private final MLEngine mlEngine; private final DiscoveryNodeHelper nodeHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; private volatile Integer maxModelPerNode; private volatile Integer maxRegisterTasksPerNode; @@ -198,7 +201,8 @@ public MLModelManager( MLTaskManager mlTaskManager, MLModelCacheHelper modelCacheHelper, MLEngine mlEngine, - DiscoveryNodeHelper nodeHelper + DiscoveryNodeHelper nodeHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.client = client; this.threadPool = threadPool; @@ -213,6 +217,7 @@ public MLModelManager( this.mlTaskManager = mlTaskManager; this.mlEngine = mlEngine; this.nodeHelper = nodeHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; this.maxModelPerNode = ML_COMMONS_MAX_MODELS_PER_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MAX_MODELS_PER_NODE, it -> maxModelPerNode = it); @@ -1174,6 +1179,7 @@ private Map setUpParameterMap(String modelId) { params.put(GUARDRAILS, mlGuard); log.info("Setting up ML guard parameter for ML predictor."); } + params.put(CONNECTOR_PRIVATE_IP_ENABLED, mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); return Collections.unmodifiableMap(params); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 6d808c64bb..4ec1cc3f85 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -495,6 +495,11 @@ public Collection createComponents( mlIndicesHandler = new MLIndicesHandler(clusterService, client); mlTaskManager = new MLTaskManager(client, threadPool, mlIndicesHandler); modelHelper = new ModelHelper(mlEngine); + + mlInputDatasetHandler = new MLInputDatasetHandler(client); + modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings); + connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings); mlModelManager = new MLModelManager( clusterService, scriptService, @@ -509,12 +514,9 @@ public Collection createComponents( mlTaskManager, modelCacheHelper, mlEngine, - nodeHelper + nodeHelper, + mlFeatureEnabledSetting ); - mlInputDatasetHandler = new MLInputDatasetHandler(client); - modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings); - connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); - mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings); mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper); @@ -928,7 +930,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED, - MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE + MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, + MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index ba0c13e614..84bdef95fe 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -184,4 +184,7 @@ private MLCommonsSettings() {} // This setting is to enable/disable agent related API register/execute/delete/get/search agent. public static final Setting ML_COMMONS_AGENT_FRAMEWORK_ENABLED = Setting .boolSetting("plugins.ml_commons.agent_framework_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED = Setting + .boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index f636f33722..e393b97d24 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -8,9 +8,12 @@ package org.opensearch.ml.settings; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; +import java.util.concurrent.atomic.AtomicBoolean; + import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -20,11 +23,13 @@ public class MLFeatureEnabledSetting { private volatile Boolean isAgentFrameworkEnabled; private volatile Boolean isLocalModelEnabled; + private volatile AtomicBoolean isConnectorPrivateIpEnabled; public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings); isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings); + isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings)); clusterService .getClusterSettings() @@ -33,6 +38,9 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it)); } /** @@ -59,4 +67,8 @@ public boolean isLocalModelEnabled() { return isLocalModelEnabled; } + public AtomicBoolean isConnectorPrivateIpEnabled() { + return isConnectorPrivateIpEnabled; + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 189ac01876..01aee3fa04 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -61,6 +61,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; import org.junit.Before; @@ -109,6 +110,7 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; @@ -179,6 +181,8 @@ public class MLModelManagerTests extends OpenSearchTestCase { @Mock ClusterApplierService clusterApplierService; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; @Before public void setup() throws URISyntaxException { @@ -253,6 +257,8 @@ public void setup() throws URISyntaxException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()).thenReturn(new AtomicBoolean(false)); + modelManager = spy( new MLModelManager( clusterService, @@ -268,7 +274,8 @@ public void setup() throws URISyntaxException { mlTaskManager, modelCacheHelper, mlEngine, - nodeHelper + nodeHelper, + mlFeatureEnabledSetting ) );