From c19785c3d7a9a0c8749ff15fbfd20ba505241746 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 4 Jun 2024 15:56:45 +0800 Subject: [PATCH 01/10] Fix bedrock connector embedding generation issue Signed-off-by: zane-neo --- .../remote/RemoteConnectorExecutor.java | 12 ++-- .../remote/AwsConnectorExecutorTest.java | 71 +++++++++++++++++++ 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 18dda5eadd..f75e7ac398 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -94,7 +94,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) { @@ -118,11 +118,15 @@ private Tuple calculateChunkSize(String action, TextDocsInputD throw new IllegalArgumentException("no " + action + " action found"); } String preProcessFunction = connectorAction.get().getPreProcessFunction(); - if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) { - // user defined preprocess script, this case, the chunk size is always equals to text docs length. + if (preProcessFunction == null) { + // default preprocess case, consider this a batch. + return Tuple.tuple(1, textDocsLength); + } else if (MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction) + || !MLPreProcessFunction.contains(preProcessFunction)) { + // bedrock and user defined preprocess script, the chunk size is always equals to text docs length. return Tuple.tuple(textDocsLength, 1); } - // consider as batch. + //Other cases: non-bedrock and user defined preprocess script, consider as batch. return Tuple.tuple(1, textDocsLength); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index cb192e83f9..e35cfa4cef 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -612,6 +612,77 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre ); } + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddingPreProcessFunction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("aws_sigv4") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(executor.getScriptService()).thenReturn(scriptService); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + } + + @Test + public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreprocessFunction() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "bedrock"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("aws_sigv4") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(executor.getScriptService()).thenReturn(scriptService); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + } + @Test public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { ConnectorAction predictAction = ConnectorAction From 3d44bc8d795d4a8b262c081beb6192e9e33f707c Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 4 Jun 2024 16:09:57 +0800 Subject: [PATCH 02/10] format code Signed-off-by: zane-neo --- .../ml/engine/algorithms/remote/RemoteConnectorExecutor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index f75e7ac398..28b7617103 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -126,7 +126,7 @@ private Tuple calculateChunkSize(String action, TextDocsInputD // bedrock and user defined preprocess script, the chunk size is always equals to text docs length. return Tuple.tuple(textDocsLength, 1); } - //Other cases: non-bedrock and user defined preprocess script, consider as batch. + // Other cases: non-bedrock and user defined preprocess script, consider as batch. return Tuple.tuple(1, textDocsLength); } } From f16fe9491e74a326c7c595d841589c4496abcea8 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 12:17:11 +0800 Subject: [PATCH 03/10] add IT Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 8 ++ .../ml/rest/RestBedRockInferenceIT.java | 96 +++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index ff2648f74e..0087c5b2d5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -60,6 +60,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaType; @@ -80,6 +81,7 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -896,6 +898,12 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } + public ModelTensorOutput predictRemoteModel(String modelId, MLInput input) throws IOException { + Response response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, TestHelper.toJsonString(input), null); + return new ModelTensorOutput(StreamInput.wrap(response.getEntity().getContent().readAllBytes())); + } + public Consumer> verifyTextEmbeddingModelDeployed() { return (modelProfile) -> { if (modelProfile.containsKey("model_state")) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java new file mode 100644 index 0000000000..2811e85e6c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import com.google.common.collect.ImmutableList; +import com.jayway.jsonpath.JsonPath; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.InputDataSet; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.utils.TestHelper; +import org.w3c.dom.Text; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.opensearch.ml.utils.TestHelper.makeRequest; + +public class RestBedRockInferenceIT extends MLCommonsRestTestCase { + private String bedrockEmbeddingModelId; + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private final String bedrockEmbeddingModelConnectorEntity = "{\n" + + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" + + " \"description\": \"The connector to bedrock Titan embedding model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" + + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" + + " }\n" + + " ]\n" + + "}"; + + @Before + public void setup() throws IOException, InterruptedException { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); + } + + + public void test_bedrock_embedding_model() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + return; + } + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + ModelTensorOutput output = predictRemoteModel(bedrockEmbeddingModelId, mlInput); + assertEquals(2, output.getMlModelOutputs().size()); + assertEquals(1536, output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length); + } +} From 46fbff61a2c64e1c33a0a7c9656121d5fb25dce4 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 15:10:11 +0800 Subject: [PATCH 04/10] add ITs Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 8 +- .../ml/rest/RestBedRockInferenceIT.java | 89 ++++++------------- .../templates/BedRockConnectorBodies.json | 63 +++++++++++++ 3 files changed, 97 insertions(+), 63 deletions(-) create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 0087c5b2d5..7ae58a2479 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -898,10 +898,12 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } - public ModelTensorOutput predictRemoteModel(String modelId, MLInput input) throws IOException { + public Map predictRemoteModel(String modelId, MLInput input) throws IOException { + String requestBody = TestHelper.toJsonString(input); + System.out.println("############################## request body is:" + requestBody); Response response = TestHelper - .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, TestHelper.toJsonString(input), null); - return new ModelTensorOutput(StreamInput.wrap(response.getEntity().getContent().readAllBytes())); + .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + return parseResponseToMap(response); } public Consumer> verifyTextEmbeddingModelDeployed() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 2811e85e6c..4d831ef032 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -5,80 +5,31 @@ package org.opensearch.ml.rest; -import com.google.common.collect.ImmutableList; -import com.jayway.jsonpath.JsonPath; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.message.BasicHeader; -import org.junit.Assert; +import lombok.SneakyThrows; import org.junit.Before; -import org.opensearch.client.Request; -import org.opensearch.client.Response; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.utils.TestHelper; -import org.w3c.dom.Text; +import org.opensearch.ml.common.utils.StringUtils; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.List; +import java.util.Locale; import java.util.Map; -import static org.opensearch.ml.utils.TestHelper.makeRequest; - public class RestBedRockInferenceIT extends MLCommonsRestTestCase { - private String bedrockEmbeddingModelId; private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); private static final String GITHUB_CI_AWS_REGION = "us-west-2"; - private final String bedrockEmbeddingModelConnectorEntity = "{\n" - + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" - + " \"description\": \"The connector to bedrock Titan embedding model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"" - + GITHUB_CI_AWS_REGION - + "\",\n" - + " \"service_name\": \"bedrock\",\n" - + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"" - + AWS_ACCESS_KEY_ID - + "\",\n" - + " \"secret_key\": \"" - + AWS_SECRET_ACCESS_KEY - + "\",\n" - + " \"session_token\": \"" - + AWS_SESSION_TOKEN - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-amz-content-sha256\": \"required\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" - + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" - + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" - + " }\n" - + " ]\n" - + "}"; - + @SneakyThrows @Before public void setup() throws IOException, InterruptedException { RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); Thread.sleep(20000); - String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); - this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); } @@ -87,10 +38,28 @@ public void test_bedrock_embedding_model() throws Exception { if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { return; } - TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); - MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - ModelTensorOutput output = predictRemoteModel(bedrockEmbeddingModelId, mlInput); - assertEquals(2, output.getMlModelOutputs().size()); - assertEquals(1536, output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length); + String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI())); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictRemoteModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 2, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json new file mode 100644 index 0000000000..9dc1ed69fd --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json @@ -0,0 +1,63 @@ +{ + "without_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "with_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1", + "input_docs_processed_step_size": 1 + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +} From 9a75bc00b5bc3405453c6dee4b315cd3131d3799 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 15:13:01 +0800 Subject: [PATCH 05/10] format code Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 3 -- .../ml/rest/RestBedRockInferenceIT.java | 40 ++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 7ae58a2479..519487e05f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -60,7 +60,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.rest.SecureRestClientBuilder; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaType; @@ -81,7 +80,6 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -900,7 +898,6 @@ public Map predictTextEmbedding(String modelId) throws IOException { public Map predictRemoteModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); - System.out.println("############################## request body is:" + requestBody); Response response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); return parseResponseToMap(response); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 4d831ef032..cda9ae54fa 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -5,13 +5,6 @@ package org.opensearch.ml.rest; -import lombok.SneakyThrows; -import org.junit.Before; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.utils.StringUtils; - import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; @@ -19,6 +12,14 @@ import java.util.Locale; import java.util.Map; +import org.junit.Before; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.SneakyThrows; + public class RestBedRockInferenceIT extends MLCommonsRestTestCase { private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); @@ -32,19 +33,38 @@ public void setup() throws IOException, InterruptedException { Thread.sleep(20000); } - public void test_bedrock_embedding_model() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { return; } - String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI())); + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") + .toURI() + ) + ); Map templateMap = StringUtils.gson.fromJson(templates, Map.class); for (Map.Entry templateEntry : templateMap.entrySet()) { String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); String testCaseName = templateEntry.getKey(); String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); - String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); From 4aff887ede36c2605f9d7302145f8f29e7b8a9ec Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 16:07:16 +0800 Subject: [PATCH 06/10] change input to fix number format exception in local Signed-off-by: zane-neo --- .../opensearch/ml/rest/templates/BedRockConnectorBodies.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json index 9dc1ed69fd..5b75b5ab72 100644 --- a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json @@ -38,7 +38,7 @@ "region": "%s", "service_name": "bedrock", "model_name": "amazon.titan-embed-text-v1", - "input_docs_processed_step_size": 1 + "input_docs_processed_step_size": "1" }, "credential": { "access_key": "%s", From ec13eea8b163224172f1550ae31c16b1aaf2728a Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 5 Jun 2024 16:35:10 +0800 Subject: [PATCH 07/10] Add log to identify the failure IT root cause Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 519487e05f..b74048cbea 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -5,6 +5,8 @@ package org.opensearch.ml.rest; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; @@ -36,18 +38,23 @@ import java.util.function.Consumer; import java.util.stream.Collectors; -import org.apache.http.Header; -import org.apache.http.HttpEntity; -import org.apache.http.HttpHeaders; -import org.apache.http.HttpHost; -import org.apache.http.auth.AuthScope; -import org.apache.http.auth.UsernamePasswordCredentials; -import org.apache.http.client.CredentialsProvider; -import org.apache.http.conn.ssl.NoopHostnameVerifier; -import org.apache.http.impl.client.BasicCredentialsProvider; -import org.apache.http.message.BasicHeader; -import org.apache.http.ssl.SSLContextBuilder; -import org.apache.http.util.EntityUtils; +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; +import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.nio.ssl.TlsStrategy; +import org.apache.hc.core5.ssl.SSLContextBuilder; +import org.apache.hc.core5.util.Timeout; import org.junit.After; import org.junit.Before; import org.opensearch.client.Request; From 18ad3208e783b9b13c0e59e2014ef0af207ea953 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 11 Jun 2024 08:58:43 +0800 Subject: [PATCH 08/10] Update plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java Co-authored-by: Yaliang Wu Signed-off-by: zane-neo --- .../test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index b74048cbea..083a72daf2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -903,7 +903,7 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } - public Map predictRemoteModel(String modelId, MLInput input) throws IOException { + public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); Response response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); From 624b415753ab6e79fb4b9eacb74c90df99790305 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 11 Jun 2024 09:45:03 +0800 Subject: [PATCH 09/10] address comments Signed-off-by: zane-neo --- .../remote/AwsConnectorExecutorTest.java | 16 +++++++++++---- .../ml/rest/RestBedRockInferenceIT.java | 20 ++++++++++++------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index e35cfa4cef..98d5feb7ba 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -634,7 +634,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -645,7 +645,11 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_bedRockEmbeddi MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test @@ -669,7 +673,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -680,7 +684,11 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPreproces MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index cda9ae54fa..fea981afe7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -68,18 +68,24 @@ public void test_bedrock_embedding_model() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - Map inferenceResult = predictRemoteModel(modelId, mlInput); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 2, output.size()); assertTrue(errorMsg, output.get(0) instanceof Map); - assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); - List outputList = (List) ((Map) output.get(0)).get("output"); - assertEquals(errorMsg, 1, outputList.size()); - assertTrue(errorMsg, outputList.get(0) instanceof Map); - assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); - assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + assertTrue(errorMsg, output.get(1) instanceof Map); + validateOutput(errorMsg, (Map) output.get(0)); + validateOutput(errorMsg, (Map) output.get(1)); } + } + private void validateOutput(String errorMsg, Map output) { + assertTrue(errorMsg, output.containsKey("output")); + assertTrue(errorMsg, output.get("output") instanceof List); + List outputList = (List) output.get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); } } From ccc48b4801f08500391e6d4e0f1bc151a4b93a3c Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 12 Jun 2024 08:35:08 +0800 Subject: [PATCH 10/10] fix backport incompatibility Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 083a72daf2..1b9c29279f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -5,8 +5,6 @@ package org.opensearch.ml.rest; -import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; -import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; @@ -38,23 +36,18 @@ import java.util.function.Consumer; import java.util.stream.Collectors; -import org.apache.hc.client5.http.auth.AuthScope; -import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; -import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; -import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager; -import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; -import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; -import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; -import org.apache.hc.core5.http.Header; -import org.apache.hc.core5.http.HttpEntity; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.ParseException; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.message.BasicHeader; -import org.apache.hc.core5.http.nio.ssl.TlsStrategy; -import org.apache.hc.core5.ssl.SSLContextBuilder; -import org.apache.hc.core5.util.Timeout; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.HttpHost; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.message.BasicHeader; +import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.util.EntityUtils; import org.junit.After; import org.junit.Before; import org.opensearch.client.Request;