Skip to content

Commit

Permalink
add escape method for process function
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 8, 2024
1 parent 671457b commit d758fd5
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
import org.opensearch.script.ScriptType;
import org.opensearch.script.TemplateScript;

import java.util.Collections;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod;

@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {

Expand Down Expand Up @@ -39,4 +47,11 @@ public void validateTextDocsInput(MLInput mlInput) {
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
}
}

protected String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
Script script = new Script(ScriptType.INLINE, "painless", addDefaultMethod(painlessScript), Collections.emptyMap());
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);
return templateScript.execute();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
import org.opensearch.script.ScriptType;
import org.opensearch.script.TemplateScript;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;

import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
Expand Down Expand Up @@ -64,9 +60,4 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) {
}
}

private String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap());
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);
return templateScript.execute();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,30 @@
import lombok.experimental.FieldDefaults;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
import org.opensearch.script.ScriptType;
import org.opensearch.script.TemplateScript;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction {

public static final String CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT = "pre_process_function.convert_remote_inference_param_to_object";
ScriptService scriptService;
String preProcessFunction;

Map<String, String> params;

@Builder
public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction) {
public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map<String, String> params) {
this.returnDirectlyForRemoteInferenceInput = false;
this.scriptService = scriptService;
this.preProcessFunction = preProcessFunction;
this.params = params;
}

@Override
Expand All @@ -45,7 +46,19 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
Map<String, Object> inputParams = new HashMap<>();
inputParams.putAll(((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getParameters());
Map<String, String> parameters = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters();
if (params.containsKey(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT) &&
Boolean.parseBoolean(params.get(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT))) {
for (String key : parameters.keySet()) {
if (isJson(parameters.get(key))) {
inputParams.put(key, gson.fromJson(parameters.get(key), Object.class));
} else {
inputParams.put(key, parameters.get(key));
}
}
} else {
inputParams.putAll(parameters);
}
String processedInput = executeScript(scriptService, preProcessFunction, inputParams);
if (processedInput == null) {
throw new IllegalArgumentException("Preprocess function output is null");
Expand All @@ -54,9 +67,4 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) {
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build();
}

String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap());
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);
return templateScript.execute();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,23 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@Log4j2
public class StringUtils {

public static final String DEFAULT_ESCAPE_FUNCTION = "\n String escape(def input) { \n" +
" if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n" +
" if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n" +
" if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n" +
" if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n" +
" if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n" +
" if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n" +
" if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n" +
" return input;" +
"\n }\n";

public static final Gson gson;

static {
Expand Down Expand Up @@ -154,4 +167,25 @@ public static String processTextDoc(String doc) {
return null;
}
}

public static String addDefaultMethod(String functionScript) {
if (!containsEscapeMethod(functionScript) && isEscapeUsed(functionScript)) {
return DEFAULT_ESCAPE_FUNCTION + functionScript;
}
return functionScript;
}

public static boolean patternExist(String input, String patternString) {
Pattern pattern = Pattern.compile(patternString);
Matcher matcher = pattern.matcher(input);
return matcher.find();
}

public static boolean isEscapeUsed(String input) {
return patternExist(input,"(?<!\\bString\\s+)\\bescape\\s*\\(");
}

public static boolean containsEscapeMethod(String input) {
return patternExist(input, "String\\s+escape\\s*\\(\\s*(def|String)\\s+.*?\\)\\s*\\{?");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.script.ScriptService;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
Expand All @@ -39,12 +39,14 @@ public class RemoteInferencePreProcessFunctionTest {

RemoteInferenceInputDataSet remoteInferenceInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
Map<String, String> predictParameter;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
preProcessFunction = "";
function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction);
predictParameter = new HashMap<>();
function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction, predictParameter);
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,34 @@ public void processTextDocs() {
Assert.assertNull(processedDocs.get(1));
Assert.assertEquals("[1.01,\\\"abc\\\"]", processedDocs.get(2));
}

@Test
public void isEscapeUsed() {
Assert.assertFalse(StringUtils.isEscapeUsed("String escape"));
Assert.assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")"));
}

@Test
public void containsEscapeMethod() {
Assert.assertFalse(StringUtils.containsEscapeMethod("String escape"));
Assert.assertFalse(StringUtils.containsEscapeMethod("String escape()"));
Assert.assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")"));
Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)"));
Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(String input)"));
}

@Test
public void addDefaultMethod_NoEscape() {
String input = "return 123;";
String result = StringUtils.addDefaultMethod(input);
Assert.assertEquals(input, result);
}

@Test
public void addDefaultMethod_Escape() {
String input = "return escape(\"abc\n123\");";
String result = StringUtils.addDefaultMethod(input);
Assert.assertNotEquals(input, result);
Assert.assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,14 @@ private static RemoteInferenceInputDataSet processMLInput(
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT)
&& Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) {
RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction);
Map<String, String> params = new HashMap<>();
params.putAll(connector.getParameters());
params.putAll(parameters);
RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(
scriptService,
preProcessFunction,
params
);
return function.apply(mlInput);
} else {
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.engine.utils;

import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod;

import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -31,7 +33,7 @@ public static Optional<String> executePreprocessFunction(
public static Optional<String> executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) {
Map<String, Object> result = StringUtils.fromJson(resultJson, "result");
if (postProcessFunction != null) {
return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result));
return Optional.ofNullable(executeScript(scriptService, addDefaultMethod(postProcessFunction), result));
}
return Optional.empty();
}
Expand Down

This file was deleted.

0 comments on commit d758fd5

Please sign in to comment.