-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pre and post process functions for Bedrock Rerank API #3254 #3339
Changes from all commits
3f7a00f
316de2c
4af169a
8a4fdb2
ededf78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.postprocess; | ||
|
||
import java.math.BigDecimal; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.opensearch.ml.common.output.model.MLResultDataType; | ||
import org.opensearch.ml.common.output.model.ModelTensor; | ||
|
||
public class BedrockRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> { | ||
|
||
@Override | ||
public void validate(Object input) { | ||
|
||
if (!(input instanceof List)) { | ||
throw new IllegalArgumentException("Post process function input is not a List."); | ||
} | ||
|
||
List<?> outerList = (List<?>) input; | ||
|
||
if (outerList.isEmpty()) { | ||
throw new IllegalArgumentException("Post process function input is empty."); | ||
} | ||
|
||
for (Object item : outerList) { | ||
if (!(item instanceof Map)) { | ||
throw new IllegalArgumentException("Rerank result is not a Map."); | ||
} | ||
|
||
Map<?, ?> innerMap = (Map<?, ?>) item; | ||
|
||
if (innerMap.isEmpty()) { | ||
throw new IllegalArgumentException("Rerank result is empty."); | ||
} | ||
|
||
if (!innerMap.containsKey("index") || !innerMap.containsKey("relevanceScore")) { | ||
throw new IllegalArgumentException("Rerank result should have both index and relevanceScore."); | ||
} | ||
|
||
if (!(innerMap.get("relevanceScore") instanceof BigDecimal || innerMap.get("relevanceScore") instanceof Double)) { | ||
throw new IllegalArgumentException("relevanceScore is not BigDecimal or Double."); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) { | ||
List<ModelTensor> modelTensors = new ArrayList<>(); | ||
|
||
if (!rerankResults.isEmpty()) { | ||
Double[] scores = new Double[rerankResults.size()]; | ||
for (Map rerankResult : rerankResults) { | ||
Integer index = (Integer) rerankResult.get("index"); | ||
Object relevanceScore = rerankResult.get("relevanceScore"); | ||
if (relevanceScore instanceof BigDecimal) { | ||
scores[index] = ((BigDecimal) relevanceScore).doubleValue(); | ||
} else if (relevanceScore instanceof Double) { | ||
scores[index] = (Double) relevanceScore; | ||
} | ||
} | ||
for (Double score : scores) { | ||
modelTensors | ||
.add( | ||
ModelTensor | ||
.builder() | ||
.name("similarity") | ||
.shape(new long[] { 1 }) | ||
.data(new Number[] { score }) | ||
.dataType(MLResultDataType.FLOAT32) | ||
.build() | ||
); | ||
} | ||
} | ||
return modelTensors; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
public class BedrockRerankPreProcessFunction extends ConnectorPreProcessFunction { | ||
|
||
public BedrockRerankPreProcessFunction() { | ||
this.returnDirectlyForRemoteInferenceInput = true; | ||
} | ||
|
||
@Override | ||
public void validate(MLInput mlInput) { | ||
|
||
if (mlInput.getInputDataset() == null) { | ||
throw new IllegalArgumentException("Input dataset cannot be null."); | ||
} | ||
|
||
if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) { | ||
throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet"); | ||
} | ||
} | ||
|
||
@Override | ||
public RemoteInferenceInputDataSet process(MLInput mlInput) { | ||
TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset(); | ||
String queryText = inputData.getQueryText(); | ||
List<String> textDocs = inputData.getTextDocs(); | ||
|
||
List<Map<String, Object>> queries = new ArrayList<Map<String, Object>>(); | ||
queries.add(Map.of("textQuery", Map.of("text", queryText), "type", "TEXT")); | ||
|
||
List<Map<String, Object>> sources = new ArrayList<Map<String, Object>>(); | ||
inputData.getTextDocs().forEach(textDoc -> { | ||
sources.add(Map.of("inlineDocumentSource", Map.of("textDocument", Map.of("text", textDoc), "type", "TEXT"), "type", "INLINE")); | ||
}); | ||
|
||
Map<String, Object> processedResult = Map | ||
.of("parameters", Map.of("queries", queries, "sources", sources, "numberOfResults", textDocs.size())); | ||
|
||
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.postprocess; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.output.model.ModelTensor; | ||
|
||
public class BedrockRerankPostProcessFunctionTest { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
BedrockRerankPostProcessFunction function; | ||
|
||
@Before | ||
public void setUp() { | ||
function = new BedrockRerankPostProcessFunction(); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput_NotList() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Post process function input is not a List."); | ||
function.apply("abc"); | ||
} | ||
|
||
@Test | ||
public void process_EmptyInput() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Post process function input is empty."); | ||
function.apply(Arrays.asList()); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput_NotCorrectListOfMapsFormat() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Rerank result is not a Map."); | ||
function.apply(Arrays.asList("abc")); | ||
} | ||
|
||
@Test | ||
public void process_EmptyMapInput() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Rerank result is empty."); | ||
function.apply(Arrays.asList(Map.of())); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput_NotCorrectMap() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. process_WrongInput_NotCorrectListOfMapsFormat(){ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated on the commit 8a4fdb2 |
||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Rerank result should have both index and relevanceScore."); | ||
List<Map<String, Object>> rerankResults = List | ||
.of( | ||
Map.of("index", 2, "relevanceScore", 0.7711548805236816), | ||
Map.of("index", 0, "relevanceScore", 0.0025114635936915874), | ||
Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05), | ||
Map.of("test1", "value1") | ||
); | ||
function.apply(rerankResults); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput_NotCorrectRelevanceScore() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("relevanceScore is not BigDecimal or Double."); | ||
List<Map<String, Object>> rerankResults = List | ||
.of( | ||
Map.of("index", 2, "relevanceScore", 0.7711548805236816), | ||
Map.of("index", 0, "relevanceScore", 0.0025114635936915874), | ||
Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05), | ||
Map.of("index", 3, "relevanceScore", "value1") | ||
); | ||
function.apply(rerankResults); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe lets make a null test? just so we can understand? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. null test can be implemented by referring superclass, but writing superclass's test case in test class for extended class can lead build failure when superclass is updated. |
||
@Test | ||
public void process_CorrectInput() { | ||
List<Map<String, Object>> rerankResults = List | ||
.of( | ||
Map.of("index", 2, "relevanceScore", 0.7711548805236816), | ||
Map.of("index", 0, "relevanceScore", 0.0025114635936915874), | ||
Map.of("index", 1, "relevanceScore", 2.4876489987946115e-05), | ||
Map.of("index", 3, "relevanceScore", 6.339210358419223e-06) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently this will pass when just the first is in correct format but does not check the rest. Like mentioned early if you can change the validation to check each entry is in the right format There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me add a test case to check a list having incorrect map. will update on the next commit.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated on 4af169a |
||
); | ||
List<ModelTensor> result = function.apply(rerankResults); | ||
assertEquals(4, result.size()); | ||
assertEquals(1, result.get(0).getData().length); | ||
assertEquals(0.0025114635936915874, result.get(0).getData()[0]); | ||
assertEquals(2.4876489987946115e-05, result.get(1).getData()[0]); | ||
assertEquals(0.7711548805236816, result.get(2).getData()[0]); | ||
assertEquals(6.339210358419223e-06, result.get(3).getData()[0]); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertTrue; | ||
|
||
import java.util.Arrays; | ||
|
||
import org.json.JSONArray; | ||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
public class BedrockRerankPreProcessFunctionTest { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
BedrockRerankPreProcessFunction function; | ||
|
||
TextSimilarityInputDataSet textSimilarityInputDataSet; | ||
TextDocsInputDataSet textDocsInputDataSet; | ||
|
||
@Before | ||
public void setUp() { | ||
function = new BedrockRerankPreProcessFunction(); | ||
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); | ||
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); | ||
} | ||
|
||
@Test | ||
public void process_NullInput() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Preprocess function input can't be null"); | ||
function.apply(null); | ||
} | ||
|
||
@Test | ||
public void process_WrongInput() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("This pre_process_function can only support TextSimilarityInputDataSet"); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
function.apply(mlInput); | ||
} | ||
|
||
@Test | ||
public void process_CorrectInput() { | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
assertEquals(3, dataSet.getParameters().size()); | ||
|
||
JSONArray expectedSources = new JSONArray( | ||
"[{\"type\": \"INLINE\", \"inlineDocumentSource\": {\"type\": \"TEXT\", \"textDocument\": {\"text\": \"hello\"}}}]" | ||
); | ||
JSONArray actualSources = new JSONArray(dataSet.getParameters().get("sources")); | ||
assertTrue(expectedSources.getJSONObject(0).similar(actualSources.getJSONObject(0))); | ||
|
||
JSONArray expectedQueries = new JSONArray("[{\"textQuery\": {\"text\": \"test\"}, \"type\": \"TEXT\"}]"); | ||
JSONArray actualQueries = new JSONArray(dataSet.getParameters().get("queries")); | ||
assertTrue(expectedQueries.getJSONObject(0).similar(actualQueries.getJSONObject(0))); | ||
|
||
assertEquals("1", dataSet.getParameters().get("numberOfResults")); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also check for null before getInputDataset()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will be updated on the next commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've checked null check is already implemented in apply method in superclass, ConnectorPreProcessFunction and ConnectorPostProcessFunction. apply method is wrapping validate method.
https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java
https://github.com/opensearch-project/ml-commons/blob/main/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java
Thus, I think it's not necessary to implement nullcheck in validate method again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just added following validation in validate method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated on 4af169a