diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 298ea24f98..784fb4961a 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -1001,7 +1001,10 @@ public void loadExtensions(ExtensionLoader loader) { public Map getProcessors(org.opensearch.ingest.Processor.Parameters parameters) { Map processors = new HashMap<>(); processors - .put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client)); + .put( + MLInferenceIngestProcessor.TYPE, + new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry) + ); return Collections.unmodifiableMap(processors); } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index 9d4870c979..d7d58184c8 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -6,9 +6,11 @@ import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -19,11 +21,13 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.ingest.ValueSource; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.MLTaskResponse; @@ -46,8 +50,10 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod public static final String DOT_SYMBOL = "."; private final InferenceProcessorAttributes inferenceProcessorAttributes; private final boolean ignoreMissing; + private final String functionName; private final boolean fullResponsePath; private final boolean ignoreFailure; + private final String modelInput; private final ScriptService scriptService; private static Client client; public static final String TYPE = "ml_inference"; @@ -55,10 +61,13 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod // allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the // prediction outcomes, return the whole prediction outcome by skipping filtering public static final String IGNORE_MISSING = "ignore_missing"; + public static final String FUNCTION_NAME = "function_name"; public static final String FULL_RESPONSE_PATH = "full_response_path"; + public static final String MODEL_INPUT = "model_input"; // At default, ml inference processor allows maximum 10 prediction tasks running in parallel // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + private final NamedXContentRegistry xContentRegistry; private Configuration suppressExceptionConfiguration = Configuration .builder() @@ -74,10 +83,13 @@ protected MLInferenceIngestProcessor( String tag, String description, boolean ignoreMissing, + String functionName, boolean fullResponsePath, boolean ignoreFailure, + String modelInput, ScriptService scriptService, - Client client + Client client, + NamedXContentRegistry xContentRegistry ) { super(tag, description); this.inferenceProcessorAttributes = new InferenceProcessorAttributes( @@ -88,10 +100,13 @@ protected MLInferenceIngestProcessor( maxPredictionTask ); this.ignoreMissing = ignoreMissing; + this.functionName = functionName; this.fullResponsePath = fullResponsePath; this.ignoreFailure = ignoreFailure; + this.modelInput = modelInput; this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -167,10 +182,13 @@ private void processPredictions( List> processOutputMap, int inputMapIndex, int inputMapSize - ) { + ) throws IOException { Map modelParameters = new HashMap<>(); + Map modelConfigs = new HashMap<>(); + if (inferenceProcessorAttributes.getModelConfigMaps() != null) { modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); } // when no input mapping is provided, default to read all fields from documents as model input if (inputMapSize == 0) { @@ -189,7 +207,22 @@ private void processPredictions( } } - ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId()); + Set inputMapKeys = new HashSet<>(modelParameters.keySet()); + inputMapKeys.removeAll(modelConfigs.keySet()); + + Map inputMappings = new HashMap<>(); + for (String k : inputMapKeys) { + inputMappings.put(k, modelParameters.get(k)); + } + ActionRequest request = getRemoteModelInferenceRequest( + xContentRegistry, + modelParameters, + modelConfigs, + inputMappings, + inferenceProcessorAttributes.getModelId(), + functionName, + modelInput + ); client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { @@ -429,6 +462,7 @@ public static class Factory implements Processor.Factory { private final ScriptService scriptService; private final Client client; + private final NamedXContentRegistry xContentRegistry; /** * Constructs a new instance of the Factory class. @@ -436,9 +470,10 @@ public static class Factory implements Processor.Factory { * @param scriptService the ScriptService instance to be used by the Factory * @param client the Client instance to be used by the Factory */ - public Factory(ScriptService scriptService, Client client) { + public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) { this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -465,6 +500,10 @@ public MLInferenceIngestProcessor create( int maxPredictionTask = ConfigurationUtils .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); + String functionName = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name()); + String modelInput = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }"); boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, false); boolean ignoreFailure = ConfigurationUtils .readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false); @@ -496,10 +535,13 @@ public MLInferenceIngestProcessor create( processorTag, description, ignoreMissing, + functionName, fullResponsePath, ignoreFailure, + modelInput, scriptService, - client + client, + xContentRegistry ); } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index 19fb367799..2c934e6db4 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -5,21 +5,27 @@ package org.opensearch.ml.processor; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelTensor; @@ -51,17 +57,57 @@ public interface ModelExecutor { * @return an ActionRequest instance for remote model inference * @throws IllegalArgumentException if the input parameters are null */ - default ActionRequest getRemoteModelInferenceRequest(Map parameters, String modelId) { + default ActionRequest getRemoteModelInferenceRequest( + NamedXContentRegistry xContentRegistry, + Map parameters, + Map modelConfigs, + Map inputMappings, + String modelId, + String functionNameStr, + String modelInput + ) throws IOException { if (parameters == null) { throw new IllegalArgumentException("wrong input. The model input cannot be empty."); } - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + FunctionName functionName = FunctionName.REMOTE; + if (functionNameStr != null) { + functionName = FunctionName.from(functionNameStr); + } + // RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + + Map inputParams = new HashMap<>(); + if (FunctionName.REMOTE == functionName) { + inputParams.put("parameters", StringUtils.toJson(parameters)); + } else { + inputParams.putAll(parameters); + } + + String payload = modelInput; + // payload = fillNullParameters(parameters, payload); + StringSubstitutor modelConfigSubstitutor = new StringSubstitutor(modelConfigs, "${model_config.", "}"); + payload = modelConfigSubstitutor.replace(payload); + StringSubstitutor inputMapSubstitutor = new StringSubstitutor(inputMappings, "${input_map.", "}"); + payload = inputMapSubstitutor.replace(payload); + StringSubstitutor parametersSubstitutor = new StringSubstitutor(inputParams, "${ml_inference.", "}"); + payload = parametersSubstitutor.replace(payload); + + if (!isJson(payload)) { + throw new IllegalArgumentException("Invalid payload: " + payload); + } - MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + // String jsonStr; + // try { + // jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(inputParams)); + // } catch (PrivilegedActionException e) { + // throw new IllegalArgumentException("wrong connector"); + // } + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, payload); - ActionRequest request = new MLPredictionTaskRequest(modelId, mlInput, null); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInput mlInput = MLInput.parse(parser, functionName.name()); + // MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); - return request; + return new MLPredictionTaskRequest(modelId, mlInput); }