diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java index 179bf152cd..aafff5b50e 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java @@ -8,7 +8,6 @@ public enum MLTaskType { TRAINING, PREDICTION, - BATCH_PREDICTION, TRAINING_AND_PREDICTION, EXECUTION, @Deprecated @@ -17,5 +16,6 @@ public enum MLTaskType { LOAD_MODEL, REGISTER_MODEL, DEPLOY_MODEL, - BATCH_INGEST + BATCH_INGEST, + BATCH_PREDICTION } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 98f7e1f33c..d92d1f69bb 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -28,6 +28,7 @@ public class MLToolSpec implements ToXContentObject { public static final String TOOL_NAME_FIELD = "name"; public static final String DESCRIPTION_FIELD = "description"; public static final String PARAMETERS_FIELD = "parameters"; + // public static final String CONFIGS_FIELD = "configs"; public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; private String type; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 47fdf7c980..513bdb5616 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -39,7 +39,8 @@ public class MLPreProcessFunction { BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction(); CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction(); MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction(); - BedrockGetBatchInferenceJobPreProcessFunction bedrockGetBatchInferenceJobPreProcessFunction = new BedrockGetBatchInferenceJobPreProcessFunction(); + BedrockGetBatchInferenceJobPreProcessFunction bedrockGetBatchInferenceJobPreProcessFunction = + new BedrockGetBatchInferenceJobPreProcessFunction(); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction); PRE_PROCESS_FUNCTIONS.put(BEDROCK_GET_BATCH_INFERENCE_JOB_INPUT, bedrockGetBatchInferenceJobPreProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java index 7976b1a6ef..a858afdd28 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.connector.functions.postprocess; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.opensearch.ml.common.output.model.ModelTensor; + public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction> { public static final String JOB_ARN = "jobArn"; public static final String PROCESSED_JOB_ARN = "processedJobArn"; @@ -34,14 +34,7 @@ public List process(Map jobInfo) { processedResult.putAll(jobInfo); String jobArn = jobInfo.get(JOB_ARN); processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F")); - modelTensors - .add( - ModelTensor - .builder() - .name("response") - .dataAsMap(processedResult) - .build() - ); + modelTensors.add(ModelTensor.builder().name("response").dataAsMap(processedResult).build()); return modelTensors; } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java index 435ff221a0..3c2a0a2c93 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockGetBatchInferenceJobPreProcessFunction.java @@ -7,16 +7,9 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; - /** * This class provides a pre-processing function for bedrock batch inference job input. * It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java index 6ea26c9eb3..5c75e4c8d2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java @@ -9,7 +9,7 @@ public class MLCancelBatchJobAction extends ActionType { public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction(); - public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job"; + public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel"; private MLCancelBatchJobAction() { super(NAME, MLCancelBatchJobResponse::new); diff --git a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java index 2ffdc32679..2fa75db471 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_MAPPING; + import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -50,6 +52,11 @@ public void testWriteTo() throws IOException { Assert.assertEquals(mlTask, task2); } + @Test + public void testIndexMapping() { + System.out.println(ML_TASK_INDEX_MAPPING); + } + @Test public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 4b14f1af17..e52435db93 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -454,7 +454,7 @@ private static void runTool( Map tmpParameters, ActionListener nextStepListener, String action, - String actionInput, + String actionInput,// {"match_all"}, model: "how many errors last week" Map toolParams ) { if (tools.get(action).validate(toolParams)) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index ce0fdd0eef..5c136f9853 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -90,7 +90,7 @@ public class GetTaskTransportAction extends HandledTransportAction remoteJobStatusFields; volatile String remoteJobCompletedStatusRegex; volatile String remoteJobCancelledStatusRegex; @@ -296,7 +296,7 @@ private void processTaskResponse( remoteJob.putAll(remoteJobStatus); Map updatedTask = new HashMap<>(); updatedTask.put(REMOTE_JOB_FIELD, remoteJob); - + for (String statusField : remoteJobStatusFields) { String statusValue = String.valueOf(remoteJob.get(statusField)); if (remoteJob.containsKey(statusField)) { @@ -322,7 +322,7 @@ private void processTaskResponse( } } } - + mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> { actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); }, e -> { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java index 33c7314be2..8b2c024cd7 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java @@ -23,8 +23,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +//TODO: Rename class and support cancelling more tasks. Now only support cancelling remote job public class RestMLCancelBatchJobAction extends BaseRestHandler { - private static final String ML_CANCEL_BATCH_ACTION = "ml_cancel_batch_action"; + public static final String ML_CANCEL_TASK_ACTION = "ml_cancel_task_action"; /** * Constructor @@ -33,18 +34,13 @@ public RestMLCancelBatchJobAction() {} @Override public String getName() { - return ML_CANCEL_BATCH_ACTION; + return ML_CANCEL_TASK_ACTION; } @Override public List routes() { return ImmutableList - .of( - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel_batch", ML_BASE_URI, PARAMETER_TASK_ID) - ) - ); + .of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel", ML_BASE_URI, PARAMETER_TASK_ID))); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 5af8e3875b..70de7f481f 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -202,44 +202,44 @@ private MLCommonsSettings() {} .boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting> ML_COMMONS_REMOTE_JOB_STATUS_FIELD = Setting - .listSetting( - "plugins.ml_commons.remote_job.status_field", - ImmutableList - .of( - "status", // openai, bedrock, cohere - "TransformJobStatus" // sagemaker - ), - Function.identity(), - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .listSetting( + "plugins.ml_commons.remote_job.status_field", + ImmutableList + .of( + "status", // openai, bedrock, cohere + "TransformJobStatus" // sagemaker + ), + Function.identity(), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.completed", - "(complete|completed)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.completed", + "(complete|completed)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.cancelled", - "(stopped|cancelled)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.cancelled", + "(stopped|cancelled)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.cancelling", - "(stopping|cancelling)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.cancelling", + "(stopping|cancelling)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX = Setting - .simpleString( - "plugins.ml_commons.remote_job.status_regex.expired", - "(expired|timeout)", - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); + .simpleString( + "plugins.ml_commons.remote_job.status_regex.expired", + "(expired|timeout)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java index 98c7795dd9..77fe42ef41 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLBatchIngestionActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.rest.RestMLCancelBatchJobAction.ML_CANCEL_TASK_ACTION; import static org.opensearch.ml.utils.TestHelper.getBatchIngestionRestRequest; import java.io.IOException;