From 5a2fe201b54e96599c7c69800e9afa095d3ecdd1 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 29 Jan 2024 12:43:44 -0800 Subject: [PATCH] fine tune Signed-off-by: Yaliang Wu --- .../ml/common/connector/HttpConnector.java | 2 +- .../connector/MLPostProcessFunction.java | 55 +++----- .../connector/MLPreProcessFunction.java | 43 +++--- .../BedrockEmbeddingPostProcessFunction.java | 42 ++++++ .../CohereRerankPostProcessFunction.java | 57 ++++++++ .../ConnectorPostProcessFunction.java | 27 ++++ .../EmbeddingPostProcessFunction.java | 50 +++++++ .../BedrockEmbeddingPreProcessFunction.java | 34 +++++ .../CohereEmbeddingPreProcessFunction.java | 34 +++++ .../CohereRerankPreProcessFunction.java | 40 ++++++ .../ConnectorPreProcessFunction.java | 58 ++++++++ .../preprocess/DefaultPreProcessFunction.java | 72 ++++++++++ .../OpenAIEmbeddingPreProcessFunction.java | 34 +++++ .../RemoteInferencePreProcessFunction.java | 62 +++++++++ .../opensearch/ml/common/input/MLInput.java | 18 ++- .../ml/common/utils/StringUtils.java | 23 +++ .../connector/MLPostProcessFunctionTest.java | 6 +- .../ml/common/utils/StringUtilsTest.java | 6 +- .../algorithms/remote/ConnectorUtils.java | 131 +++++++++--------- .../remote/RemoteConnectorExecutor.java | 7 +- .../ml/engine/utils/ScriptUtils.java | 6 - .../algorithms/remote/ConnectorUtilsTest.java | 5 +- .../ml/engine/utils/ScriptUtilsTest.java | 8 +- 23 files changed, 674 insertions(+), 146 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ef0e4bf4a1..d5c148f5e1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -291,7 +291,7 @@ public T createPredictPayload(Map parameters) { payload = substitutor.replace(payload); if (!isJson(payload)) { - throw new IllegalArgumentException("Invalid JSON in payload"); + throw new IllegalArgumentException("Invalid payload: " + payload); } return (T) payload; } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index f0b51233fa..4fb3f75412 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.connector; -import com.google.common.collect.ImmutableList; -import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; import org.opensearch.ml.common.output.model.ModelTensor; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -20,58 +20,41 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; + public static final String COHERE_RERANK = "connector.post_process.cohere.rerank"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; + public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); - private static final Map, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); - + private static final Map>> POST_PROCESS_FUNCTIONS = new HashMap<>(); static { + EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction(); + BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction(); + CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); - POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList()); - } - - public static Function, List> buildModelTensorList() { - return embeddings -> { - List modelTensors = new ArrayList<>(); - if (embeddings == null) { - throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); - } - if (embeddings.get(0) instanceof Number) { - embeddings = ImmutableList.of(embeddings); - } - embeddings.forEach(embedding -> { - List eachEmbedding = (List) embedding; - modelTensors.add( - ModelTensor - .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{eachEmbedding.size()}) - .data(eachEmbedding.toArray(new Number[0])) - .build() - ); - }); - return modelTensors; - }; + JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results"); + JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); + POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction); } public static String getResponseFilter(String postProcessFunction) { return JSON_PATH_EXPRESSION.get(postProcessFunction); } - public static Function, List> get(String postProcessFunction) { + public static Function> get(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.get(postProcessFunction); } public static boolean contains(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction); } -} +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 4021769806..d2d65ebdfd 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -5,43 +5,48 @@ package org.opensearch.ml.common.connector; +import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.function.Function; public class MLPreProcessFunction { - private static final Map, Map>> PRE_PROCESS_FUNCTIONS = new HashMap<>(); + private static final Map> PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; + public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank"; + public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank"; - private static Function, Map> cohereTextEmbeddingPreProcess() { - return inputs -> Map.of("parameters", Map.of("texts", inputs)); - } - - private static Function, Map> openAiTextEmbeddingPreProcess() { - return inputs -> Map.of("parameters", Map.of("input", inputs)); - } - - private static Function, Map> bedrockTextEmbeddingPreProcess() { - return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0))); - } + public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input"; + public static final String CONVERT_INPUT_TO_JSON_STRING = "pre_process_function.convert_input_to_json_string"; static { - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess()); + CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction(); + OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction(); + BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); + CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction); + PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction); } public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static Function, Map> get(String preProcessFunction) { - return PRE_PROCESS_FUNCTIONS.get(preProcessFunction); + public static Function get(String postProcessFunction) { + return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java new file mode 100644 index 0000000000..eb55253c01 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; +import java.util.List; + +public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + + List outerList = (List) input; + + if (!outerList.isEmpty() && !(((List)input).get(0) instanceof Number)) { + throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); + } + } + + @Override + public List process(List embedding) { + List modelTensors = new ArrayList<>(); + modelTensors.add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) + .build()); + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java new file mode 100644 index 0000000000..216fcc9d0a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction>> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + List outerList = (List) input; + if (!outerList.isEmpty()) { + if (!(outerList.get(0) instanceof Map)) { + throw new IllegalArgumentException("Post process function input is not a List of Map."); + } + Map innerMap = (Map) outerList.get(0); + + if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevance_score")) { + throw new IllegalArgumentException("The rerank result should contain index and relevance_score."); + } + } + } + + @Override + public List process(List> rerankResults) { + List modelTensors = new ArrayList<>(); + + if (rerankResults.size() > 0) { + Double[] scores = new Double[rerankResults.size()]; + for (int i = 0; i < rerankResults.size(); i++) { + Integer index = (Integer) rerankResults.get(i).get("index"); + scores[index] = (Double) rerankResults.get(i).get("relevance_score"); + } + + for (int i = 0; i < scores.length; i++) { + modelTensors.add(ModelTensor.builder() + .name("similarity") + .shape(new long[]{1}) + .data(new Number[]{scores[i]}) + .dataType(MLResultDataType.FLOAT32) + .build()); + } + } + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java new file mode 100644 index 0000000000..9cb81099c4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.List; +import java.util.function.Function; + +public abstract class ConnectorPostProcessFunction implements Function> { + + @Override + public List apply(Object input) { + if (input == null) { + throw new IllegalArgumentException("Can't run post process function as model output is null"); + } + validate(input); + return process((T)input); + } + + public abstract void validate(Object input); + + public abstract List process(T input); +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java new file mode 100644 index 0000000000..b03c791295 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; +import java.util.List; + +public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction>> { + + @Override + public void validate(Object input) { + if (!(input instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List."); + } + + List outerList = (List) input; + + if (!outerList.isEmpty()) { + if (!(outerList.get(0) instanceof List)) { + throw new IllegalArgumentException("Post process function input is not a List of List."); + } + List innerList = (List) outerList.get(0); + + if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) { + throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); + } + } + } + + @Override + public List process(List> embeddings) { + List modelTensors = new ArrayList<>(); + embeddings.forEach(embedding -> modelTensors.add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) + .build() + )); + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..dae61b6c6c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public BedrockEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0))); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..d82210f4a3 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class CohereEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public CohereEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of("texts", processTextDocs(inputData))); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java new file mode 100644 index 0000000000..c975f7f329 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class CohereRerankPreProcessFunction extends ConnectorPreProcessFunction { + + public CohereRerankPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof TextSimilarityInputDataSet)) { + throw new IllegalArgumentException("This pre_process_function can only support TextSimilarityInputDataSet"); + } + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of( + "query", inputData.getQueryText(), + "documents", inputData.getTextDocs(), + "top_n", inputData.getTextDocs().size() + )); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java new file mode 100644 index 0000000000..72ca6ce112 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@Log4j2 +public abstract class ConnectorPreProcessFunction implements Function { + + protected boolean returnDirectlyForRemoteInferenceInput; + + @Override + public RemoteInferenceInputDataSet apply(MLInput mlInput) { + if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + return (RemoteInferenceInputDataSet)mlInput.getInputDataset(); + } else { + validate(mlInput); + return process(mlInput); + } + } + + public abstract void validate(MLInput mlInput); + + public abstract RemoteInferenceInputDataSet process(MLInput mlInput); + + List processTextDocs(TextDocsInputDataSet inputDataSet) { + List docs = new ArrayList<>(); + for (String doc : inputDataSet.getDocs()) { + if (doc != null) { + String gsonString = gson.toJson(doc); + // in 2.9, user will add " before and after string + // gson.toString(string) will add extra " before after string, so need to remove + docs.add(gsonString.substring(1, gsonString.length() - 1)); + } else { + docs.add(null); + } + } + return docs; + } + + public void validateTextDocsInput(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { + throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java new file mode 100644 index 0000000000..6f128fdd51 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.experimental.FieldDefaults; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.script.TemplateScript; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class DefaultPreProcessFunction extends ConnectorPreProcessFunction { + + ScriptService scriptService; + String preProcessFunction; + boolean convertInputToJsonString; + + @Builder + public DefaultPreProcessFunction(ScriptService scriptService, String preProcessFunction, boolean convertInputToJsonString) { + this.returnDirectlyForRemoteInferenceInput = false; + this.scriptService = scriptService; + this.preProcessFunction = preProcessFunction; + this.convertInputToJsonString = convertInputToJsonString; + } + + @Override + public void validate(MLInput mlInput) { + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + mlInput.toXContent(builder, EMPTY_PARAMS); + String inputStr = builder.toString(); + Map inputParams = gson.fromJson(inputStr, Map.class); + if (convertInputToJsonString) { + inputParams = convertScriptStringToJsonString(Map.of("parameters", gson.fromJson(inputStr, Map.class))); + } + String processedInput = executeScript(scriptService, preProcessFunction, inputParams); + if (processedInput == null) { + throw new IllegalArgumentException("Pre-process function output is null"); + } + Map map = gson.fromJson(processedInput, Map.class); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + } catch (IOException e) { + throw new IllegalArgumentException("Failed to run pre-process function: Wrong input"); + } + } + + private String executeScript(ScriptService scriptService, String painlessScript, Map params) { + Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); + TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); + return templateScript.execute(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java new file mode 100644 index 0000000000..32f294fdcc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; + + +public class OpenAIEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { + + public OpenAIEmbeddingPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = true; + } + + @Override + public void validate(MLInput mlInput) { + validateTextDocsInput(mlInput); + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + Map processedResult = Map.of("parameters", Map.of("input", processTextDocs(inputData))); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java new file mode 100644 index 0000000000..73cf91bee7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.experimental.FieldDefaults; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.script.TemplateScript; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction { + + ScriptService scriptService; + String preProcessFunction; + + @Builder + public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction) { + this.returnDirectlyForRemoteInferenceInput = false; + this.scriptService = scriptService; + this.preProcessFunction = preProcessFunction; + } + + @Override + public void validate(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet)) { + throw new IllegalArgumentException("This pre_process_function can only support RemoteInferenceInputDataSet"); + } + } + + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + Map inputParams = new HashMap<>(); + inputParams.putAll(((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getParameters()); + String processedInput = executeScript(scriptService, preProcessFunction, inputParams); + if (processedInput == null) { + throw new IllegalArgumentException("Input is null after processed by preprocess function"); + } + Map map = gson.fromJson(processedInput, Map.class); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + } + + String executeScript(ScriptService scriptService, String painlessScript, Map params) { + Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); + TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); + return templateScript.execute(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index acd1522736..f2d74bf8c9 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -17,6 +17,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; @@ -30,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -59,6 +61,7 @@ public class MLInput implements Input { public static final String TEXT_DOCS_FIELD = "text_docs"; // Input query text to compare against for text similarity model public static final String QUERY_TEXT_FIELD = "query_text"; + public static final String PARAMETERS_FIELD = "parameters"; // Algorithm name protected FunctionName algorithm; @@ -163,18 +166,23 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } break; case TEXT_SIMILARITY: - TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset; - List tdocs = ds.getTextDocs(); - String queryText = ds.getQueryText(); + TextSimilarityInputDataSet inputDataSet = (TextSimilarityInputDataSet) this.inputDataset; + List documents = inputDataSet.getTextDocs(); + String queryText = inputDataSet.getQueryText(); builder.field(QUERY_TEXT_FIELD, queryText); - if (tdocs != null && !tdocs.isEmpty()) { + if (documents != null && !documents.isEmpty()) { builder.startArray(TEXT_DOCS_FIELD); - for(String d : tdocs) { + for(String d : documents) { builder.value(d); } builder.endArray(); } break; + case REMOTE: + RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) this.inputDataset; + Map parameters = remoteInferenceInputDataSet.getParameters(); + builder.field(PARAMETERS_FIELD, parameters); + break; default: break; } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 43aa3c76ae..f66d1e58c4 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -8,6 +8,7 @@ import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonParser; +import lombok.extern.log4j.Log4j2; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; @@ -21,6 +22,7 @@ import java.util.List; import java.util.Map; +@Log4j2 public class StringUtils { public static final Gson gson; @@ -97,4 +99,25 @@ public static String toJson(Object value) { throw new RuntimeException(e); } } + + public static Map convertScriptStringToJsonString(Map processedInput) { + Map parameterStringMap = new HashMap<>(); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map parametersMap = (Map) processedInput.get("parameters"); + for (String key : parametersMap.keySet()) { + if (parametersMap.get(key) instanceof String) { + parameterStringMap.put(key, (String) parametersMap.get(key)); + } else { + parameterStringMap.put(key, gson.toJson(parametersMap.get(key))); + } + } + return null; + }); + } catch (PrivilegedActionException e) { + log.error("Error processing parameters", e); + throw new RuntimeException(e); + } + return parameterStringMap; + } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index c004c93a31..201b8e5f1e 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -43,15 +43,15 @@ public void test_getResponseFilter() { @Test public void test_buildModelTensorList() { - Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList()); + Assert.assertNotNull(MLPostProcessFunction.buildEmbeddingModelTensorList()); List> numbersList = new ArrayList<>(); numbersList.add(Collections.singletonList(1.0f)); - Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList)); + Assert.assertNotNull(MLPostProcessFunction.buildEmbeddingModelTensorList().apply(numbersList)); } @Test public void test_buildModelTensorList_exception() { exceptionRule.expect(IllegalArgumentException.class); - MLPostProcessFunction.buildModelTensorList().apply(null); + MLPostProcessFunction.buildEmbeddingModelTensorList().apply(null); } } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index a4b34d75b5..3022c97e0a 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.utils; import org.junit.Assert; @@ -87,7 +92,6 @@ public void getParameterMap() { parameters.put("key4", new int[]{10, 20}); parameters.put("key5", new Object[]{1.01, "abc"}); Map parameterMap = StringUtils.getParameterMap(parameters); - System.out.println(parameterMap); Assert.assertEquals(5, parameterMap.size()); Assert.assertEquals("value1", parameterMap.get("key1")); Assert.assertEquals("2", parameterMap.get("key2")); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 6a63d68ff5..c3e385ca4e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -7,20 +7,18 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT; import static org.opensearch.ml.common.utils.StringUtils.gson; -import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; -import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; @@ -28,6 +26,8 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.connector.MLPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -63,14 +63,11 @@ public static RemoteInferenceInputDataSet processInput( if (mlInput == null) { throw new IllegalArgumentException("Input is null"); } - RemoteInferenceInputDataSet inputData; - if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService); - } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { - inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); - } else { - throw new IllegalArgumentException("Wrong input type"); + Optional predictAction = connector.findPredictAction(); + if (predictAction.isEmpty()) { + throw new IllegalArgumentException("no predict action found"); } + RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService); if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); inputData.getParameters().forEach((key, value) -> { @@ -88,65 +85,56 @@ public static RemoteInferenceInputDataSet processInput( return inputData; } - private static RemoteInferenceInputDataSet processTextDocsInput( - TextDocsInputDataSet inputDataSet, + private static RemoteInferenceInputDataSet processMLInput( + MLInput mlInput, Connector connector, Map parameters, ScriptService scriptService ) { - Optional predictAction = connector.findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); - } - String preProcessFunction = predictAction.get().getPreProcessFunction(); - preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; - if (MLPreProcessFunction.contains(preProcessFunction)) { - Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(inputDataSet.getDocs()); - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); + String preProcessFunction = getPreprocessFunction(mlInput, connector); + if (preProcessFunction == null) { + if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + } else { + throw new IllegalArgumentException("pre_process_function not defined in connector"); + } } else { - List docs = new ArrayList<>(); - for (String doc : inputDataSet.getDocs()) { - if (doc != null) { - String gsonString = gson.toJson(doc); - // in 2.9, user will add " before and after string - // gson.toString(string) will add extra " before after string, so need to remove - docs.add(gsonString.substring(1, gsonString.length() - 1)); + preProcessFunction = fillProcessFunctionParameter(parameters, preProcessFunction); + if (MLPreProcessFunction.contains(preProcessFunction)) { + Function function = MLPreProcessFunction.get(preProcessFunction); + return function.apply(mlInput); + } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT) + && Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) { + RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction); + return function.apply(mlInput); } else { - docs.add(null); + return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); } + } else { + boolean convertInputToJsonString = parameters.containsKey(CONVERT_INPUT_TO_JSON_STRING) + && Boolean.parseBoolean(parameters.get(CONVERT_INPUT_TO_JSON_STRING)); + DefaultPreProcessFunction function = DefaultPreProcessFunction + .builder() + .scriptService(scriptService) + .preProcessFunction(preProcessFunction) + .convertInputToJsonString(convertInputToJsonString) + .build(); + return function.apply(mlInput); } - if (preProcessFunction.contains("${parameters.")) { - StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); - preProcessFunction = substitutor.replace(preProcessFunction); - } - Optional processedInput = executePreprocessFunction(scriptService, preProcessFunction, docs); - if (processedInput.isEmpty()) { - throw new IllegalArgumentException("Wrong input"); - } - Map map = gson.fromJson(processedInput.get(), Map.class); - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); } } - private static Map convertScriptStringToJsonString(Map processedInput) { - Map parameterStringMap = new HashMap<>(); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - Map parametersMap = (Map) processedInput.get("parameters"); - for (String key : parametersMap.keySet()) { - if (parametersMap.get(key) instanceof String) { - parameterStringMap.put(key, (String) parametersMap.get(key)); - } else { - parameterStringMap.put(key, gson.toJson(parametersMap.get(key))); - } - } - return null; - }); - } catch (PrivilegedActionException e) { - log.error("Error processing parameters", e); - throw new RuntimeException(e); + private static String getPreprocessFunction(MLInput mlInput, Connector connector) { + Optional predictAction = connector.findPredictAction(); + String preProcessFunction = predictAction.get().getPreProcessFunction(); + if (preProcessFunction != null) { + return preProcessFunction; + } + if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { + return MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT; } - return parameterStringMap; + return null; } public static ModelTensors processOutput( @@ -165,21 +153,16 @@ public static ModelTensors processOutput( } ConnectorAction connectorAction = predictAction.get(); String postProcessFunction = connectorAction.getPostProcessFunction(); - if (postProcessFunction != null && postProcessFunction.contains("${parameters")) { - StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); - postProcessFunction = substitutor.replace(postProcessFunction); - } + postProcessFunction = fillProcessFunctionParameter(parameters, postProcessFunction); String responseFilter = parameters.get(RESPONSE_FILTER_FIELD); if (MLPostProcessFunction.contains(postProcessFunction)) { // in this case, we can use jsonpath to build a List> result from model response. if (StringUtils.isBlank(responseFilter)) responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction); - List vectors = JsonPath.read(modelResponse, responseFilter); - List processedResponse = executeBuildInPostProcessFunction( - vectors, - MLPostProcessFunction.get(postProcessFunction) - ); + + Object filteredOutput = JsonPath.read(modelResponse, responseFilter); + List processedResponse = MLPostProcessFunction.get(postProcessFunction).apply(filteredOutput); return ModelTensors.builder().mlModelTensors(processedResponse).build(); } @@ -198,6 +181,18 @@ public static ModelTensors processOutput( return ModelTensors.builder().mlModelTensors(modelTensors).build(); } + private static String fillProcessFunctionParameter(Map parameters, String processFunction) { + if (processFunction != null && processFunction.contains("${parameters.")) { + Map tmpParameters = new HashMap<>(); + for (String key : parameters.keySet()) { + tmpParameters.put(key, gson.toJson(parameters.get(key))); + } + StringSubstitutor substitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}"); + processFunction = substitutor.replace(processFunction); + } + return processFunction; + } + public static SdkHttpFullRequest signRequest( SdkHttpFullRequest request, String accessKey, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index be50af3aff..4f46c67906 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -106,14 +106,17 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List inputParameters = new HashMap<>(); if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet) inputDataset).getParameters() != null) { - parameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); + inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); } - + parameters.putAll(inputParameters); RemoteInferenceInputDataSet inputData = processInput(mlInput, connector, parameters, getScriptService()); if (inputData.getParameters() != null) { parameters.putAll(inputData.getParameters()); } + // override again to always prioritize the input parameter + parameters.putAll(inputParameters); String payload = connector.createPredictPayload(parameters); connector.validatePayload(payload); String userStr = getClient() diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java index cc721e9129..46d7794c6c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java @@ -9,9 +9,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; -import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.Script; import org.opensearch.script.ScriptService; @@ -30,10 +28,6 @@ public static Optional executePreprocessFunction( return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences))); } - public static List executeBuildInPostProcessFunction(List vectors, Function, List> function) { - return function.apply(vectors); - } - public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { Map result = StringUtils.fromJson(resultJson, "result"); if (postProcessFunction != null) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 8ad745340b..0bb587c4d4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -190,8 +190,7 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio .parameters(parameters) .actions(Arrays.asList(predictAction)) .build(); - ModelTensors tensors = ConnectorUtils - .processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of()); + ModelTensors tensors = ConnectorUtils.processOutput(connector, scriptService, ImmutableMap.of()); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); @@ -224,7 +223,7 @@ public void processOutput_PostprocessFunction() throws IOException { .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of()); + ModelTensors tensors = ConnectorUtils.processOutput(connector, scriptService, ImmutableMap.of()); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java index bea44ebf48..b9faeafafb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.engine.utils; import static org.junit.Assert.assertEquals; @@ -40,8 +45,7 @@ public void test_executePreprocessFunction() { @Test public void test_executeBuildInPostProcessFunction() { List> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f)); - List modelTensors = ScriptUtils - .executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING)); + List modelTensors = MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING).apply(input); assertNotNull(modelTensors); assertEquals(2, modelTensors.size()); }