Skip to content

Commit

Permalink
add preprocess function for batch inferernce jobArn
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Sep 13, 2024
1 parent ce593fa commit 3642d05
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.BedrockGetBatchInferenceJobPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
Expand All @@ -24,6 +25,7 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
public static final String BEDROCK_GET_BATCH_INFERENCE_JOB_INPUT = "connector.pre_process.bedrock.get_batch_inference_job";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";
Expand All @@ -37,8 +39,10 @@ public class MLPreProcessFunction {
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
BedrockGetBatchInferenceJobPreProcessFunction bedrockGetBatchInferenceJobPreProcessFunction = new BedrockGetBatchInferenceJobPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(BEDROCK_GET_BATCH_INFERENCE_JOB_INPUT, bedrockGetBatchInferenceJobPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

/**
* This class provides a pre-processing function for bedrock batch inference job input.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link RemoteInferenceInputDataSet}, which must have jobArn parameter.
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
*/
public class BedrockGetBatchInferenceJobPreProcessFunction extends ConnectorPreProcessFunction {

public static final String JOB_ARN = "jobArn";

public BedrockGetBatchInferenceJobPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = false;
}

@Override
public void validate(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet)) {
throw new IllegalArgumentException("Wrong input dataset type");
}
RemoteInferenceInputDataSet inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
if (inputData == null) {
throw new IllegalArgumentException("No input dataset provided");
}
String jobArn = inputData.getParameters().get(JOB_ARN);
if (jobArn == null) {
throw new IllegalArgumentException("No jobArn provided");
}
}

/**
* @param mlInput The input data to be processed.
* This method is to escape slash in jobArn.
*/
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
RemoteInferenceInputDataSet inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset();
inputData.getParameters().computeIfPresent(JOB_ARN, (k, jobArn) -> jobArn.replace("/", "%2F"));
return inputData;
}
}

0 comments on commit 3642d05

Please sign in to comment.