diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java index 7ecaf15f32..88f7cad061 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java @@ -9,10 +9,13 @@ import org.opensearch.ml.common.output.model.ModelTensor; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; -public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction> { +public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction> { + public static final String JOB_ARN = "jobArn"; + public static final String PROCESSED_JOB_ARN = "processedJobArn"; @Override public void validate(Object input) { @@ -22,22 +25,25 @@ public void validate(Object input) { Map jobInfo = (Map) input; - if (!(jobInfo.containsKey("jobArn"))) { + if (!(jobInfo.containsKey(JOB_ARN))) { throw new IllegalArgumentException("Bedrock batch job arn missing."); } } @Override - public List process(List embedding) { + public List process(Map jobInfo) { List modelTensors = new ArrayList<>(); + + Map processedResult = new HashMap<>(); + processedResult.putAll(jobInfo); + String jobArn = jobInfo.get(JOB_ARN); + processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F")); modelTensors .add( ModelTensor .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[] { embedding.size() }) - .data(embedding.toArray(new Number[0])) + .name("response") + .dataAsMap(processedResult) .build() ); return modelTensors;