Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

escape input data #1970

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("texts", processTextDocs(inputData)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("texts", inputData.getDocs()));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

import static org.opensearch.ml.common.utils.StringUtils.gson;

@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {

Expand All @@ -38,21 +34,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

List<String> processTextDocs(TextDocsInputDataSet inputDataSet) {
List<String> docs = new ArrayList<>();
for (String doc : inputDataSet.getDocs()) {
if (doc != null) {
String gsonString = gson.toJson(doc);
// in 2.9, user will add " before and after string
// gson.toString(string) will add extra " before after string, so need to remove
docs.add(gsonString.substring(1, gsonString.length() - 1));
} else {
docs.add(null);
}
}
return docs;
}

public void validateTextDocsInput(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) {
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("input", processTextDocs(inputData)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("input", inputData.getDocs()));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -120,4 +121,23 @@ public static Map<String, String> convertScriptStringToJsonString(Map<String, Ob
}
return parameterStringMap;
}

public static List<String> processTextDocs(List<String> inputDocs) {
List<String> docs = new ArrayList<>();
for (String doc : inputDocs) {
docs.add(processTextDoc(doc));
}
return docs;
}

public static String processTextDoc(String doc) {
if (doc != null) {
String gsonString = gson.toJson(doc);
// in 2.9, user will add " before and after string
// gson.toString(string) will add extra " before after string, so need to remove
return gsonString.substring(1, gsonString.length() - 1);
} else {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.opensearch.script.ScriptService;

import java.util.Arrays;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.junit.Assert;
import org.junit.Test;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -99,4 +100,13 @@ public void getParameterMap() {
Assert.assertEquals("[10,20]", parameterMap.get("key4"));
Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5"));
}

@Test
public void processTextDocs() {
List<String> processedDocs = StringUtils.processTextDocs(Arrays.asList("abc \n\n123\"4", null, "[1.01,\"abc\"]"));
Assert.assertEquals(3, processedDocs.size());
Assert.assertEquals("abc \\n\\n123\\\"4", processedDocs.get(0));
Assert.assertNull(processedDocs.get(1));
Assert.assertEquals("[1.01,\\\"abc\\\"]", processedDocs.get(2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING;
import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.processTextDoc;
import static org.opensearch.ml.common.utils.StringUtils.processTextDocs;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;

import java.io.IOException;
Expand All @@ -29,6 +31,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand Down Expand Up @@ -68,20 +71,7 @@
throw new IllegalArgumentException("no predict action found");
}
RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService);
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().forEach((key, value) -> {
if (value == null) {
newParameters.put(key, null);
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
// no need to escape if it's already valid json
newParameters.put(key, value);
} else {
newParameters.put(key, escapeJson(value));
}
});
inputData.setParameters(newParameters);
}
escapeRemoteInferenceInputData(inputData);
return inputData;
}

Expand Down Expand Up @@ -112,6 +102,7 @@
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
}
} else {
MLInput newInput = escapeMLInput(mlInput);

Check warning on line 105 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L105

Added line #L105 was not covered by tests
boolean convertInputToJsonString = parameters.containsKey(CONVERT_INPUT_TO_JSON_STRING)
&& Boolean.parseBoolean(parameters.get(CONVERT_INPUT_TO_JSON_STRING));
DefaultPreProcessFunction function = DefaultPreProcessFunction
Expand All @@ -120,11 +111,51 @@
.preProcessFunction(preProcessFunction)
.convertInputToJsonString(convertInputToJsonString)
.build();
return function.apply(mlInput);
return function.apply(newInput);

Check warning on line 114 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L114

Added line #L114 was not covered by tests
}
}
}

private static MLInput escapeMLInput(MLInput mlInput) {
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
List<String> newDocs = processTextDocs(docs);
TextDocsInputDataSet newInputData = ((TextDocsInputDataSet) mlInput.getInputDataset()).toBuilder().docs(newDocs).build();
return mlInput.toBuilder().inputDataset(newInputData).build();

Check warning on line 124 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L121-L124

Added lines #L121 - L124 were not covered by tests
}

if (mlInput.getInputDataset() instanceof TextSimilarityInputDataSet) {
String query = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getQueryText();
String newQuery = processTextDoc(query);
List<String> docs = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getTextDocs();
List<String> newDocs = processTextDocs(docs);
TextSimilarityInputDataSet newInputData = ((TextSimilarityInputDataSet) mlInput.getInputDataset())
.toBuilder()
.queryText(newQuery)
.textDocs(newDocs)
.build();
return mlInput.toBuilder().inputDataset(newInputData).build();

Check warning on line 137 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L128-L137

Added lines #L128 - L137 were not covered by tests
}
return mlInput;

Check warning on line 139 in ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

View check run for this annotation

Codecov / codecov/patch

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java#L139

Added line #L139 was not covered by tests
}

public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) {
Map<String, String> newParameters = new HashMap<>();
if (inputData.getParameters() != null) {
inputData.getParameters().forEach((key, value) -> {
if (value == null) {
newParameters.put(key, null);
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
// no need to escape if it's already valid json
newParameters.put(key, value);
} else {
newParameters.put(key, escapeJson(value));
}
});
inputData.setParameters(newParameters);
}
}

private static String getPreprocessFunction(MLInput mlInput, Connector connector) {
Optional<ConnectorAction> predictAction = connector.findPredictAction();
String preProcessFunction = predictAction.get().getPreProcessFunction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;

import java.util.ArrayList;
Expand Down Expand Up @@ -108,6 +109,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List<ModelTenso
MLInputDataset inputDataset = mlInput.getInputDataset();
Map<String, String> inputParameters = new HashMap<>();
if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet) inputDataset).getParameters() != null) {
escapeRemoteInferenceInputData((RemoteInferenceInputDataSet) inputDataset);
inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters());
}
parameters.putAll(inputParameters);
Expand Down
Loading