Skip to content

Commit

Permalink
[Backport 2.x] support batch task management by periodically polling …
Browse files Browse the repository at this point in the history
…the remote task via a cron job (#3458)

* support batch task management by periodically bolling the remote task via a cron job (#3421)

* support batch task management by periocially bolling the remote task via a cron job

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* address comments and resolve dependencies to avoid conflicts

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add unit tests

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* renamed files and added more tests

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

---------

Signed-off-by: Bhavana Goud Ramaram <[email protected]>
(cherry picked from commit 161d789)

* fix failing BWC tests

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* fix missing path in failing BWC tests

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* fix failing BWC tests

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add missing braces

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add missing braces

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add missing braces

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add missing braces

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add missing braces

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add to yml file

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add to yml file

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* add to yml file

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

* refactored code

Signed-off-by: Bhavana Goud Ramaram <[email protected]>

---------

Signed-off-by: Bhavana Goud Ramaram <[email protected]>
Co-authored-by: Bhavana Goud Ramaram <[email protected]>
(cherry picked from commit f083b7e)
  • Loading branch information
opensearch-trigger-bot[bot] authored and github-actions[bot] committed Jan 30, 2025
1 parent 7552412 commit 600990c
Show file tree
Hide file tree
Showing 27 changed files with 1,544 additions and 101 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test_bwc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ jobs:
echo plugin_version $plugin_version
./gradlew assemble
echo "Creating ./plugin/src/test/resources/org/opensearch/ml/bwc..."
mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc
mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc/job-scheduler
mkdir -p ./plugin/src/test/resources/org/opensearch/ml/bwc/ml
- name: Run MLCommons Backwards Compatibility Tests
run: |
echo "Running backwards compatibility tests ..."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class CommonValue {
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ public enum MLTaskState {
CANCELLED,
COMPLETED_WITH_ERROR,
CANCELLING,
EXPIRED
EXPIRED,
UNREACHABLE
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,28 @@
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG = CommonValue.VERSION_2_19_0;
@Setter
private Map<String, String> parameters;
@Setter
private ActionType actionType;
@Setter
private Map<String, String> dlq;

@Builder(toBuilder = true)
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType, Map<String, String> dlq) {
super(MLInputDataType.REMOTE);
this.parameters = parameters;
this.actionType = actionType;
this.dlq = dlq;
}

public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
this(parameters, actionType, null);
}

public RemoteInferenceInputDataSet(Map<String, String> parameters) {
this(parameters, null);
this(parameters, null, null);
}

public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
Expand All @@ -55,6 +63,13 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
this.actionType = null;
}
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
if (streamInput.readBoolean()) {
dlq = streamInput.readMap(s -> s.readString(), s -> s.readString());
} else {
this.dlq = null;
}
}
}

@Override
Expand All @@ -75,6 +90,14 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeBoolean(false);
}
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
if (dlq != null) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(dlq, StreamOutput::writeString, StreamOutput::writeString);
} else {
streamOutput.writeBoolean(false);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
public class RemoteInferenceMLInput extends MLInput {
public static final String PARAMETERS_FIELD = "parameters";
public static final String ACTION_TYPE_FIELD = "action_type";
public static final String DLQ_FIELD = "dlq";

public RemoteInferenceMLInput(StreamInput in) throws IOException {
super(in);
Expand All @@ -37,6 +38,7 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
super();
this.algorithm = functionName;
Map<String, String> parameters = null;
Map<String, String> dlq = null;
ActionType actionType = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -50,12 +52,15 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
case ACTION_TYPE_FIELD:
actionType = ActionType.from(parser.text());
break;
case DLQ_FIELD:
dlq = StringUtils.getParameterMap(parser.map());
break;
default:
parser.skipChildren();
break;
}
}
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType, dlq);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,34 @@
public class MLTaskGetRequest extends ActionRequest {
@Getter
String taskId;

@Getter
String tenantId;

// This is to identify if the get request is initiated by user or not. During batch task polling job,
// we also perform get operation. This field is to distinguish between
// these two situations.
@Getter
boolean isUserInitiatedGetTaskRequest;

@Builder
public MLTaskGetRequest(String taskId, String tenantId) {
this(taskId, tenantId, true);
}

@Builder
public MLTaskGetRequest(String taskId, String tenantId, Boolean isUserInitiatedGetTaskRequest) {
this.taskId = taskId;
this.tenantId = tenantId;
this.isUserInitiatedGetTaskRequest = isUserInitiatedGetTaskRequest;
}

public MLTaskGetRequest(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.taskId = in.readString();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
this.isUserInitiatedGetTaskRequest = in.readBoolean();
}

@Override
Expand All @@ -51,6 +65,7 @@ public void writeTo(StreamOutput out) throws IOException {
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
out.writeBoolean(isUserInitiatedGetTaskRequest);
}

@Override
Expand Down
8 changes: 5 additions & 3 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ dependencies {
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
}
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.29.12'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,10 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.engine.annotation.Ingester;

import com.google.common.annotations.VisibleForTesting;
import org.opensearch.ml.engine.utils.S3Utils;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
Expand All @@ -54,7 +48,12 @@ public S3DataIngestion(Client client) {

@Override
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
S3Client s3 = initS3Client(mlBatchIngestionInput);
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);

S3Client s3 = S3Utils.initS3Client(accessKey, secretKey, region, sessionToken);

List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) {
Expand All @@ -77,8 +76,8 @@ public double ingestSingleSource(
boolean isSoleSource,
int bulkSize
) {
String bucketName = getS3BucketName(s3Uri);
String keyName = getS3KeyName(s3Uri);
String bucketName = S3Utils.getS3BucketName(s3Uri);
String keyName = S3Utils.getS3KeyName(s3Uri);
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
double successRate = 0;

Expand Down Expand Up @@ -153,55 +152,4 @@ public double ingestSingleSource(

return successRate;
}

private String getS3BucketName(String s3Uri) {
// Remove the "s3://" prefix
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, the entire remaining string is the bucket name
if (slashIndex == -1) {
return uriWithoutPrefix;
}
// Otherwise, the bucket name is the substring up to the first slash
return uriWithoutPrefix.substring(0, slashIndex);
}

private String getS3KeyName(String s3Uri) {
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, it means there is no key, return an empty string or handle as needed
if (slashIndex == -1) {
return "";
}
// The key name is the substring after the first slash
return uriWithoutPrefix.substring(slashIndex + 1);
}

@VisibleForTesting
public S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) {
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);

AwsCredentials credentials = sessionToken == null
? AwsBasicCredentials.create(accessKey, secretKey)
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);

try {
S3Client s3 = AccessController
.doPrivileged(
(PrivilegedExceptionAction<S3Client>) () -> S3Client
.builder()
.region(Region.of(region)) // Specify the region here
.credentialsProvider(StaticCredentialsProvider.create(credentials))
.build()
);
return s3;
} catch (PrivilegedActionException e) {
throw new RuntimeException("Can't load credentials", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.utils;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;

import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;

@Log4j2
public class S3Utils {
@VisibleForTesting
public static S3Client initS3Client(String accessKey, String secretKey, String sessionToken, String region) {
AwsCredentials credentials = sessionToken == null
? AwsBasicCredentials.create(accessKey, secretKey)
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);

try {
S3Client s3 = AccessController
.doPrivileged(
(PrivilegedExceptionAction<S3Client>) () -> S3Client
.builder()
.region(Region.of(region)) // Specify the region here
.credentialsProvider(StaticCredentialsProvider.create(credentials))
.build()
);
return s3;
} catch (PrivilegedActionException e) {
throw new RuntimeException("Can't load credentials", e);
}
}

public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();

s3Client.putObject(request, RequestBody.fromString(content));
log.debug("Successfully uploaded file to S3: s3://{}/{}", bucketName, key);
return null; // Void return type for doPrivileged
});
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
}
}

public static String getS3BucketName(String s3Uri) {
// Remove the "s3://" prefix
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, the entire remaining string is the bucket name
if (slashIndex == -1) {
return uriWithoutPrefix;
}
// Otherwise, the bucket name is the substring up to the first slash
return uriWithoutPrefix.substring(0, slashIndex);
}

public static String getS3KeyName(String s3Uri) {
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, it means there is no key, return an empty string or handle as needed
if (slashIndex == -1) {
return "";
}
// The key name is the substring after the first slash
return uriWithoutPrefix.substring(slashIndex + 1);
}

}
Loading

0 comments on commit 600990c

Please sign in to comment.