From 3642d05bb60fe8f72a69ea258940612d5c292961 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 13 Sep 2024 14:07:16 -0700 Subject: [PATCH] add preprocess function for batch inferernce jobArn Signed-off-by: Yaliang Wu --- .../connector/MLPreProcessFunction.java | 4 ++ ...etBatchInferenceJobPreProcessFunction.java | 59 +++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 3a5a3427a8..47fdf7c980 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -10,6 +10,7 @@ import java.util.function.Function; import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.BedrockGetBatchInferenceJobPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; @@ -24,6 +25,7 @@ public class MLPreProcessFunction { public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding"; + public static final String BEDROCK_GET_BATCH_INFERENCE_JOB_INPUT = "connector.pre_process.bedrock.get_batch_inference_job"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank"; public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank"; @@ -37,8 +39,10 @@ public class MLPreProcessFunction { BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); + BedrockGetBatchInferenceJobPreProcessFunction bedrockGetBatchInferenceJobPreProcessFunction = new BedrockGetBatchInferenceJobPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(BEDROCK_GET_BATCH_INFERENCE_JOB_INPUT, bedrockGetBatchInferenceJobPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java new file mode 100644 index 0000000000..2a277ecf49 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java @@ -0,0 +1,59 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + +/** + * This class provides a pre-processing function for bedrock batch inference job input. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link RemoteInferenceInputDataSet}, which must have jobArn parameter. + * The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + */ +public class BedrockGetBatchInferenceJobPreProcessFunction extends ConnectorPreProcessFunction { + + public static final String JOB_ARN = "jobArn"; + + public BedrockGetBatchInferenceJobPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = false; + } + + @Override + public void validate(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet)) { + throw new IllegalArgumentException("Wrong input dataset type"); + } + RemoteInferenceInputDataSet inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + if (inputData == null) { + throw new IllegalArgumentException("No input dataset provided"); + } + String jobArn = inputData.getParameters().get(JOB_ARN); + if (jobArn == null) { + throw new IllegalArgumentException("No jobArn provided"); + } + } + + /** + * @param mlInput The input data to be processed. + * This method is to escape slash in jobArn. + */ + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + RemoteInferenceInputDataSet inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + inputData.getParameters().computeIfPresent(JOB_ARN, (k, jobArn) -> jobArn.replace("/", "%2F")); + return inputData; + } +}