Skip to content

Commit

Permalink
Modify S3 Source to create multiple SqsWorkers (#4239)
Browse files Browse the repository at this point in the history
* Modify S3 Source to create multiple SqsWorkers

Signed-off-by: Krishna Kondaka <[email protected]>

* Addressed review comments and added integration test case

Signed-off-by: Krishna Kondaka <[email protected]>

---------

Signed-off-by: Krishna Kondaka <[email protected]>
Co-authored-by: Krishna Kondaka <[email protected]>
  • Loading branch information
kkondaka and Krishna Kondaka authored Mar 6, 2024
1 parent 5d1edb6 commit 6a30c6f
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.s3;

import com.linecorp.armeria.client.retry.Backoff;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Timer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.plugins.source.s3.configuration.NotificationSourceOption;
import org.opensearch.dataprepper.plugins.source.s3.configuration.OnErrorOption;
import org.opensearch.dataprepper.plugins.source.s3.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.source.s3.configuration.SqsOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.sqs.SqsClient;

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.StringStartsWith.startsWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class SqsServiceIT {
private SqsClient sqsClient;
private S3Service s3Service;
@Mock
private SqsOptions sqsOptions;
private S3SourceConfig s3SourceConfig;
private PluginMetrics pluginMetrics;
private S3ObjectGenerator s3ObjectGenerator;
private String bucket;
private AcknowledgementSetManager acknowledgementSetManager;
private AwsAuthenticationOptions awsAuthenticationOptions;
private AtomicInteger deletedCount;
private Counter deletedCounter;
private Counter numMessagesCounter;
private AtomicInteger numMessages;

@BeforeEach
void setUp() {
String receivedMessages = SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME;
acknowledgementSetManager = mock(AcknowledgementSetManager.class);
awsAuthenticationOptions = mock(AwsAuthenticationOptions.class);
final S3Client s3Client = S3Client.builder()
.region(Region.of(System.getProperty("tests.s3source.region")))
.build();
bucket = System.getProperty("tests.s3source.bucket");
s3ObjectGenerator = new S3ObjectGenerator(s3Client, bucket);
s3Service = mock(S3Service.class);

sqsClient = SqsClient.builder()
.region(Region.of(System.getProperty("tests.s3source.region")))
.build();

deletedCount = new AtomicInteger(0);
numMessages = new AtomicInteger(0);
pluginMetrics = mock(PluginMetrics.class);
final DistributionSummary distributionSummary = mock(DistributionSummary.class);
final Timer sqsMessageDelayTimer = mock(Timer.class);

deletedCounter = mock(Counter.class);
numMessagesCounter = mock(Counter.class);
lenient().when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(numMessagesCounter);
lenient().when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(deletedCounter);
lenient().when(pluginMetrics.summary(anyString())).thenReturn(distributionSummary);
when(pluginMetrics.timer(anyString())).thenReturn(sqsMessageDelayTimer);
lenient().doAnswer((val) -> {
int x = numMessages.addAndGet(((Double)val.getArgument(0)).intValue());
return null;
}).when(numMessagesCounter).increment(any(Double.class));

s3SourceConfig = mock(S3SourceConfig.class);
sqsOptions = mock(SqsOptions.class);
when(sqsOptions.getSqsUrl()).thenReturn(System.getProperty("tests.s3source.queue.url"));
when(sqsOptions.getVisibilityTimeout()).thenReturn(Duration.ofSeconds(60));
when(sqsOptions.getMaximumMessages()).thenReturn(10);
when(sqsOptions.getWaitTime()).thenReturn(Duration.ofSeconds(10));
when(s3SourceConfig.getSqsOptions()).thenReturn(sqsOptions);
lenient().when(s3SourceConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES);

when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(System.getProperty("tests.s3source.region")));
when(s3SourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
when(s3SourceConfig.getAcknowledgements()).thenReturn(false);
lenient().when(s3SourceConfig.getNotificationSource()).thenReturn(NotificationSourceOption.S3);

// Clear SQS queue messages before running each test
clearSqsQueue();
numMessages = new AtomicInteger(0);
}

private SqsService createObjectUnderTest() {
final AwsCredentialsProvider awsCredentialsProvider = DefaultCredentialsProvider.create();
return new SqsService(acknowledgementSetManager, s3SourceConfig, s3Service, pluginMetrics, awsCredentialsProvider);
}

private void writeToS3(final int numberOfObjectsToWrite) throws IOException {
final int numberOfRecords = 100;
final NewlineDelimitedRecordsGenerator newlineDelimitedRecordsGenerator = new NewlineDelimitedRecordsGenerator();
for (int i = 0; i < numberOfObjectsToWrite; i++) {
final String key = "s3 source/sqs/" + UUID.randomUUID() + "_" + Instant.now().toString() + newlineDelimitedRecordsGenerator.getFileExtension();
// isCompressionEnabled is set to false since we test for compression in S3ObjectWorkerIT
s3ObjectGenerator.write(numberOfRecords, key, newlineDelimitedRecordsGenerator, false);
}
}

@ParameterizedTest
@ValueSource(ints = {1, 2, 5})
public void test_sqsService(int numWorkers) throws IOException {
int numberOfObjectsToWrite = 5;

when(s3SourceConfig.getNumWorkers()).thenReturn(numWorkers);
final SqsService objectUnderTest = createObjectUnderTest();
writeToS3(numberOfObjectsToWrite);
numMessages.set(0);
objectUnderTest.start();
await().atMost(Duration.ofSeconds(15))
.untilAsserted(() -> {
assertThat(numMessages.get(), equalTo(numberOfObjectsToWrite));
});
final ArgumentCaptor<S3ObjectReference> s3ObjectReferenceArgumentCaptor = ArgumentCaptor.forClass(S3ObjectReference.class);
verify(s3Service, atLeastOnce()).addS3Object(s3ObjectReferenceArgumentCaptor.capture(), eq(null));
assertThat(s3ObjectReferenceArgumentCaptor.getValue().getBucketName(), equalTo(bucket));
assertThat(s3ObjectReferenceArgumentCaptor.getValue().getKey(), startsWith("s3 source/sqs/"));
objectUnderTest.stop();
}

private void clearSqsQueue() {
Backoff backoff = Backoff.exponential(SqsService.INITIAL_DELAY, SqsService.MAXIMUM_DELAY).withJitter(SqsService.JITTER_RATE)
.withMaxAttempts(Integer.MAX_VALUE);
final SqsWorker sqsWorker = new SqsWorker(acknowledgementSetManager, sqsClient, s3Service, s3SourceConfig, pluginMetrics, backoff);
//final SqsService objectUnderTest = createObjectUnderTest();
int sqsMessagesProcessed;
do {
sqsMessagesProcessed = sqsWorker.processSqsMessages();
}
while (sqsMessagesProcessed > 0);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import jakarta.validation.Valid;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.Max;
import org.opensearch.dataprepper.model.configuration.PluginModel;
import org.opensearch.dataprepper.plugins.codec.CompressionOption;
import org.opensearch.dataprepper.plugins.source.s3.configuration.AwsAuthenticationOptions;
Expand All @@ -24,6 +26,7 @@

public class S3SourceConfig {
static final Duration DEFAULT_BUFFER_TIMEOUT = Duration.ofSeconds(10);
static final int DEFAULT_NUMBER_OF_WORKERS = 1;
static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 100;
static final String DEFAULT_METADATA_ROOT_KEY = "s3/";

Expand All @@ -43,6 +46,12 @@ public class S3SourceConfig {
@Valid
private SqsOptions sqsOptions;

@JsonProperty("workers")
@Min(1)
@Max(1000)
@Valid
private int numWorkers = DEFAULT_NUMBER_OF_WORKERS;

@JsonProperty("aws")
@NotNull
@Valid
Expand Down Expand Up @@ -101,6 +110,10 @@ boolean getAcknowledgements() {
return acknowledgments;
}

public int getNumWorkers() {
return numWorkers;
}

public CompressionOption getCompression() {
return compression;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
import software.amazon.awssdk.services.sqs.SqsClient;

import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;

public class SqsService {
private static final Logger LOG = LoggerFactory.getLogger(SqsService.class);
static final long SHUTDOWN_TIMEOUT = 30L;
static final long INITIAL_DELAY = Duration.ofSeconds(20).toMillis();
static final long MAXIMUM_DELAY = Duration.ofMinutes(5).toMillis();
static final double JITTER_RATE = 0.20;
Expand All @@ -28,9 +32,7 @@ public class SqsService {
private final SqsClient sqsClient;
private final PluginMetrics pluginMetrics;
private final AcknowledgementSetManager acknowledgementSetManager;

private Thread sqsWorkerThread;
private SqsWorker sqsWorker;
private ExecutorService executorService;

public SqsService(final AcknowledgementSetManager acknowledgementSetManager,
final S3SourceConfig s3SourceConfig,
Expand All @@ -42,14 +44,15 @@ public SqsService(final AcknowledgementSetManager acknowledgementSetManager,
this.pluginMetrics = pluginMetrics;
this.acknowledgementSetManager = acknowledgementSetManager;
this.sqsClient = createSqsClient(credentialsProvider);
executorService = Executors.newFixedThreadPool(s3SourceConfig.getNumWorkers());
}

public void start() {
final Backoff backoff = Backoff.exponential(INITIAL_DELAY, MAXIMUM_DELAY).withJitter(JITTER_RATE)
.withMaxAttempts(Integer.MAX_VALUE);
sqsWorker = new SqsWorker(acknowledgementSetManager, sqsClient, s3Accessor, s3SourceConfig, pluginMetrics, backoff);
sqsWorkerThread = new Thread(sqsWorker);
sqsWorkerThread.start();
for (int i = 0; i < s3SourceConfig.getNumWorkers(); i++) {
executorService.submit(new SqsWorker(acknowledgementSetManager, sqsClient, s3Accessor, s3SourceConfig, pluginMetrics, backoff));
}
}

SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) {
Expand All @@ -64,6 +67,19 @@ SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) {
}

public void stop() {
sqsWorker.stop();
sqsClient.close();
executorService.shutdown();
try {
if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT, TimeUnit.SECONDS)) {
LOG.warn("Failed to terminate SqsWorkers");
executorService.shutdownNow();
}
} catch (InterruptedException e) {
if (e.getCause() instanceof InterruptedException) {
LOG.error("Interrupted during shutdown, exiting uncleanly...", e);
executorService.shutdownNow();
Thread.currentThread().interrupt();
}
}
}
}

0 comments on commit 6a30c6f

Please sign in to comment.