Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.19] support batch task management by periodically polling the remote task via a cron job #3472

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class CommonValue {
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ public enum MLTaskState {
CANCELLED,
COMPLETED_WITH_ERROR,
CANCELLING,
EXPIRED
EXPIRED,
UNREACHABLE
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,28 @@
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = CommonValue.VERSION_2_16_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG = CommonValue.VERSION_2_19_0;
@Setter
private Map<String, String> parameters;
@Setter
private ActionType actionType;
@Setter
private Map<String, String> dlq;

@Builder(toBuilder = true)
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType, Map<String, String> dlq) {
super(MLInputDataType.REMOTE);
this.parameters = parameters;
this.actionType = actionType;
this.dlq = dlq;
}

public RemoteInferenceInputDataSet(Map<String, String> parameters, ActionType actionType) {
this(parameters, actionType, null);
}

public RemoteInferenceInputDataSet(Map<String, String> parameters) {
this(parameters, null);
this(parameters, null, null);
}

public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
Expand All @@ -55,6 +63,13 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
this.actionType = null;
}
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
if (streamInput.readBoolean()) {
dlq = streamInput.readMap(s -> s.readString(), s -> s.readString());
} else {
this.dlq = null;
}
}
}

@Override
Expand All @@ -75,6 +90,14 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeBoolean(false);
}
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DLQ_CONFIG)) {
if (dlq != null) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(dlq, StreamOutput::writeString, StreamOutput::writeString);
} else {
streamOutput.writeBoolean(false);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
public class RemoteInferenceMLInput extends MLInput {
public static final String PARAMETERS_FIELD = "parameters";
public static final String ACTION_TYPE_FIELD = "action_type";
public static final String DLQ_FIELD = "dlq";

public RemoteInferenceMLInput(StreamInput in) throws IOException {
super(in);
Expand All @@ -37,6 +38,7 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
super();
this.algorithm = functionName;
Map<String, String> parameters = null;
Map<String, String> dlq = null;
ActionType actionType = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -50,12 +52,15 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName)
case ACTION_TYPE_FIELD:
actionType = ActionType.from(parser.text());
break;
case DLQ_FIELD:
dlq = StringUtils.getParameterMap(parser.map());
break;
default:
parser.skipChildren();
break;
}
}
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType);
inputDataset = new RemoteInferenceInputDataSet(parameters, actionType, dlq);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,34 @@
public class MLTaskGetRequest extends ActionRequest {
@Getter
String taskId;

@Getter
String tenantId;

// This is to identify if the get request is initiated by user or not. During batch task polling job,
// we also perform get operation. This field is to distinguish between
// these two situations.
@Getter
boolean isUserInitiatedGetTaskRequest;

@Builder
public MLTaskGetRequest(String taskId, String tenantId) {
this(taskId, tenantId, true);
}

@Builder
public MLTaskGetRequest(String taskId, String tenantId, Boolean isUserInitiatedGetTaskRequest) {
this.taskId = taskId;
this.tenantId = tenantId;
this.isUserInitiatedGetTaskRequest = isUserInitiatedGetTaskRequest;
}

public MLTaskGetRequest(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.taskId = in.readString();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
this.isUserInitiatedGetTaskRequest = in.readBoolean();
}

@Override
Expand All @@ -51,6 +65,7 @@ public void writeTo(StreamOutput out) throws IOException {
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
out.writeBoolean(isUserInitiatedGetTaskRequest);
}

@Override
Expand Down
8 changes: 5 additions & 3 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ dependencies {
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
}
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.29.12'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,10 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.engine.annotation.Ingester;

import com.google.common.annotations.VisibleForTesting;
import org.opensearch.ml.engine.utils.S3Utils;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
Expand All @@ -54,7 +48,12 @@ public S3DataIngestion(Client client) {

@Override
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
S3Client s3 = initS3Client(mlBatchIngestionInput);
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);

S3Client s3 = S3Utils.initS3Client(accessKey, secretKey, region, sessionToken);

List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
if (Objects.isNull(s3Uris) || s3Uris.isEmpty()) {
Expand All @@ -77,8 +76,8 @@ public double ingestSingleSource(
boolean isSoleSource,
int bulkSize
) {
String bucketName = getS3BucketName(s3Uri);
String keyName = getS3KeyName(s3Uri);
String bucketName = S3Utils.getS3BucketName(s3Uri);
String keyName = S3Utils.getS3KeyName(s3Uri);
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(keyName).build();
double successRate = 0;

Expand Down Expand Up @@ -153,55 +152,4 @@ public double ingestSingleSource(

return successRate;
}

private String getS3BucketName(String s3Uri) {
// Remove the "s3://" prefix
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, the entire remaining string is the bucket name
if (slashIndex == -1) {
return uriWithoutPrefix;
}
// Otherwise, the bucket name is the substring up to the first slash
return uriWithoutPrefix.substring(0, slashIndex);
}

private String getS3KeyName(String s3Uri) {
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, it means there is no key, return an empty string or handle as needed
if (slashIndex == -1) {
return "";
}
// The key name is the substring after the first slash
return uriWithoutPrefix.substring(slashIndex + 1);
}

@VisibleForTesting
public S3Client initS3Client(MLBatchIngestionInput mlBatchIngestionInput) {
String accessKey = mlBatchIngestionInput.getCredential().get(ACCESS_KEY_FIELD);
String secretKey = mlBatchIngestionInput.getCredential().get(SECRET_KEY_FIELD);
String sessionToken = mlBatchIngestionInput.getCredential().get(SESSION_TOKEN_FIELD);
String region = mlBatchIngestionInput.getCredential().get(REGION_FIELD);

AwsCredentials credentials = sessionToken == null
? AwsBasicCredentials.create(accessKey, secretKey)
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);

try {
S3Client s3 = AccessController
.doPrivileged(
(PrivilegedExceptionAction<S3Client>) () -> S3Client
.builder()
.region(Region.of(region)) // Specify the region here
.credentialsProvider(StaticCredentialsProvider.create(credentials))
.build()
);
return s3;
} catch (PrivilegedActionException e) {
throw new RuntimeException("Can't load credentials", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.utils;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;

import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;

@Log4j2
public class S3Utils {
@VisibleForTesting
public static S3Client initS3Client(String accessKey, String secretKey, String sessionToken, String region) {
AwsCredentials credentials = sessionToken == null
? AwsBasicCredentials.create(accessKey, secretKey)
: AwsSessionCredentials.create(accessKey, secretKey, sessionToken);

try {
S3Client s3 = AccessController
.doPrivileged(
(PrivilegedExceptionAction<S3Client>) () -> S3Client
.builder()
.region(Region.of(region)) // Specify the region here
.credentialsProvider(StaticCredentialsProvider.create(credentials))
.build()
);
return s3;
} catch (PrivilegedActionException e) {
throw new RuntimeException("Can't load credentials", e);
}
}

public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();

s3Client.putObject(request, RequestBody.fromString(content));
log.debug("Successfully uploaded file to S3: s3://{}/{}", bucketName, key);
return null; // Void return type for doPrivileged
});
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
}
}

public static String getS3BucketName(String s3Uri) {
// Remove the "s3://" prefix
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, the entire remaining string is the bucket name
if (slashIndex == -1) {
return uriWithoutPrefix;
}
// Otherwise, the bucket name is the substring up to the first slash
return uriWithoutPrefix.substring(0, slashIndex);
}

public static String getS3KeyName(String s3Uri) {
String uriWithoutPrefix = s3Uri.substring(5);
// Find the first slash after the bucket name
int slashIndex = uriWithoutPrefix.indexOf('/');
// If there is no slash, it means there is no key, return an empty string or handle as needed
if (slashIndex == -1) {
return "";
}
// The key name is the substring after the first slash
return uriWithoutPrefix.substring(slashIndex + 1);
}

}
Loading
Loading