Skip to content

Commit

Permalink
escape input data
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jan 31, 2024
1 parent 45e5199 commit 28bb941
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("texts", processTextDocs(inputData)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("texts", inputData.getDocs()));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

import static org.opensearch.ml.common.utils.StringUtils.gson;

@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {

Expand All @@ -38,21 +34,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

List<String> processTextDocs(TextDocsInputDataSet inputDataSet) {
List<String> docs = new ArrayList<>();
for (String doc : inputDataSet.getDocs()) {
if (doc != null) {
String gsonString = gson.toJson(doc);
// in 2.9, user will add " before and after string
// gson.toString(string) will add extra " before after string, so need to remove
docs.add(gsonString.substring(1, gsonString.length() - 1));
} else {
docs.add(null);
}
}
return docs;
}

public void validateTextDocsInput(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) {
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("input", processTextDocs(inputData)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("input", inputData.getDocs()));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -120,4 +121,23 @@ public static Map<String, String> convertScriptStringToJsonString(Map<String, Ob
}
return parameterStringMap;
}

public static List<String> processTextDocs(List<String> inputDocs) {
List<String> docs = new ArrayList<>();

Check warning on line 126 in common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java#L126

Added line #L126 was not covered by tests
for (String doc : inputDocs) {
docs.add(processTextDoc(doc));
}
return docs;

Check warning on line 130 in common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java#L128-L130

Added lines #L128 - L130 were not covered by tests
}

public static String processTextDoc(String doc) {
if (doc != null) {
String gsonString = gson.toJson(doc);

Check warning on line 135 in common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java#L135

Added line #L135 was not covered by tests
// in 2.9, user will add " before and after string
// gson.toString(string) will add extra " before after string, so need to remove
return gsonString.substring(1, gsonString.length() - 1);

Check warning on line 138 in common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java#L138

Added line #L138 was not covered by tests
} else {
return null;

Check warning on line 140 in common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java#L140

Added line #L140 was not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING;
import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.processTextDoc;
import static org.opensearch.ml.common.utils.StringUtils.processTextDocs;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;

import java.io.IOException;
Expand All @@ -29,6 +31,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand Down Expand Up @@ -68,20 +71,7 @@ public static RemoteInferenceInputDataSet processInput(
throw new IllegalArgumentException("no predict action found");
}
RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService);
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().forEach((key, value) -> {
if (value == null) {
newParameters.put(key, null);
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
// no need to escape if it's already valid json
newParameters.put(key, value);
} else {
newParameters.put(key, escapeJson(value));
}
});
inputData.setParameters(newParameters);
}
escapeRemoteInferenceInputData(inputData);
return inputData;
}

Expand Down Expand Up @@ -112,6 +102,7 @@ private static RemoteInferenceInputDataSet processMLInput(
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
}
} else {
MLInput newInput = escapeMLInput(mlInput);

Check warning on line 105 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L105

Added line #L105 was not covered by tests
boolean convertInputToJsonString = parameters.containsKey(CONVERT_INPUT_TO_JSON_STRING)
&& Boolean.parseBoolean(parameters.get(CONVERT_INPUT_TO_JSON_STRING));
DefaultPreProcessFunction function = DefaultPreProcessFunction
Expand All @@ -120,11 +111,51 @@ private static RemoteInferenceInputDataSet processMLInput(
.preProcessFunction(preProcessFunction)
.convertInputToJsonString(convertInputToJsonString)
.build();
return function.apply(mlInput);
return function.apply(newInput);

Check warning on line 114 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L114

Added line #L114 was not covered by tests
}
}
}

private static MLInput escapeMLInput(MLInput mlInput) {
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
List<String> newDocs = processTextDocs(docs);
TextDocsInputDataSet newInputData = ((TextDocsInputDataSet) mlInput.getInputDataset()).toBuilder().docs(newDocs).build();
return mlInput.toBuilder().inputDataset(newInputData).build();

Check warning on line 124 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L121-L124

Added lines #L121 - L124 were not covered by tests
}

if (mlInput.getInputDataset() instanceof TextSimilarityInputDataSet) {
String query = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getQueryText();
String newQuery = processTextDoc(query);
List<String> docs = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getTextDocs();
List<String> newDocs = processTextDocs(docs);
TextSimilarityInputDataSet newInputData = ((TextSimilarityInputDataSet) mlInput.getInputDataset())
.toBuilder()
.queryText(newQuery)
.textDocs(newDocs)
.build();
return mlInput.toBuilder().inputDataset(newInputData).build();

Check warning on line 137 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L128-L137

Added lines #L128 - L137 were not covered by tests
}
return mlInput;

Check warning on line 139 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L139

Added line #L139 was not covered by tests
}

public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) {
Map<String, String> newParameters = new HashMap<>();
if (inputData.getParameters() != null) {
inputData.getParameters().forEach((key, value) -> {
if (value == null) {
newParameters.put(key, null);
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
// no need to escape if it's already valid json
newParameters.put(key, value);
} else {
newParameters.put(key, escapeJson(value));
}
});
inputData.setParameters(newParameters);
}
}

private static String getPreprocessFunction(MLInput mlInput, Connector connector) {
Optional<ConnectorAction> predictAction = connector.findPredictAction();
String preProcessFunction = predictAction.get().getPreProcessFunction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;

import java.util.ArrayList;
Expand Down Expand Up @@ -108,6 +109,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List<ModelTenso
MLInputDataset inputDataset = mlInput.getInputDataset();
Map<String, String> inputParameters = new HashMap<>();
if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet) inputDataset).getParameters() != null) {
escapeRemoteInferenceInputData((RemoteInferenceInputDataSet) inputDataset);
inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters());
}
parameters.putAll(inputParameters);
Expand Down

0 comments on commit 28bb941

Please sign in to comment.