From 784bcb8dff024f5e9565b1b3caa3fa6c791e1522 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 13 Sep 2024 22:02:01 -0700 Subject: [PATCH] test Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/common/MLTaskType.java | 4 +- .../ml/common/agent/MLToolSpec.java | 1 + ...BedrockBatchJobArnPostProcessFunction.java | 13 +--- ...etBatchInferenceJobPreProcessFunction.java | 54 ++++++++++++++ .../task/MLCancelBatchJobAction.java | 2 +- .../org/opensearch/ml/common/MLTaskTests.java | 7 ++ .../algorithms/agent/MLChatAgentRunner.java | 2 +- .../action/tasks/GetTaskTransportAction.java | 6 +- .../ml/rest/RestMLCancelBatchJobAction.java | 12 ++-- .../ml/settings/MLCommonsSettings.java | 70 +++++++++---------- .../rest/RestMLBatchIngestionActionTests.java | 1 + 11 files changed, 112 insertions(+), 60 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java index 179bf152cd..aafff5b50e 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java @@ -8,7 +8,6 @@ public enum MLTaskType { TRAINING, PREDICTION, - BATCH_PREDICTION, TRAINING_AND_PREDICTION, EXECUTION, @Deprecated @@ -17,5 +16,6 @@ public enum MLTaskType { LOAD_MODEL, REGISTER_MODEL, DEPLOY_MODEL, - BATCH_INGEST + BATCH_INGEST, + BATCH_PREDICTION } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 98f7e1f33c..d92d1f69bb 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -28,6 +28,7 @@ public class MLToolSpec implements ToXContentObject { public static final String TOOL_NAME_FIELD = "name"; public static final String DESCRIPTION_FIELD = "description"; public static final String PARAMETERS_FIELD = "parameters"; + // public static final String CONFIGS_FIELD = "configs"; public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; private String type; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java index 7976b1a6ef..a858afdd28 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.connector.functions.postprocess; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.opensearch.ml.common.output.model.ModelTensor; + public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction> { public static final String JOB_ARN = "jobArn"; public static final String PROCESSED_JOB_ARN = "processedJobArn"; @@ -34,14 +34,7 @@ public List process(Map jobInfo) { processedResult.putAll(jobInfo); String jobArn = jobInfo.get(JOB_ARN); processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F")); - modelTensors - .add( - ModelTensor - .builder() - .name("response") - .dataAsMap(processedResult) - .build() - ); + modelTensors.add(ModelTensor.builder().name("response").dataAsMap(processedResult).build()); return modelTensors; } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java new file mode 100644 index 0000000000..3c2a0a2c93 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java @@ -0,0 +1,54 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +/** + * This class provides a pre-processing function for bedrock batch inference job input. + * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. + * The input data is expected to be of type {@link RemoteInferenceInputDataSet}, which must have jobArn parameter. + * The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. + */ +public class BedrockGetBatchInferenceJobPreProcessFunction extends ConnectorPreProcessFunction { + + public static final String JOB_ARN = "jobArn"; + public static final String PROCESSED_JOB_ARN = "processedJobArn"; + + public BedrockGetBatchInferenceJobPreProcessFunction() { + this.returnDirectlyForRemoteInferenceInput = false; + } + + @Override + public void validate(MLInput mlInput) { + if (!(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet)) { + throw new IllegalArgumentException("Wrong input dataset type"); + } + RemoteInferenceInputDataSet inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + if (inputData == null) { + throw new IllegalArgumentException("No input dataset provided"); + } + String jobArn = inputData.getParameters().get(JOB_ARN); + if (jobArn == null) { + throw new IllegalArgumentException("No jobArn provided"); + } + } + + /** + * @param mlInput The input data to be processed. + * This method is to escape slash in jobArn. + */ + @Override + public RemoteInferenceInputDataSet process(MLInput mlInput) { + RemoteInferenceInputDataSet inputData = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + String jobArn = inputData.getParameters().get(JOB_ARN); + inputData.getParameters().put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F")); + return inputData; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java index 6ea26c9eb3..5c75e4c8d2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java @@ -9,7 +9,7 @@ public class MLCancelBatchJobAction extends ActionType { public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction(); - public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job"; + public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel"; private MLCancelBatchJobAction() { super(NAME, MLCancelBatchJobResponse::new); diff --git a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java index 2ffdc32679..2fa75db471 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_MAPPING; + import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -50,6 +52,11 @@ public void testWriteTo() throws IOException { Assert.assertEquals(mlTask, task2); } + @Test + public void testIndexMapping() { + System.out.println(ML_TASK_INDEX_MAPPING); + } + @Test public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 4b14f1af17..e52435db93 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -454,7 +454,7 @@ private static void runTool( Map tmpParameters, ActionListener nextStepListener, String action, - String actionInput, + String actionInput,// {"match_all"}, model: "how many errors last week" Map toolParams ) { if (tools.get(action).validate(toolParams)) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index ce0fdd0eef..5c136f9853 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -90,7 +90,7 @@ public class GetTaskTransportAction extends HandledTransportAction remoteJobStatusFields; volatile String remoteJobCompletedStatusRegex; volatile String remoteJobCancelledStatusRegex; @@ -296,7 +296,7 @@ private void processTaskResponse( remoteJob.putAll(remoteJobStatus); Map updatedTask = new HashMap<>(); updatedTask.put(REMOTE_JOB_FIELD, remoteJob); - + for (String statusField : remoteJobStatusFields) { String statusValue = String.valueOf(remoteJob.get(statusField)); if (remoteJob.containsKey(statusField)) { @@ -322,7 +322,7 @@ private void processTaskResponse( } } } - + mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> { actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); }, e -> { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java index 33c7314be2..8b2c024cd7 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java @@ -23,8 +23,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +//TODO: Rename class and support cancelling more tasks. Now only support cancelling remote job public class RestMLCancelBatchJobAction extends BaseRestHandler { - private static final String ML_CANCEL_BATCH_ACTION = "ml_cancel_batch_action"; + public static final String ML_CANCEL_TASK_ACTION = "ml_cancel_task_action"; /** * Constructor @@ -33,18 +34,13 @@ public RestMLCancelBatchJobAction() {} @Override public String getName() { - return ML_CANCEL_BATCH_ACTION; + return ML_CANCEL_TASK_ACTION; } @Override public List routes() { return ImmutableList - .of( - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel_batch", ML_BASE_URI, PARAMETER_TASK_ID) - ) - ); + .of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel", ML_BASE_URI, PARAMETER_TASK_ID))); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 5af8e3875b..70de7f481f 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -202,44 +202,44 @@ private MLCommonsSettings() {} .boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting> ML_COMMONS_REMOTE_JOB_STATUS_FIELD = Setting - .listSetting( - "plugins.ml_commons.remote_job.status_field", - ImmutableList - .of( - "status", // openai, bedrock, cohere - "TransformJobStatus" // sagemaker - ), - Function.identity(), - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .listSetting( + "plugins.ml_commons.remote_job.status_field", + ImmutableList + .of( + "status", // openai, bedrock, cohere + "TransformJobStatus" // sagemaker + ), + Function.identity(), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.completed", - "(complete|completed)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.completed", + "(complete|completed)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.cancelled", - "(stopped|cancelled)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.cancelled", + "(stopped|cancelled)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.cancelling", - "(stopping|cancelling)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.cancelling", + "(stopping|cancelling)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.expired", - "(expired|timeout)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.expired", + "(expired|timeout)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java index 98c7795dd9..77fe42ef41 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.rest.RestMLCancelBatchJobAction.ML_CANCEL_TASK_ACTION; import static org.opensearch.ml.utils.TestHelper.getBatchIngestionRestRequest; import java.io.IOException;