From 1f550e16ba40a97107266ea71c23f6d2fa59c366 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 4 Jun 2024 15:56:45 +0800 Subject: [PATCH 01/11] Fix bedrock connector embedding generation issue Signed-off-by: zane-neo --- .../remote/RemoteConnectorExecutor.java | 12 ++-- .../remote/AwsConnectorExecutorTest.java | 71 +++++++++++++++++++ 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index e786122cbe..51afaa4b85 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -93,7 +93,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) { @@ -117,11 +117,15 @@ private Tuple calculateChunkSize(String action, TextDocsInputD throw new IllegalArgumentException("no " + action + " action found"); } String preProcessFunction = connectorAction.get().getPreProcessFunction(); - if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) { - // user defined preprocess script, this case, the chunk size is always equals to text docs length. + if (preProcessFunction == null) { + // default preprocess case, consider this a batch. + return Tuple.tuple(1, textDocsLength); + } else if (MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction) + || !MLPreProcessFunction.contains(preProcessFunction)) { + // bedrock and user defined preprocess script, the chunk size is always equals to text docs length. return Tuple.tuple(textDocsLength, 1); } - // consider as batch. + //Other cases: non-bedrock and user defined preprocess script, consider as batch. return Tuple.tuple(1, textDocsLength); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index cb192e83f9..e35cfa4cef 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -612,6 +612,77 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre ); } + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddingPreProcessFunction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("aws_sigv4") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(executor.getScriptService()).thenReturn(scriptService); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + } + + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreprocessFunction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("aws_sigv4") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(executor.getScriptService()).thenReturn(scriptService); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + } + @Test public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { ConnectorAction predictAction = ConnectorAction From 1f5e21bc5cb62c8a2b0806782bf0b0248aa3d0cc Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 4 Jun 2024 16:09:57 +0800 Subject: [PATCH 02/11] format code Signed-off-by: zane-neo --- .../ml/engine/algorithms/remote/RemoteConnectorExecutor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 51afaa4b85..22c866873b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -125,7 +125,7 @@ private Tuple calculateChunkSize(String action, TextDocsInputD // bedrock and user defined preprocess script, the chunk size is always equals to text docs length. return Tuple.tuple(textDocsLength, 1); } - //Other cases: non-bedrock and user defined preprocess script, consider as batch. + // Other cases: non-bedrock and user defined preprocess script, consider as batch. return Tuple.tuple(1, textDocsLength); } } From 9639f6d7ed00aa115b0ceaf2b21a71f55df425e5 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 12:17:11 +0800 Subject: [PATCH 03/11] add IT Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 8 ++ .../ml/rest/RestBedRockInferenceIT.java | 96 +++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index cf1f87e09e..6c6a2f8195 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -67,6 +67,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaType; @@ -87,6 +88,7 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -911,6 +913,12 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } + public ModelTensorOutput predictRemoteModel(String modelId, MLInput input) throws IOException { + Response response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, TestHelper.toJsonString(input), null); + return new ModelTensorOutput(StreamInput.wrap(response.getEntity().getContent().readAllBytes())); + } + public Consumer> verifyTextEmbeddingModelDeployed() { return (modelProfile) -> { if (modelProfile.containsKey("model_state")) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java new file mode 100644 index 0000000000..2811e85e6c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import com.google.common.collect.ImmutableList; +import com.jayway.jsonpath.JsonPath; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.InputDataSet; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.utils.TestHelper; +import org.w3c.dom.Text; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.opensearch.ml.utils.TestHelper.makeRequest; + +public class RestBedRockInferenceIT extends MLCommonsRestTestCase { + private String bedrockEmbeddingModelId; + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private final String bedrockEmbeddingModelConnectorEntity = "{\n" + + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" + + " \"description\": \"The connector to bedrock Titan embedding model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" + + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" + + " }\n" + + " ]\n" + + "}"; + + @Before + public void setup() throws IOException, InterruptedException { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); + } + + + public void test_bedrock_embedding_model() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + return; + } + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + ModelTensorOutput output = predictRemoteModel(bedrockEmbeddingModelId, mlInput); + assertEquals(2, output.getMlModelOutputs().size()); + assertEquals(1536, output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length); + } +} From 9774a2c6e6a34da6921bd5369b8819858595f490 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 15:10:11 +0800 Subject: [PATCH 04/11] add ITs Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 8 +- .../ml/rest/RestBedRockInferenceIT.java | 89 ++++++------------- .../templates/BedRockConnectorBodies.json | 63 +++++++++++++ 3 files changed, 97 insertions(+), 63 deletions(-) create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 6c6a2f8195..b6ca54e9ce 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -913,10 +913,12 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } - public ModelTensorOutput predictRemoteModel(String modelId, MLInput input) throws IOException { + public Map predictRemoteModel(String modelId, MLInput input) throws IOException { + String requestBody = TestHelper.toJsonString(input); + System.out.println("############################## request body is:" + requestBody); Response response = TestHelper - .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, TestHelper.toJsonString(input), null); - return new ModelTensorOutput(StreamInput.wrap(response.getEntity().getContent().readAllBytes())); + .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + return parseResponseToMap(response); } public Consumer> verifyTextEmbeddingModelDeployed() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 2811e85e6c..4d831ef032 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -5,80 +5,31 @@ package org.opensearch.ml.rest; -import com.google.common.collect.ImmutableList; -import com.jayway.jsonpath.JsonPath; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.message.BasicHeader; -import org.junit.Assert; +import lombok.SneakyThrows; import org.junit.Before; -import org.opensearch.client.Request; -import org.opensearch.client.Response; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.utils.TestHelper; -import org.w3c.dom.Text; +import org.opensearch.ml.common.utils.StringUtils; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.List; +import java.util.Locale; import java.util.Map; -import static org.opensearch.ml.utils.TestHelper.makeRequest; - public class RestBedRockInferenceIT extends MLCommonsRestTestCase { - private String bedrockEmbeddingModelId; private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); private static final String GITHUB_CI_AWS_REGION = "us-west-2"; - private final String bedrockEmbeddingModelConnectorEntity = "{\n" - + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" - + " \"description\": \"The connector to bedrock Titan embedding model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"" - + GITHUB_CI_AWS_REGION - + "\",\n" - + " \"service_name\": \"bedrock\",\n" - + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"" - + AWS_ACCESS_KEY_ID - + "\",\n" - + " \"secret_key\": \"" - + AWS_SECRET_ACCESS_KEY - + "\",\n" - + " \"session_token\": \"" - + AWS_SESSION_TOKEN - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-amz-content-sha256\": \"required\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" - + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" - + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" - + " }\n" - + " ]\n" - + "}"; - + @SneakyThrows @Before public void setup() throws IOException, InterruptedException { RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); Thread.sleep(20000); - String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); - this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); } @@ -87,10 +38,28 @@ public void test_bedrock_embedding_model() throws Exception { if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { return; } - TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); - MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - ModelTensorOutput output = predictRemoteModel(bedrockEmbeddingModelId, mlInput); - assertEquals(2, output.getMlModelOutputs().size()); - assertEquals(1536, output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length); + String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI())); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictRemoteModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 2, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json new file mode 100644 index 0000000000..9dc1ed69fd --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json @@ -0,0 +1,63 @@ +{ + "without_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "with_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1", + "input_docs_processed_step_size": 1 + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +} From 613364d27fc4eb2e9b52200d4190cccda5b2c857 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 15:13:01 +0800 Subject: [PATCH 05/11] format code Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 3 -- .../ml/rest/RestBedRockInferenceIT.java | 40 ++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index b6ca54e9ce..482f38bc4a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -67,7 +67,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.rest.SecureRestClientBuilder; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaType; @@ -88,7 +87,6 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -915,7 +913,6 @@ public Map predictTextEmbedding(String modelId) throws IOException { public Map predictRemoteModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); - System.out.println("############################## request body is:" + requestBody); Response response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); return parseResponseToMap(response); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 4d831ef032..cda9ae54fa 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -5,13 +5,6 @@ package org.opensearch.ml.rest; -import lombok.SneakyThrows; -import org.junit.Before; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.utils.StringUtils; - import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; @@ -19,6 +12,14 @@ import java.util.Locale; import java.util.Map; +import org.junit.Before; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.SneakyThrows; + public class RestBedRockInferenceIT extends MLCommonsRestTestCase { private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); @@ -32,19 +33,38 @@ public void setup() throws IOException, InterruptedException { Thread.sleep(20000); } - public void test_bedrock_embedding_model() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { return; } - String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI())); + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") + .toURI() + ) + ); Map templateMap = StringUtils.gson.fromJson(templates, Map.class); for (Map.Entry templateEntry : templateMap.entrySet()) { String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); String testCaseName = templateEntry.getKey(); String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); - String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); From 7209ad683e92b4bed09abbd132ab81583996b9c1 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 16:07:16 +0800 Subject: [PATCH 06/11] change input to fix number format exception in local Signed-off-by: zane-neo --- .../opensearch/ml/rest/templates/BedRockConnectorBodies.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json index 9dc1ed69fd..5b75b5ab72 100644 --- a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json @@ -38,7 +38,7 @@ "region": "%s", "service_name": "bedrock", "model_name": "amazon.titan-embed-text-v1", - "input_docs_processed_step_size": 1 + "input_docs_processed_step_size": "1" }, "credential": { "access_key": "%s", From 5d4fe505cf2d85cee9496a0211088beb45ec8f5f Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 16:35:10 +0800 Subject: [PATCH 07/11] Add log to identify the failure IT root cause Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/MLCommonsRestTestCase.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 482f38bc4a..3ac7743852 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -38,6 +38,7 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import lombok.extern.log4j.Log4j2; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; @@ -103,6 +104,7 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; +@Log4j2 public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase { protected Gson gson = new Gson(); public static long CUSTOM_MODEL_TIMEOUT = 20_000; // 20 seconds @@ -856,6 +858,7 @@ public static Map parseResponseToMap(Response response) throws IOException { HttpEntity entity = response.getEntity(); assertNotNull(response); String entityString = TestHelper.httpEntityToString(entity); + log.info("response: {}", entityString); return StringUtils.gson.fromJson(entityString, Map.class); } From 22f5bf7b6d5760c59b5e22c994d9e494b4d47da6 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 16:45:02 +0800 Subject: [PATCH 08/11] format code Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/MLCommonsRestTestCase.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 3ac7743852..d44874fe10 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -38,7 +38,6 @@ import java.util.function.Consumer; import java.util.stream.Collectors; -import lombok.extern.log4j.Log4j2; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; @@ -104,6 +103,8 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; +import lombok.extern.log4j.Log4j2; + @Log4j2 public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase { protected Gson gson = new Gson(); From aae1ab0ccf18c30946af78517c7694ff0edc57aa Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 6 Jun 2024 11:13:15 +0800 Subject: [PATCH 09/11] remove debug log Signed-off-by: zane-neo --- .../java/org/opensearch/ml/rest/MLCommonsRestTestCase.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index d44874fe10..482f38bc4a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -103,9 +103,6 @@ import com.google.gson.Gson; import com.google.gson.JsonArray; -import lombok.extern.log4j.Log4j2; - -@Log4j2 public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase { protected Gson gson = new Gson(); public static long CUSTOM_MODEL_TIMEOUT = 20_000; // 20 seconds @@ -859,7 +856,6 @@ public static Map parseResponseToMap(Response response) throws IOException { HttpEntity entity = response.getEntity(); assertNotNull(response); String entityString = TestHelper.httpEntityToString(entity); - log.info("response: {}", entityString); return StringUtils.gson.fromJson(entityString, Map.class); } From 35e81339420b75a76fde5f36cffdeed53272f992 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 11 Jun 2024 08:58:43 +0800 Subject: [PATCH 10/11] Update plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java Co-authored-by: Yaliang Wu Signed-off-by: zane-neo --- .../test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 482f38bc4a..886494de3c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -911,7 +911,7 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } - public Map predictRemoteModel(String modelId, MLInput input) throws IOException { + public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); Response response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); From a9d09636d980987f97fa87169c618b07955019d1 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 11 Jun 2024 09:45:03 +0800 Subject: [PATCH 11/11] address comments Signed-off-by: zane-neo --- .../remote/AwsConnectorExecutorTest.java | 16 +++++++++++---- .../ml/rest/RestBedRockInferenceIT.java | 20 ++++++++++++------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index e35cfa4cef..98d5feb7ba 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -634,7 +634,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -645,7 +645,11 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test @@ -669,7 +673,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -680,7 +684,11 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index cda9ae54fa..fea981afe7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -68,18 +68,24 @@ public void test_bedrock_embedding_model() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - Map inferenceResult = predictRemoteModel(modelId, mlInput); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 2, output.size()); assertTrue(errorMsg, output.get(0) instanceof Map); - assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); - List outputList = (List) ((Map) output.get(0)).get("output"); - assertEquals(errorMsg, 1, outputList.size()); - assertTrue(errorMsg, outputList.get(0) instanceof Map); - assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); - assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + assertTrue(errorMsg, output.get(1) instanceof Map); + validateOutput(errorMsg, (Map) output.get(0)); + validateOutput(errorMsg, (Map) output.get(1)); } + } + private void validateOutput(String errorMsg, Map output) { + assertTrue(errorMsg, output.containsKey("output")); + assertTrue(errorMsg, output.get("output") instanceof List); + List outputList = (List) output.get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); } }