Skip to content

Commit

Permalink
add postprocess function for batch inferernce jobArn
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Sep 13, 2024
1 parent a060cfa commit 4490b8f
Showing 1 changed file with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction<List<Float>> {
public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction<Map<String, String>> {
public static final String JOB_ARN = "jobArn";
public static final String PROCESSED_JOB_ARN = "processedJobArn";

@Override
public void validate(Object input) {
Expand All @@ -22,22 +25,25 @@ public void validate(Object input) {

Map<String, String> jobInfo = (Map<String, String>) input;

if (!(jobInfo.containsKey("jobArn"))) {
if (!(jobInfo.containsKey(JOB_ARN))) {
throw new IllegalArgumentException("Bedrock batch job arn missing.");
}
}

@Override
public List<ModelTensor> process(List<Float> embedding) {
public List<ModelTensor> process(Map<String, String> jobInfo) {
List<ModelTensor> modelTensors = new ArrayList<>();

Map<String, String> processedResult = new HashMap<>();
processedResult.putAll(jobInfo);
String jobArn = jobInfo.get(JOB_ARN);
processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F"));
modelTensors
.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[] { embedding.size() })
.data(embedding.toArray(new Number[0]))
.name("response")
.dataAsMap(processedResult)
.build()
);
return modelTensors;
Expand Down

0 comments on commit 4490b8f

Please sign in to comment.