Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Sep 14, 2024
1 parent 8d63bf5 commit 2ee6147
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 68 deletions.
4 changes: 2 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLTaskType.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
public enum MLTaskType {
TRAINING,
PREDICTION,
BATCH_PREDICTION,
TRAINING_AND_PREDICTION,
EXECUTION,
@Deprecated
Expand All @@ -17,5 +16,6 @@ public enum MLTaskType {
LOAD_MODEL,
REGISTER_MODEL,
DEPLOY_MODEL,
BATCH_INGEST
BATCH_INGEST,
BATCH_PREDICTION
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class MLToolSpec implements ToXContentObject {
public static final String TOOL_NAME_FIELD = "name";
public static final String DESCRIPTION_FIELD = "description";
public static final String PARAMETERS_FIELD = "parameters";
// public static final String CONFIGS_FIELD = "configs";
public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response";

private String type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public class MLPreProcessFunction {
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
BedrockGetBatchInferenceJobPreProcessFunction bedrockGetBatchInferenceJobPreProcessFunction = new BedrockGetBatchInferenceJobPreProcessFunction();
BedrockGetBatchInferenceJobPreProcessFunction bedrockGetBatchInferenceJobPreProcessFunction =
new BedrockGetBatchInferenceJobPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(BEDROCK_GET_BATCH_INFERENCE_JOB_INPUT, bedrockGetBatchInferenceJobPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

package org.opensearch.ml.common.connector.functions.postprocess;

import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.output.model.ModelTensor;

public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction<Map<String, String>> {
public static final String JOB_ARN = "jobArn";
public static final String PROCESSED_JOB_ARN = "processedJobArn";
Expand All @@ -34,14 +34,7 @@ public List<ModelTensor> process(Map<String, String> jobInfo) {
processedResult.putAll(jobInfo);
String jobArn = jobInfo.get(JOB_ARN);
processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F"));
modelTensors
.add(
ModelTensor
.builder()
.name("response")
.dataAsMap(processedResult)
.build()
);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(processedResult).build());
return modelTensors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,9 @@

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

/**
* This class provides a pre-processing function for bedrock batch inference job input.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> {
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction();
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job";
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel";

private MLCancelBatchJobAction() {
super(NAME, MLCancelBatchJobResponse::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.common;

import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_MAPPING;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
Expand Down Expand Up @@ -50,6 +52,11 @@ public void testWriteTo() throws IOException {
Assert.assertEquals(mlTask, task2);
}

@Test
public void testIndexMapping() {
System.out.println(ML_TASK_INDEX_MAPPING);
}

@Test
public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ private static void runTool(
Map<String, String> tmpParameters,
ActionListener<Object> nextStepListener,
String action,
String actionInput,
String actionInput,// {"match_all"}, model: "how many errors last week"
Map<String, String> toolParams
) {
if (tools.get(action).validate(toolParams)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest

MLTaskManager mlTaskManager;
MLModelCacheHelper modelCacheHelper;

volatile List<String> remoteJobStatusFields;
volatile String remoteJobCompletedStatusRegex;
volatile String remoteJobCancelledStatusRegex;
Expand Down Expand Up @@ -296,7 +296,7 @@ private void processTaskResponse(
remoteJob.putAll(remoteJobStatus);
Map<String, Object> updatedTask = new HashMap<>();
updatedTask.put(REMOTE_JOB_FIELD, remoteJob);

for (String statusField : remoteJobStatusFields) {
String statusValue = String.valueOf(remoteJob.get(statusField));
if (remoteJob.containsKey(statusField)) {
Expand All @@ -322,7 +322,7 @@ private void processTaskResponse(
}
}
}

mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> {
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
}, e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

//TODO: Rename class and support cancelling more tasks. Now only support cancelling remote job
public class RestMLCancelBatchJobAction extends BaseRestHandler {
private static final String ML_CANCEL_BATCH_ACTION = "ml_cancel_batch_action";
public static final String ML_CANCEL_TASK_ACTION = "ml_cancel_task_action";

/**
* Constructor
Expand All @@ -33,18 +34,13 @@ public RestMLCancelBatchJobAction() {}

@Override
public String getName() {
return ML_CANCEL_BATCH_ACTION;
return ML_CANCEL_TASK_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList
.of(
new Route(
RestRequest.Method.POST,
String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel_batch", ML_BASE_URI, PARAMETER_TASK_ID)
)
);
.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel", ML_BASE_URI, PARAMETER_TASK_ID)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,44 +202,44 @@ private MLCommonsSettings() {}
.boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<List<String>> ML_COMMONS_REMOTE_JOB_STATUS_FIELD = Setting
.listSetting(
"plugins.ml_commons.remote_job.status_field",
ImmutableList
.of(
"status", // openai, bedrock, cohere
"TransformJobStatus" // sagemaker
),
Function.identity(),
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
.listSetting(
"plugins.ml_commons.remote_job.status_field",
ImmutableList
.of(
"status", // openai, bedrock, cohere
"TransformJobStatus" // sagemaker
),
Function.identity(),
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static final Setting<String> ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX = Setting
.simpleString(
"plugins.ml_commons.remote_job.status_regex.completed",
"(complete|completed)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
.simpleString(
"plugins.ml_commons.remote_job.status_regex.completed",
"(complete|completed)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final Setting<String> ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX = Setting
.simpleString(
"plugins.ml_commons.remote_job.status_regex.cancelled",
"(stopped|cancelled)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
.simpleString(
"plugins.ml_commons.remote_job.status_regex.cancelled",
"(stopped|cancelled)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final Setting<String> ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX = Setting
.simpleString(
"plugins.ml_commons.remote_job.status_regex.cancelling",
"(stopping|cancelling)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
.simpleString(
"plugins.ml_commons.remote_job.status_regex.cancelling",
"(stopping|cancelling)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final Setting<String> ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX = Setting
.simpleString(
"plugins.ml_commons.remote_job.status_regex.expired",
"(expired|timeout)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
.simpleString(
"plugins.ml_commons.remote_job.status_regex.expired",
"(expired|timeout)",
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.rest.RestMLCancelBatchJobAction.ML_CANCEL_TASK_ACTION;
import static org.opensearch.ml.utils.TestHelper.getBatchIngestionRestRequest;

import java.io.IOException;
Expand Down

0 comments on commit 2ee6147

Please sign in to comment.