Skip to content

Commit

Permalink
renamed files and added more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Goud Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Jan 28, 2025
1 parent d0a9973 commit b553d8e
Show file tree
Hide file tree
Showing 15 changed files with 483 additions and 125 deletions.
6 changes: 3 additions & 3 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ dependencies {
}
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'

implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
compileOnly group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation group: 'org.json', name: 'json', version: '20231013'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,8 @@ public static S3Client initS3Client(String accessKey, String secretKey, String s
}
}

public static void putObject(
String accessKey,
String secretKey,
String sessionToken,
String region,
String bucketName,
String key,
String content
) {
try (S3Client s3Client = initS3Client(accessKey, secretKey, sessionToken, region)) {
public static void putObject(S3Client s3Client, String bucketName, String key, String content) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build();

Expand All @@ -63,8 +55,6 @@ public static void putObject(
});
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to upload file to S3: s3://" + bucketName + "/" + key, e);
} catch (Exception e) {
log.error("Unexpected error during S3 upload", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.utils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;

public class S3UtilsTest {

@Mock
private S3Client s3Client;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
}

@Test
public void testInitS3Client() {
String accessKey = "test-access-key";
String secretKey = "test-secret-key";
String sessionToken = "test-session-token";
String region = "us-west-2";

S3Client client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region);
assertNotNull(client);
}

@Test
public void testInitS3ClientWithoutSessionToken() {
String accessKey = "test-access-key";
String secretKey = "test-secret-key";
String region = "us-west-2";

S3Client client = S3Utils.initS3Client(accessKey, secretKey, null, region);
assertNotNull(client);
}

@Test
public void testPutObject() {
String bucketName = "test-bucket";
String key = "test-key";
String content = "test-content";

when(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))).thenReturn(PutObjectResponse.builder().build());

S3Utils.putObject(s3Client, bucketName, key, content);

verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), any(RequestBody.class));
}

@Test
public void testGetS3BucketName() {
String s3Uri = "s3://test-bucket/path/to/file";
assertEquals("test-bucket", S3Utils.getS3BucketName(s3Uri));

s3Uri = "s3://test-bucket";
assertEquals("test-bucket", S3Utils.getS3BucketName(s3Uri));
}

@Test
public void testGetS3KeyName() {
String s3Uri = "s3://test-bucket/path/to/file";
assertEquals("path/to/file", S3Utils.getS3KeyName(s3Uri));

s3Uri = "s3://test-bucket";
assertEquals("", S3Utils.getS3KeyName(s3Uri));
}
}
14 changes: 13 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ dependencies {
implementation project(':opensearch-ml-memory')
compileOnly "com.google.guava:guava:32.1.3-jre"

implementation group: 'software.amazon.awssdk', name: 'aws-core', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 's3', version: '2.29.12'
implementation group: 'software.amazon.awssdk', name: 'regions', version: '2.29.12'

implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: '2.29.12'

implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: '2.29.12'

implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: '2.29.12'

zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
Expand Down Expand Up @@ -353,7 +363,9 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.action.models.DeleteModelTransportAction.2',
'org.opensearch.ml.model.MLModelCacheHelper',
'org.opensearch.ml.model.MLModelCacheHelper.1',
'org.opensearch.ml.action.tasks.CancelBatchJobTransportAction'
'org.opensearch.ml.action.tasks.CancelBatchJobTransportAction',
'org.opensearch.ml.jobs.MLBatchTaskUpdateExtension',
'org.opensearch.ml.jobs.MLBatchTaskUpdateJobRunner'

]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.S3Exception;

@Log4j2
public class GetTaskTransportAction extends HandledTransportAction<ActionRequest, MLTaskGetResponse> {
Expand Down Expand Up @@ -129,7 +131,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
volatile Pattern remoteJobFailedStatusRegexPattern;
private final MLEngine mlEngine;

private Map<String, String> decryptedCredential;
// private Map<String, String> decryptedCredential;

@Inject
public GetTaskTransportAction(
Expand Down Expand Up @@ -456,19 +458,25 @@ private void executeConnector(
connector.addAction(connectorAction);
}

decryptedCredential = connector.getDecryptedCredential();

if (decryptedCredential == null || decryptedCredential.isEmpty()) {
decryptedCredential = mlEngine.getConnectorCredential(connector);
}

final Map<String, String> decryptedCredential = connector.getDecryptedCredential() != null
&& !connector.getDecryptedCredential().isEmpty()
? mlEngine.getConnectorCredential(connector)
: connector.getDecryptedCredential();
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, isUserInitiatedGetTaskRequest, taskResponse, remoteJob, actionListener);
processTaskResponse(
mlTask,
taskId,
isUserInitiatedGetTaskRequest,
taskResponse,
remoteJob,
decryptedCredential,
actionListener
);
}, e -> {
// When the request to remote service fails, we will retry the request for next 10 minutes (10 runs).
// If it fails even then, we mark it as unreachable in task index and send message to DLQ
Expand Down Expand Up @@ -500,6 +508,7 @@ protected void processTaskResponse(
Boolean isUserInitiatedGetTaskRequest,
MLTaskResponse taskResponse,
Map<String, Object> remoteJob,
Map<String, String> decryptedCredential,
ActionListener<MLTaskGetResponse> actionListener
) {
try {
Expand Down Expand Up @@ -566,15 +575,18 @@ protected void updateDLQ(MLTask mlTask, Map<String, String> decryptedCredential)
log.error("Failed to get the bucket name and region from batch predict request");
}
remoteJobDetails.remove("dlq");

String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name"));
String s3ObjectKey = "BatchJobFailure_" + jobName;
String content = mlTask.getState().equals(UNREACHABLE)
? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError())
: remoteJobDetails.toString();

S3Utils.putObject(accessKey, secretKey, sessionToken, region, bucketName, s3ObjectKey, content);
log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now());
try (S3Client s3Client = S3Utils.initS3Client(accessKey, secretKey, sessionToken, region)) {
String jobName = (String) remoteJobDetails.getOrDefault("TransformJobName", remoteJob.get("job_name"));
String s3ObjectKey = "BatchJobFailure_" + jobName;
String content = mlTask.getState().equals(UNREACHABLE)
? String.format("Unable to reach the Job: %s. Error Message: %s", jobName, mlTask.getError())
: remoteJobDetails.toString();

S3Utils.putObject(s3Client, bucketName, s3ObjectKey, content);
log.debug("Task status successfully uploaded to S3 for task ID: {} at {}", taskId, Instant.now());
}
} catch (S3Exception e) {
log.error("Failed to update task status for task: {}. S3 Exception: {}", taskId, e.awsErrorDetails().errorMessage());
} catch (Exception e) {
log.error("Failed to update task status for task: " + taskId, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.jobs;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.time.Instant;

Expand All @@ -16,7 +18,7 @@
import org.opensearch.jobscheduler.spi.schedule.ScheduleParser;
import org.opensearch.ml.common.CommonValue;

public class BatchPredictTaskUpdateJob implements JobSchedulerExtension {
public class MLBatchTaskUpdateExtension implements JobSchedulerExtension {

@Override
public String getJobType() {
Expand All @@ -31,32 +33,32 @@ public ScheduledJobRunner getJobRunner() {
@Override
public ScheduledJobParser getJobParser() {
return (parser, id, jobDocVersion) -> {
MLBatchPredictTaskUpdateJobParameter jobParameter = new MLBatchPredictTaskUpdateJobParameter();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLBatchTaskUpdateJobParameter jobParameter = new MLBatchTaskUpdateJobParameter();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);

while (!parser.nextToken().equals(XContentParser.Token.END_OBJECT)) {
String fieldName = parser.currentName();
parser.nextToken();
switch (fieldName) {
case MLBatchPredictTaskUpdateJobParameter.NAME_FIELD:
case MLBatchTaskUpdateJobParameter.NAME_FIELD:
jobParameter.setJobName(parser.text());
break;
case MLBatchPredictTaskUpdateJobParameter.ENABLED_FILED:
case MLBatchTaskUpdateJobParameter.ENABLED_FILED:
jobParameter.setEnabled(parser.booleanValue());
break;
case MLBatchPredictTaskUpdateJobParameter.ENABLED_TIME_FILED:
case MLBatchTaskUpdateJobParameter.ENABLED_TIME_FILED:
jobParameter.setEnabledTime(parseInstantValue(parser));
break;
case MLBatchPredictTaskUpdateJobParameter.LAST_UPDATE_TIME_FIELD:
case MLBatchTaskUpdateJobParameter.LAST_UPDATE_TIME_FIELD:
jobParameter.setLastUpdateTime(parseInstantValue(parser));
break;
case MLBatchPredictTaskUpdateJobParameter.SCHEDULE_FIELD:
case MLBatchTaskUpdateJobParameter.SCHEDULE_FIELD:
jobParameter.setSchedule(ScheduleParser.parse(parser));
break;
case MLBatchPredictTaskUpdateJobParameter.LOCK_DURATION_SECONDS:
case MLBatchTaskUpdateJobParameter.LOCK_DURATION_SECONDS:
jobParameter.setLockDurationSeconds(parser.longValue());
break;
case MLBatchPredictTaskUpdateJobParameter.JITTER:
case MLBatchTaskUpdateJobParameter.JITTER:
jobParameter.setJitter(parser.doubleValue());
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* It adds an additional "indexToWatch" field to {@link ScheduledJobParameter}, which stores the index
* the job runner will watch.
*/
public class MLBatchPredictTaskUpdateJobParameter implements ScheduledJobParameter {
public class MLBatchTaskUpdateJobParameter implements ScheduledJobParameter {
public static final String NAME_FIELD = "name";
public static final String ENABLED_FILED = "enabled";
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
Expand All @@ -38,9 +38,9 @@ public class MLBatchPredictTaskUpdateJobParameter implements ScheduledJobParamet
private Long lockDurationSeconds;
private Double jitter;

public MLBatchPredictTaskUpdateJobParameter() {}
public MLBatchTaskUpdateJobParameter() {}

public MLBatchPredictTaskUpdateJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter) {
public MLBatchTaskUpdateJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter) {
this.jobName = name;
this.schedule = schedule;
this.lockDurationSeconds = lockDurationSeconds;
Expand Down
Loading

0 comments on commit b553d8e

Please sign in to comment.