diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 5ba465b15a..9e4b2301bb 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -10,6 +10,7 @@ import java.util.Map; import java.util.function.Function; +import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; @@ -20,6 +21,7 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; + public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn"; public static final String COHERE_RERANK = "connector.post_process.cohere.rerank"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; @@ -31,6 +33,7 @@ public class MLPostProcessFunction { static { EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction(); BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction(); + BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction(); CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); @@ -42,6 +45,7 @@ public class MLPostProcessFunction { POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java new file mode 100644 index 0000000000..3b42f1c089 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; +import java.util.List; + +public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + + List outerList = (List) input; + + if (!outerList.isEmpty() && !(((List) input).get(0) instanceof Number)) { + throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); + } + } + + @Override + public List process(List embedding) { + List modelTensors = new ArrayList<>(); + modelTensors + .add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[] { embedding.size() }) + .data(embedding.toArray(new Number[0])) + .build() + ); + return modelTensors; + } +}