diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index eae2cb6524..5701e0fa3a 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -9,9 +9,17 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.ScriptType; +import org.opensearch.script.TemplateScript; +import java.util.Collections; +import java.util.Map; import java.util.function.Function; +import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; + @Log4j2 public abstract class ConnectorPreProcessFunction implements Function { @@ -39,4 +47,11 @@ public void validateTextDocsInput(MLInput mlInput) { throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); } } + + protected String executeScript(ScriptService scriptService, String painlessScript, Map params) { + Script script = new Script(ScriptType.INLINE, "painless", addDefaultMethod(painlessScript), Collections.emptyMap()); + TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); + return templateScript.execute(); + } + } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java index 6b66b6eeb4..fac2b5bc94 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java @@ -12,13 +12,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.script.Script; import org.opensearch.script.ScriptService; -import org.opensearch.script.ScriptType; -import org.opensearch.script.TemplateScript; import java.io.IOException; -import java.util.Collections; import java.util.Map; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -64,9 +60,4 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) { } } - private String executeScript(ScriptService scriptService, String painlessScript, Map params) { - Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); - TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); - return templateScript.execute(); - } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java index a8c549ea3b..882c1409f6 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java @@ -10,29 +10,30 @@ import lombok.experimental.FieldDefaults; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.script.Script; import org.opensearch.script.ScriptService; -import org.opensearch.script.ScriptType; -import org.opensearch.script.TemplateScript; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction { + public static final String CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT = "pre_process_function.convert_remote_inference_param_to_object"; ScriptService scriptService; String preProcessFunction; + Map params; + @Builder - public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction) { + public RemoteInferencePreProcessFunction(ScriptService scriptService, String preProcessFunction, Map params) { this.returnDirectlyForRemoteInferenceInput = false; this.scriptService = scriptService; this.preProcessFunction = preProcessFunction; + this.params = params; } @Override @@ -45,7 +46,19 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { Map inputParams = new HashMap<>(); - inputParams.putAll(((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getParameters()); + Map parameters = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters(); + if (params.containsKey(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT) && + Boolean.parseBoolean(params.get(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT))) { + for (String key : parameters.keySet()) { + if (isJson(parameters.get(key))) { + inputParams.put(key, gson.fromJson(parameters.get(key), Object.class)); + } else { + inputParams.put(key, parameters.get(key)); + } + } + } else { + inputParams.putAll(parameters); + } String processedInput = executeScript(scriptService, preProcessFunction, inputParams); if (processedInput == null) { throw new IllegalArgumentException("Preprocess function output is null"); @@ -54,9 +67,4 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) { return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); } - String executeScript(ScriptService scriptService, String painlessScript, Map params) { - Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); - TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); - return templateScript.execute(); - } } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index fbad16003a..cd58292672 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -23,10 +23,23 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; @Log4j2 public class StringUtils { + public static final String DEFAULT_ESCAPE_FUNCTION = "\n String escape(def input) { \n" + + " if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n" + + " if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n" + + " if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n" + + " if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n" + + " if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n" + + " if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n" + + " if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n" + + " return input;" + + "\n }\n"; + public static final Gson gson; static { @@ -154,4 +167,25 @@ public static String processTextDoc(String doc) { return null; } } + + public static String addDefaultMethod(String functionScript) { + if (!containsEscapeMethod(functionScript) && isEscapeUsed(functionScript)) { + return DEFAULT_ESCAPE_FUNCTION + functionScript; + } + return functionScript; + } + + public static boolean patternExist(String input, String patternString) { + Pattern pattern = Pattern.compile(patternString); + Matcher matcher = pattern.matcher(input); + return matcher.find(); + } + + public static boolean isEscapeUsed(String input) { + return patternExist(input,"(? predictParameter; @Before public void setUp() { MockitoAnnotations.openMocks(this); preProcessFunction = ""; - function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction); + predictParameter = new HashMap<>(); + function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction, predictParameter); remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index 3154672cf1..f238d12c91 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -123,4 +123,34 @@ public void processTextDocs() { Assert.assertNull(processedDocs.get(1)); Assert.assertEquals("[1.01,\\\"abc\\\"]", processedDocs.get(2)); } + + @Test + public void isEscapeUsed() { + Assert.assertFalse(StringUtils.isEscapeUsed("String escape")); + Assert.assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")")); + } + + @Test + public void containsEscapeMethod() { + Assert.assertFalse(StringUtils.containsEscapeMethod("String escape")); + Assert.assertFalse(StringUtils.containsEscapeMethod("String escape()")); + Assert.assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")")); + Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)")); + Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(String input)")); + } + + @Test + public void addDefaultMethod_NoEscape() { + String input = "return 123;"; + String result = StringUtils.addDefaultMethod(input); + Assert.assertEquals(input, result); + } + + @Test + public void addDefaultMethod_Escape() { + String input = "return escape(\"abc\n123\");"; + String result = StringUtils.addDefaultMethod(input); + Assert.assertNotEquals(input, result); + Assert.assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION)); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 893f923fbd..49e6ef7d69 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -96,7 +96,14 @@ private static RemoteInferenceInputDataSet processMLInput( } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT) && Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) { - RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction(scriptService, preProcessFunction); + Map params = new HashMap<>(); + params.putAll(connector.getParameters()); + params.putAll(parameters); + RemoteInferencePreProcessFunction function = new RemoteInferencePreProcessFunction( + scriptService, + preProcessFunction, + params + ); return function.apply(mlInput); } else { return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java index 46d7794c6c..43a46d06ac 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java @@ -5,6 +5,8 @@ package org.opensearch.ml.engine.utils; +import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; + import java.util.Collections; import java.util.List; import java.util.Map; @@ -31,7 +33,7 @@ public static Optional executePreprocessFunction( public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { Map result = StringUtils.fromJson(resultJson, "result"); if (postProcessFunction != null) { - return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); + return Optional.ofNullable(executeScript(scriptService, addDefaultMethod(postProcessFunction), result)); } return Optional.empty(); }