From 3f7a00f3d675bc19264c5342664c4488b7a30392 Mon Sep 17 00:00:00 2001 From: tkykenmt Date: Sat, 4 Jan 2025 23:20:43 +0900 Subject: [PATCH 1/5] Add pre and post process functions for Bedrock Rerank API #3254 Signed-off-by: tkykenmt --- .../connector/MLPostProcessFunction.java | 5 ++ .../connector/MLPreProcessFunction.java | 4 + .../BedrockRerankPostProcessFunction.java | 69 ++++++++++++++++++ .../BedrockRerankPreProcessFunction.java | 50 +++++++++++++ .../BedrockRerankPostProcessFunctionTest.java | 70 ++++++++++++++++++ .../BedrockRerankPreProcessFunctionTest.java | 73 +++++++++++++++++++ 6 files changed, 271 insertions(+) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunctionTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index abe56cde0e..9b94593dfc 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -12,6 +12,7 @@ import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; import org.opensearch.ml.common.output.model.ModelTensor; @@ -23,6 +24,7 @@ public class MLPostProcessFunction { public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn"; public static final String COHERE_RERANK = "connector.post_process.cohere.rerank"; + public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; @@ -35,12 +37,14 @@ public class MLPostProcessFunction { BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction(); BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction(); CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); + BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$"); JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results"); + JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); @@ -48,6 +52,7 @@ public class MLPostProcessFunction { POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 723da8c07d..e781e69e5e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -10,6 +10,7 @@ import java.util.function.Function; import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.BedrockRerankPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; @@ -28,6 +29,7 @@ public class MLPreProcessFunction { public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank"; + public static final String TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT = "connector.pre_process.bedrock.rerank"; public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank"; public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input"; @@ -38,6 +40,7 @@ public class MLPreProcessFunction { OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); + BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction(); MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction = new CohereMultiModalEmbeddingPreProcessFunction(); @@ -49,6 +52,7 @@ public class MLPreProcessFunction { PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT, bedrockRerankPreProcessFunction); } public static boolean contains(String functionName) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java new file mode 100644 index 0000000000..e740b91ab2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +public class BedrockRerankPostProcessFunction extends ConnectorPostProcessFunction>> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + List outerList = (List) input; + if (!outerList.isEmpty()) { + if (!(outerList.get(0) instanceof Map)) { + throw new IllegalArgumentException("Post process function input is not a List of Map."); + } + Map innerMap = (Map) outerList.get(0); + + if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevanceScore")) { + throw new IllegalArgumentException("The rerank result should contain index and relevanceScore."); + } + } + } + + @Override + public List process(List> rerankResults) { + List modelTensors = new ArrayList<>(); + + if (rerankResults.size() > 0) { + Double[] scores = new Double[rerankResults.size()]; + for (int i = 0; i < rerankResults.size(); i++) { + Integer index = (Integer) rerankResults.get(i).get("index"); + Object relevanceScore = rerankResults.get(i).get("relevanceScore"); + scores[index] = switch (relevanceScore) { + case BigDecimal bd -> bd.doubleValue(); + case Double d -> d; + case null -> throw new IllegalArgumentException("relevanceScore is null"); + default -> throw new IllegalArgumentException("Unexpected type for relevanceScore: " + + relevanceScore.getClass().getName()); + }; + } + + for (int i = 0; i < scores.length; i++) { + modelTensors + .add( + ModelTensor + .builder() + .name("similarity") + .shape(new long[] { 1 }) + .data(new Number[] { scores[i] }) + .dataType(MLResultDataType.FLOAT32) + .build() + ); + } + } + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java new file mode 100644 index 0000000000..8db0c4fece --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +public class BedrockRerankPreProcessFunction extends ConnectorPreProcessFunction { + + public BedrockRerankPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) { + throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet"); + } + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset(); + String queryText = inputData.getQueryText(); + List textDocs = inputData.getTextDocs(); + + List> queries = new ArrayList>(); + queries.add(Map.of("textQuery", Map.of("text", queryText), "type", "TEXT")); + + List> sources = new ArrayList>(); + inputData.getTextDocs().forEach(textDoc -> { + sources.add(Map.of("inlineDocumentSource", Map.of("textDocument", Map.of("text", textDoc), "type", "TEXT"), "type", "INLINE")); + }); + + Map processedResult = Map + .of("parameters", Map.of("queries", queries, "sources", sources, "numberOfResults", textDocs.size())); + + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java new file mode 100644 index 0000000000..1007e2f974 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import static org.junit.Assert.assertEquals; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.ModelTensor; + +public class BedrockRerankPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + BedrockRerankPostProcessFunction function; + + @Before + public void setUp() { + function = new BedrockRerankPostProcessFunction(); + } + + @Test + public void process_WrongInput_NotList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is not a List."); + function.apply("abc"); + } + + @Test + public void process_WrongInput_NotCorrectList() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is not a List of Map."); + function.apply(Arrays.asList("abc")); + } + + @Test + public void process_WrongInput_NotCorrectMap() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("The rerank result should contain index and relevance_score."); + function.apply(Arrays.asList(Map.of("test1", "value1"))); + } + + @Test + public void process_CorrectInput() { + List> rerankResults = List + .of( + Map.of("index", 2, "relevanceScore", 0.7711548805236816), + Map.of("index", 0, "relevanceScore", 0.0025114635936915874), + Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05), + Map.of("index", 3, "relevanceScore", 6.339210358419223e-06) + ); + List result = function.apply(rerankResults); + assertEquals(4, result.size()); + assertEquals(1, result.get(0).getData().length); + assertEquals(0.0025114635936915874, result.get(0).getData()[0]); + assertEquals(2.4876489987946115e-05, result.get(1).getData()[0]); + assertEquals(0.7711548805236816, result.get(2).getData()[0]); + assertEquals(6.339210358419223e-06, result.get(3).getData()[0]); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunctionTest.java new file mode 100644 index 0000000000..6d57148146 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunctionTest.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import org.json.JSONArray; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +public class BedrockRerankPreProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + BedrockRerankPreProcessFunction function; + + TextSimilarityInputDataSet textSimilarityInputDataSet; + TextDocsInputDataSet textDocsInputDataSet; + + @Before + public void setUp() { + function = new BedrockRerankPreProcessFunction(); + textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); + textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); + } + + @Test + public void process_NullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Preprocess function input can't be null"); + function.apply(null); + } + + @Test + public void process_WrongInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("This pre_process_function can only support TextSimilarityInputDataSet"); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + function.apply(mlInput); + } + + @Test + public void process_CorrectInput() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(mlInput); + assertEquals(3, dataSet.getParameters().size()); + + JSONArray expectedSources = new JSONArray( + "[{\"type\": \"INLINE\", \"inlineDocumentSource\": {\"type\": \"TEXT\", \"textDocument\": {\"text\": \"hello\"}}}]" + ); + JSONArray actualSources = new JSONArray(dataSet.getParameters().get("sources")); + assertTrue(expectedSources.getJSONObject(0).similar(actualSources.getJSONObject(0))); + + JSONArray expectedQueries = new JSONArray("[{\"textQuery\": {\"text\": \"test\"}, \"type\": \"TEXT\"}]"); + JSONArray actualQueries = new JSONArray(dataSet.getParameters().get("queries")); + assertTrue(expectedQueries.getJSONObject(0).similar(actualQueries.getJSONObject(0))); + + assertEquals("1", dataSet.getParameters().get("numberOfResults")); + } +} From 316de2c96fa40042a2771caf81744f74a292a869 Mon Sep 17 00:00:00 2001 From: tkykenmt Date: Wed, 8 Jan 2025 08:12:27 +0900 Subject: [PATCH 2/5] modify format using spotlessApply Signed-off-by: tkykenmt --- .../postprocess/BedrockRerankPostProcessFunction.java | 5 +++-- .../postprocess/BedrockRerankPostProcessFunctionTest.java | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java index e740b91ab2..8449253431 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java @@ -46,8 +46,9 @@ public List process(List> rerankResults) { case BigDecimal bd -> bd.doubleValue(); case Double d -> d; case null -> throw new IllegalArgumentException("relevanceScore is null"); - default -> throw new IllegalArgumentException("Unexpected type for relevanceScore: " + - relevanceScore.getClass().getName()); + default -> throw new IllegalArgumentException( + "Unexpected type for relevanceScore: " + relevanceScore.getClass().getName() + ); }; } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java index 1007e2f974..1f92076820 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java @@ -7,7 +7,6 @@ import static org.junit.Assert.assertEquals; -import java.math.BigDecimal; import java.util.Arrays; import java.util.List; import java.util.Map; From 4af169af75b4e464d8ec7cc9d22e7f1f616d673e Mon Sep 17 00:00:00 2001 From: tkykenmt Date: Wed, 8 Jan 2025 15:15:07 +0900 Subject: [PATCH 3/5] Fix on validation/converting scores #3339 Signed-off-by: tkykenmt --- .../BedrockRerankPostProcessFunction.java | 54 +++++++++++-------- .../BedrockRerankPreProcessFunction.java | 5 ++ .../BedrockRerankPostProcessFunctionTest.java | 41 ++++++++++++-- 3 files changed, 76 insertions(+), 24 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java index 8449253431..b0643ae976 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java @@ -17,18 +17,34 @@ public class BedrockRerankPostProcessFunction extends ConnectorPostProcessFuncti @Override public void validate(Object input) { + if (!(input instanceof List)) { throw new IllegalArgumentException("Post process function input is not a List."); } + List outerList = (List) input; - if (!outerList.isEmpty()) { - if (!(outerList.get(0) instanceof Map)) { - throw new IllegalArgumentException("Post process function input is not a List of Map."); + + if (outerList.isEmpty()) { + throw new IllegalArgumentException("Post process function input is empty."); + } + + for (Object item : outerList) { + if (!(item instanceof Map)) { + throw new IllegalArgumentException("Rerank result is not a Map."); + } + + Map innerMap = (Map) item; + + if (innerMap.isEmpty()) { + throw new IllegalArgumentException("Rerank result is empty."); + } + + if (!innerMap.containsKey("index") || !innerMap.containsKey("relevanceScore")) { + throw new IllegalArgumentException("Rerank result should have both index and relevanceScore."); } - Map innerMap = (Map) outerList.get(0); - if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevanceScore")) { - throw new IllegalArgumentException("The rerank result should contain index and relevanceScore."); + if (!(innerMap.get("relevanceScore") instanceof BigDecimal || innerMap.get("relevanceScore") instanceof Double)) { + throw new IllegalArgumentException("relevanceScore is not BigDecimal or Double."); } } } @@ -37,29 +53,25 @@ public void validate(Object input) { public List process(List> rerankResults) { List modelTensors = new ArrayList<>(); - if (rerankResults.size() > 0) { + if (!rerankResults.isEmpty()) { Double[] scores = new Double[rerankResults.size()]; - for (int i = 0; i < rerankResults.size(); i++) { - Integer index = (Integer) rerankResults.get(i).get("index"); - Object relevanceScore = rerankResults.get(i).get("relevanceScore"); - scores[index] = switch (relevanceScore) { - case BigDecimal bd -> bd.doubleValue(); - case Double d -> d; - case null -> throw new IllegalArgumentException("relevanceScore is null"); - default -> throw new IllegalArgumentException( - "Unexpected type for relevanceScore: " + relevanceScore.getClass().getName() - ); - }; + for (Map rerankResult : rerankResults) { + Integer index = (Integer) rerankResult.get("index"); + Object relevanceScore = rerankResult.get("relevanceScore"); + if (relevanceScore instanceof BigDecimal) { + scores[index] = ((BigDecimal) relevanceScore).doubleValue(); + } else if (relevanceScore instanceof Double) { + scores[index] = (Double) relevanceScore; + } } - - for (int i = 0; i < scores.length; i++) { + for (Double score : scores) { modelTensors .add( ModelTensor .builder() .name("similarity") .shape(new long[] { 1 }) - .data(new Number[] { scores[i] }) + .data(new Number[] { score }) .dataType(MLResultDataType.FLOAT32) .build() ); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java index 8db0c4fece..7137bc5093 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockRerankPreProcessFunction.java @@ -23,6 +23,11 @@ public BedrockRerankPreProcessFunction() { @Override public void validate(MLInput mlInput) { + + if (mlInput.getInputDataset() == null) { + throw new IllegalArgumentException("Input dataset cannot be null."); + } + if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) { throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet"); } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java index 1f92076820..3896b76190 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java @@ -35,18 +35,53 @@ public void process_WrongInput_NotList() { function.apply("abc"); } + @Test + public void process_EmptyInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is empty."); + function.apply(Arrays.asList()); + } + @Test public void process_WrongInput_NotCorrectList() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Post process function input is not a List of Map."); + exceptionRule.expectMessage("Rerank result is not a Map."); function.apply(Arrays.asList("abc")); } + @Test + public void process_EmptyMapInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Rerank result is empty."); + function.apply(Arrays.asList(Map.of())); + } + @Test public void process_WrongInput_NotCorrectMap() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("The rerank result should contain index and relevance_score."); - function.apply(Arrays.asList(Map.of("test1", "value1"))); + exceptionRule.expectMessage("Rerank result should have both index and relevanceScore."); + List> rerankResults = List + .of( + Map.of("index", 2, "relevanceScore", 0.7711548805236816), + Map.of("index", 0, "relevanceScore", 0.0025114635936915874), + Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05), + Map.of("test1", "value1") + ); + function.apply(rerankResults); + } + + @Test + public void process_WrongInput_NotCorrectRelevanceScore() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("relevanceScore is not BigDecimal or Double."); + List> rerankResults = List + .of( + Map.of("index", 2, "relevanceScore", 0.7711548805236816), + Map.of("index", 0, "relevanceScore", 0.0025114635936915874), + Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05), + Map.of("index", 3, "relevanceScore", "value1") + ); + function.apply(rerankResults); } @Test From 8a4fdb24a5d30bb746871987a4b3acaf5348dae3 Mon Sep 17 00:00:00 2001 From: tkykenmt Date: Thu, 9 Jan 2025 09:50:39 +0900 Subject: [PATCH 4/5] Fix on method name of test case for list of maps data #3339 Signed-off-by: tkykenmt --- .../postprocess/BedrockRerankPostProcessFunctionTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java index 3896b76190..8e3f18e3a3 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunctionTest.java @@ -43,7 +43,7 @@ public void process_EmptyInput() { } @Test - public void process_WrongInput_NotCorrectList() { + public void process_WrongInput_NotCorrectListOfMapsFormat() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Rerank result is not a Map."); function.apply(Arrays.asList("abc")); From ededf78ee0604e159cad4ce469d06960cf642120 Mon Sep 17 00:00:00 2001 From: tkykenmt Date: Fri, 10 Jan 2025 14:40:42 +0900 Subject: [PATCH 5/5] remove unnecessary cast #3339 Signed-off-by: tkykenmt --- .../functions/postprocess/BedrockRerankPostProcessFunction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java index b0643ae976..b53fa486cd 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockRerankPostProcessFunction.java @@ -55,7 +55,7 @@ public List process(List> rerankResults) { if (!rerankResults.isEmpty()) { Double[] scores = new Double[rerankResults.size()]; - for (Map rerankResult : rerankResults) { + for (Map rerankResult : rerankResults) { Integer index = (Integer) rerankResult.get("index"); Object relevanceScore = rerankResult.get("relevanceScore"); if (relevanceScore instanceof BigDecimal) {