Skip to content

Commit

Permalink
support local model
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed May 27, 2024
1 parent 1d7b16d commit fd54211
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,10 @@ public void loadExtensions(ExtensionLoader loader) {
public Map<String, org.opensearch.ingest.Processor.Factory> getProcessors(org.opensearch.ingest.Processor.Parameters parameters) {
Map<String, org.opensearch.ingest.Processor.Factory> processors = new HashMap<>();
processors
.put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client));
.put(
MLInferenceIngestProcessor.TYPE,
new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry)
);
return Collections.unmodifiableMap(processors);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -19,11 +21,13 @@
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.ValueSource;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
Expand All @@ -46,19 +50,24 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
public static final String DOT_SYMBOL = ".";
private final InferenceProcessorAttributes inferenceProcessorAttributes;
private final boolean ignoreMissing;
private final String functionName;
private final boolean fullResponsePath;
private final boolean ignoreFailure;
private final String modelInput;
private final ScriptService scriptService;
private static Client client;
public static final String TYPE = "ml_inference";
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
// prediction outcomes, return the whole prediction outcome by skipping filtering
public static final String IGNORE_MISSING = "ignore_missing";
public static final String FUNCTION_NAME = "function_name";
public static final String FULL_RESPONSE_PATH = "full_response_path";
public static final String MODEL_INPUT = "model_input";
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
private final NamedXContentRegistry xContentRegistry;

private Configuration suppressExceptionConfiguration = Configuration
.builder()
Expand All @@ -74,10 +83,13 @@ protected MLInferenceIngestProcessor(
String tag,
String description,
boolean ignoreMissing,
String functionName,
boolean fullResponsePath,
boolean ignoreFailure,
String modelInput,
ScriptService scriptService,
Client client
Client client,
NamedXContentRegistry xContentRegistry
) {
super(tag, description);
this.inferenceProcessorAttributes = new InferenceProcessorAttributes(
Expand All @@ -88,10 +100,13 @@ protected MLInferenceIngestProcessor(
maxPredictionTask
);
this.ignoreMissing = ignoreMissing;
this.functionName = functionName;
this.fullResponsePath = fullResponsePath;
this.ignoreFailure = ignoreFailure;
this.modelInput = modelInput;
this.scriptService = scriptService;
this.client = client;
this.xContentRegistry = xContentRegistry;
}

/**
Expand Down Expand Up @@ -167,10 +182,13 @@ private void processPredictions(
List<Map<String, String>> processOutputMap,
int inputMapIndex,
int inputMapSize
) {
) throws IOException {
Map<String, String> modelParameters = new HashMap<>();
Map<String, String> modelConfigs = new HashMap<>();

if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
}
// when no input mapping is provided, default to read all fields from documents as model input
if (inputMapSize == 0) {
Expand All @@ -189,7 +207,22 @@ private void processPredictions(
}
}

ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId());
Set<String> inputMapKeys = new HashSet<>(modelParameters.keySet());
inputMapKeys.removeAll(modelConfigs.keySet());

Map<String, String> inputMappings = new HashMap<>();
for (String k : inputMapKeys) {
inputMappings.put(k, modelParameters.get(k));
}
ActionRequest request = getRemoteModelInferenceRequest(
xContentRegistry,
modelParameters,
modelConfigs,
inputMappings,
inferenceProcessorAttributes.getModelId(),
functionName,
modelInput
);

client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() {

Expand Down Expand Up @@ -429,16 +462,18 @@ public static class Factory implements Processor.Factory {

private final ScriptService scriptService;
private final Client client;
private final NamedXContentRegistry xContentRegistry;

/**
* Constructs a new instance of the Factory class.
*
* @param scriptService the ScriptService instance to be used by the Factory
* @param client the Client instance to be used by the Factory
*/
public Factory(ScriptService scriptService, Client client) {
public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) {
this.scriptService = scriptService;
this.client = client;
this.xContentRegistry = xContentRegistry;
}

/**
Expand All @@ -465,6 +500,10 @@ public MLInferenceIngestProcessor create(
int maxPredictionTask = ConfigurationUtils
.readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS);
boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false);
String functionName = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
String modelInput = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, false);
boolean ignoreFailure = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
Expand Down Expand Up @@ -496,10 +535,13 @@ public MLInferenceIngestProcessor create(
processorTag,
description,
ignoreMissing,
functionName,
fullResponsePath,
ignoreFailure,
modelInput,
scriptService,
client
client,
xContentRegistry
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,27 @@

package org.opensearch.ml.processor;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand Down Expand Up @@ -51,17 +57,57 @@ public interface ModelExecutor {
* @return an ActionRequest instance for remote model inference
* @throws IllegalArgumentException if the input parameters are null
*/
default <T> ActionRequest getRemoteModelInferenceRequest(Map<String, String> parameters, String modelId) {
default <T> ActionRequest getRemoteModelInferenceRequest(
NamedXContentRegistry xContentRegistry,
Map<String, String> parameters,
Map<String, String> modelConfigs,
Map<String, String> inputMappings,
String modelId,
String functionNameStr,
String modelInput
) throws IOException {
if (parameters == null) {
throw new IllegalArgumentException("wrong input. The model input cannot be empty.");
}
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();
FunctionName functionName = FunctionName.REMOTE;
if (functionNameStr != null) {
functionName = FunctionName.from(functionNameStr);
}
// RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();

Map<String, Object> inputParams = new HashMap<>();
if (FunctionName.REMOTE == functionName) {
inputParams.put("parameters", StringUtils.toJson(parameters));
} else {
inputParams.putAll(parameters);
}

String payload = modelInput;
// payload = fillNullParameters(parameters, payload);
StringSubstitutor modelConfigSubstitutor = new StringSubstitutor(modelConfigs, "${model_config.", "}");
payload = modelConfigSubstitutor.replace(payload);
StringSubstitutor inputMapSubstitutor = new StringSubstitutor(inputMappings, "${input_map.", "}");
payload = inputMapSubstitutor.replace(payload);
StringSubstitutor parametersSubstitutor = new StringSubstitutor(inputParams, "${ml_inference.", "}");
payload = parametersSubstitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
}

MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();
// String jsonStr;
// try {
// jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(inputParams));
// } catch (PrivilegedActionException e) {
// throw new IllegalArgumentException("wrong connector");
// }
XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, payload);

ActionRequest request = new MLPredictionTaskRequest(modelId, mlInput, null);
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, functionName.name());
// MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();

return request;
return new MLPredictionTaskRequest(modelId, mlInput);

}

Expand Down

0 comments on commit fd54211

Please sign in to comment.