Skip to content

Commit

Permalink
fine tune
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jan 29, 2024
1 parent 152c5e2 commit 5a2fe20
Show file tree
Hide file tree
Showing 23 changed files with 674 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid JSON in payload");
throw new IllegalArgumentException("Invalid payload: " + payload);
}
return (T) payload;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

package org.opensearch.ml.common.connector;

import com.google.common.collect.ImmutableList;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -20,58 +20,41 @@ public class MLPostProcessFunction {
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String COHERE_RERANK = "connector.post_process.cohere.rerank";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();

private static final Map<String, Function<Object, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();

static {
EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction();
BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction();
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildModelTensorList());
}

public static Function<List<?>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
if (embeddings.get(0) instanceof Number) {
embeddings = ImmutableList.of(embeddings);
}
embeddings.forEach(embedding -> {
List<Number> eachEmbedding = (List<Number>) embedding;
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{eachEmbedding.size()})
.data(eachEmbedding.toArray(new Number[0]))
.build()
);
});
return modelTensors;
};
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
}

public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
public static Function<Object, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

public static boolean contains(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,48 @@

package org.opensearch.ml.common.connector;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MLPreProcessFunction {

private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("texts", inputs));
}

private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("inputText", inputs.get(0)));
}
public static final String PROCESS_REMOTE_INFERENCE_INPUT = "pre_process_function.process_remote_inference_input";
public static final String CONVERT_INPUT_TO_JSON_STRING = "pre_process_function.convert_input_to_json_string";

static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
CohereEmbeddingPreProcessFunction cohereEmbeddingPreProcessFunction = new CohereEmbeddingPreProcessFunction();
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_DEFAULT_INPUT, cohereRerankPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT, cohereRerankPreProcessFunction);
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
public static Function<MLInput, RemoteInferenceInputDataSet> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;

public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<Float>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}

List<?> outerList = (List<?>) input;

if (!outerList.isEmpty() && !(((List<?>)input).get(0) instanceof Number)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
}
}

@Override
public List<ModelTensor> process(List<Float> embedding) {
List<ModelTensor> modelTensors = new ArrayList<>();
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build());
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction<List<Map<String, Object>>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}
List<?> outerList = (List<?>) input;
if (!outerList.isEmpty()) {
if (!(outerList.get(0) instanceof Map)) {
throw new IllegalArgumentException("Post process function input is not a List of Map.");
}
Map innerMap = (Map) outerList.get(0);

if (innerMap.isEmpty() || !innerMap.containsKey("index") || !innerMap.containsKey("relevance_score")) {
throw new IllegalArgumentException("The rerank result should contain index and relevance_score.");
}
}
}

@Override
public List<ModelTensor> process(List<Map<String, Object>> rerankResults) {
List<ModelTensor> modelTensors = new ArrayList<>();

if (rerankResults.size() > 0) {
Double[] scores = new Double[rerankResults.size()];
for (int i = 0; i < rerankResults.size(); i++) {
Integer index = (Integer) rerankResults.get(i).get("index");
scores[index] = (Double) rerankResults.get(i).get("relevance_score");
}

for (int i = 0; i < scores.length; i++) {
modelTensors.add(ModelTensor.builder()
.name("similarity")
.shape(new long[]{1})
.data(new Number[]{scores[i]})
.dataType(MLResultDataType.FLOAT32)
.build());
}
}
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.List;
import java.util.function.Function;

public abstract class ConnectorPostProcessFunction<T> implements Function<Object, List<ModelTensor>> {

@Override
public List<ModelTensor> apply(Object input) {
if (input == null) {
throw new IllegalArgumentException("Can't run post process function as model output is null");
}
validate(input);
return process((T)input);
}

public abstract void validate(Object input);

public abstract List<ModelTensor> process(T input);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.List;

public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction<List<List<Float>>> {

@Override
public void validate(Object input) {
if (!(input instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List.");
}

List<?> outerList = (List<?>) input;

if (!outerList.isEmpty()) {
if (!(outerList.get(0) instanceof List)) {
throw new IllegalArgumentException("Post process function input is not a List of List.");
}
List<?> innerList = (List<?>) outerList.get(0);

if (innerList.isEmpty() || !(innerList.get(0) instanceof Number)) {
throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values.");
}
}
}

@Override
public List<ModelTensor> process(List<List<Float>> embeddings) {
List<ModelTensor> modelTensors = new ArrayList<>();
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
return modelTensors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;


public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public BedrockEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Loading

0 comments on commit 5a2fe20

Please sign in to comment.