From 29701d0bd9a4c3be65e022e4f637173a80b24efe Mon Sep 17 00:00:00 2001 From: Jeremy Michael Date: Wed, 18 Dec 2024 14:56:46 -0800 Subject: [PATCH] added unit tests and addressed comments in design doc Signed-off-by: Jeremy Michael --- data-prepper-plugins/sqs-source/README.md | 2 +- .../plugins/source/sqs/QueueConfig.java | 24 +- .../source/sqs/RawSqsMessageHandler.java | 1 - .../plugins/source/sqs/SqsService.java | 1 - .../plugins/source/sqs/SqsWorker.java | 40 ++- .../sqs/AwsAuthenticationAdapterTest.java | 90 ++++++ .../sqs/AwsAuthenticationOptionsTest.java | 72 +++++ .../plugins/source/sqs/SqsServiceTest.java | 80 +++++ .../plugins/source/sqs/SqsSourceTest.java | 63 ++++ .../plugins/source/sqs/SqsWorkerTest.java | 282 ++++++++++++++++++ 10 files changed, 616 insertions(+), 39 deletions(-) create mode 100644 data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationAdapterTest.java create mode 100644 data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationOptionsTest.java create mode 100644 data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java create mode 100644 data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsSourceTest.java create mode 100644 data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java diff --git a/data-prepper-plugins/sqs-source/README.md b/data-prepper-plugins/sqs-source/README.md index 47944262ee..ff4313605f 100644 --- a/data-prepper-plugins/sqs-source/README.md +++ b/data-prepper-plugins/sqs-source/README.md @@ -19,4 +19,4 @@ sqs-pipeline: region: sts_role_arn: sink: - - stdout: + - stdout: diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java index 35d1ddc3b8..6b0bf4dfad 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java @@ -13,14 +13,13 @@ public class QueueConfig { - private static final int DEFAULT_MAXIMUM_MESSAGES = 10; + private static final Integer DEFAULT_MAXIMUM_MESSAGES = null; private static final Boolean DEFAULT_VISIBILITY_DUPLICATE_PROTECTION = false; - private static final Duration DEFAULT_VISIBILITY_TIMEOUT_SECONDS = Duration.ofSeconds(30); + private static final Duration DEFAULT_VISIBILITY_TIMEOUT_SECONDS = null; private static final Duration DEFAULT_VISIBILITY_DUPLICATE_PROTECTION_TIMEOUT = Duration.ofHours(2); private static final Duration DEFAULT_WAIT_TIME_SECONDS = Duration.ofSeconds(20); private static final Duration DEFAULT_POLL_DELAY_SECONDS = Duration.ofSeconds(0); static final int DEFAULT_NUMBER_OF_WORKERS = 1; - private static final int DEFAULT_BATCH_SIZE = 10; @JsonProperty("url") @NotNull @@ -33,14 +32,7 @@ public class QueueConfig { @JsonProperty("maximum_messages") @Min(1) @Max(10) - private int maximumMessages = DEFAULT_MAXIMUM_MESSAGES; - - @JsonProperty("batch_size") - @Max(10) - private Integer batchSize = DEFAULT_BATCH_SIZE; - - @JsonProperty("polling_frequency") - private Duration pollingFrequency = Duration.ZERO; + private Integer maximumMessages = DEFAULT_MAXIMUM_MESSAGES; @JsonProperty("poll_delay") @DurationMin(seconds = 0) @@ -68,15 +60,7 @@ public String getUrl() { return url; } - public Duration getPollingFrequency() { - return pollingFrequency; - } - - public Integer getBatchSize() { - return batchSize; - } - - public int getMaximumMessages() { + public Integer getMaximumMessages() { return maximumMessages; } diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/RawSqsMessageHandler.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/RawSqsMessageHandler.java index d3af0285bb..a756a5e3b6 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/RawSqsMessageHandler.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/RawSqsMessageHandler.java @@ -16,7 +16,6 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.sqs.model.Message; import com.fasterxml.jackson.databind.node.ObjectNode; -import java.util.Map; import java.time.Instant; import java.util.Objects; diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java index 457696810d..a13b81dd4a 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java @@ -15,7 +15,6 @@ import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.retry.RetryPolicy; import software.amazon.awssdk.services.sqs.SqsClient; - import org.opensearch.dataprepper.buffer.common.BufferAccumulator; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java index 27cfd54f84..b5b9f14deb 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java @@ -7,7 +7,6 @@ import com.linecorp.armeria.client.retry.Backoff; import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Timer; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; @@ -21,7 +20,6 @@ import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; -import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResultEntry; import software.amazon.awssdk.services.sqs.model.Message; import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; import software.amazon.awssdk.services.sqs.model.SqsException; @@ -44,7 +42,6 @@ public class SqsWorker implements Runnable { static final String SQS_MESSAGES_DELETED_METRIC_NAME = "sqsMessagesDeleted"; static final String SQS_MESSAGES_FAILED_METRIC_NAME = "sqsMessagesFailed"; static final String SQS_MESSAGES_DELETE_FAILED_METRIC_NAME = "sqsMessagesDeleteFailed"; - static final String SQS_MESSAGE_DELAY_METRIC_NAME = "sqsMessageDelay"; static final String SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangedCount"; static final String SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangeFailedCount"; static final String ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME = "acknowledgementSetCallbackCounter"; @@ -58,7 +55,6 @@ public class SqsWorker implements Runnable { private final Counter acknowledgementSetCallbackCounter; private final Counter sqsVisibilityTimeoutChangedCount; private final Counter sqsVisibilityTimeoutChangeFailedCount; - private final Timer sqsMessageDelayTimer; private final Backoff standardBackoff; private final QueueConfig queueConfig; private int failedAttemptCount; @@ -93,7 +89,6 @@ public SqsWorker(final Buffer> buffer, sqsMessagesDeletedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETED_METRIC_NAME); sqsMessagesFailedCounter = pluginMetrics.counter(SQS_MESSAGES_FAILED_METRIC_NAME); sqsMessagesDeleteFailedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETE_FAILED_METRIC_NAME); - sqsMessageDelayTimer = pluginMetrics.timer(SQS_MESSAGE_DELAY_METRIC_NAME); acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME); sqsVisibilityTimeoutChangedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME); sqsVisibilityTimeoutChangeFailedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME); @@ -108,7 +103,6 @@ public void run() { } catch (final Exception e) { LOG.error("Unable to process SQS messages. Processing error due to: {}", e.getMessage()); - // There shouldn't be any exceptions caught here, but added backoff just to control the amount of logging in case of an exception is thrown. applyBackoff(); } @@ -163,14 +157,20 @@ private void applyBackoff() { } } - private ReceiveMessageRequest createReceiveMessageRequest() { - return ReceiveMessageRequest.builder() + ReceiveMessageRequest.Builder requestBuilder = ReceiveMessageRequest.builder() .queueUrl(queueConfig.getUrl()) - .maxNumberOfMessages(queueConfig.getMaximumMessages()) - .visibilityTimeout((int) queueConfig.getVisibilityTimeout().getSeconds()) - .waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds()) - .build(); + .waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds()); + + if (queueConfig.getMaximumMessages() != null) { + requestBuilder.maxNumberOfMessages(queueConfig.getMaximumMessages()); + } + + if (queueConfig.getVisibilityTimeout() != null) { + requestBuilder.visibilityTimeout((int) queueConfig.getVisibilityTimeout().getSeconds()); + } + + return requestBuilder.build(); } private List processSqsEvents(final List messages) { @@ -181,7 +181,15 @@ private List processSqsEvents(final List waitingForAcknowledgements = new ArrayList<>(); AcknowledgementSet acknowledgementSet = null; - final int visibilityTimeout = (int)queueConfig.getVisibilityTimeout().getSeconds(); + + final int visibilityTimeout; + if (queueConfig.getVisibilityTimeout() != null) { + visibilityTimeout = (int) queueConfig.getVisibilityTimeout().getSeconds(); + } else { + visibilityTimeout = (int) Duration.ofSeconds(30).getSeconds(); + + } + final int maxVisibilityTimeout = (int)queueConfig.getVisibilityDuplicateProtectionTimeout().getSeconds(); final int progressCheckInterval = visibilityTimeout/2 - 1; if (endToEndAcknowledgementsEnabled) { @@ -197,7 +205,7 @@ private List processSqsEvents(final List processSqsEvents(final List { - final int newVisibilityTimeoutSeconds = visibilityTimeout; int newValue = messageVisibilityTimesMap.getOrDefault(message, visibilityTimeout) + progressCheckInterval; if (newValue >= maxVisibilityTimeout) { return; } messageVisibilityTimesMap.put(message, newValue); + final int newVisibilityTimeoutSeconds = visibilityTimeout; increaseVisibilityTimeout(message, newVisibilityTimeoutSeconds); }, Duration.ofSeconds(progressCheckInterval)); @@ -245,9 +253,9 @@ private Optional processSqsObject( final AcknowledgementSet acknowledgementSet) { try { sqsEventProcessor.addSqsObject(message, queueConfig.getUrl(), bufferAccumulator, acknowledgementSet); - // TODO: see implementation in s3 return Optional.of(buildDeleteMessageBatchRequestEntry(message)); } catch (final Exception e) { + sqsMessagesFailedCounter.increment(); LOG.error("Error processing from SQS: {}. Retrying with exponential backoff.", e.getMessage()); applyBackoff(); return Optional.empty(); diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationAdapterTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationAdapterTest.java new file mode 100644 index 0000000000..04806ff4d3 --- /dev/null +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationAdapterTest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.sqs; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import java.util.Collections; +import java.util.Map; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.mock; + + +@ExtendWith(MockitoExtension.class) +class AwsAuthenticationAdapterTest { + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private SqsSourceConfig sqsSourceConfig; + + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + private String stsRoleArn; + + @BeforeEach + void setUp() { + when(sqsSourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + + stsRoleArn = UUID.randomUUID().toString(); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn); + } + + private AwsAuthenticationAdapter createObjectUnderTest() { + return new AwsAuthenticationAdapter(awsCredentialsSupplier, sqsSourceConfig); + } + + @Test + void getCredentialsProvider_returns_AwsCredentialsProvider_from_AwsCredentialsSupplier() { + final AwsCredentialsProvider expectedProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))) + .thenReturn(expectedProvider); + + assertThat(createObjectUnderTest().getCredentialsProvider(), equalTo(expectedProvider)); + } + + @ParameterizedTest + @ValueSource(strings = {"us-east-1", "eu-west-1"}) + void getCredentialsProvider_creates_expected_AwsCredentialsOptions(final String regionString) { + final String externalId = UUID.randomUUID().toString(); + final Region region = Region.of(regionString); + + final Map headerOverrides = Collections.singletonMap(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(externalId); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(headerOverrides); + + createObjectUnderTest().getCredentialsProvider(); + + final ArgumentCaptor credentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + verify(awsCredentialsSupplier).getProvider(credentialsOptionsArgumentCaptor.capture()); + + final AwsCredentialsOptions actualOptions = credentialsOptionsArgumentCaptor.getValue(); + + assertThat(actualOptions, notNullValue()); + assertThat(actualOptions.getStsRoleArn(), equalTo(stsRoleArn)); + assertThat(actualOptions.getStsExternalId(), equalTo(externalId)); + assertThat(actualOptions.getRegion(), equalTo(region)); + assertThat(actualOptions.getStsHeaderOverrides(), equalTo(headerOverrides)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationOptionsTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationOptionsTest.java new file mode 100644 index 0000000000..0edf0a4c5b --- /dev/null +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AwsAuthenticationOptionsTest.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.sqs; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; +import software.amazon.awssdk.regions.Region; + +import java.lang.reflect.Field; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; + +class AwsAuthenticationOptionsTest { + + private AwsAuthenticationOptions awsAuthenticationOptions; + + @BeforeEach + void setUp() { + awsAuthenticationOptions = new AwsAuthenticationOptions(); + } + + @Test + void getAwsRegion_returns_Region_of() throws NoSuchFieldException, IllegalAccessException { + final String regionString = UUID.randomUUID().toString(); + final Region expectedRegionObject = mock(Region.class); + reflectivelySetField(awsAuthenticationOptions, "awsRegion", regionString); + final Region actualRegion; + try(final MockedStatic regionMockedStatic = mockStatic(Region.class)) { + regionMockedStatic.when(() -> Region.of(regionString)).thenReturn(expectedRegionObject); + actualRegion = awsAuthenticationOptions.getAwsRegion(); + } + assertThat(actualRegion, equalTo(expectedRegionObject)); + } + + @Test + void getAwsRegion_returns_null_when_region_is_null() throws NoSuchFieldException, IllegalAccessException { + reflectivelySetField(awsAuthenticationOptions, "awsRegion", null); + assertThat(awsAuthenticationOptions.getAwsRegion(), nullValue()); + } + + @Test + void getStsExternalId_notNull() throws NoSuchFieldException, IllegalAccessException { + final String externalId = UUID.randomUUID().toString(); + reflectivelySetField(awsAuthenticationOptions, "awsStsExternalId", externalId); + assertThat(awsAuthenticationOptions.getAwsStsExternalId(), equalTo(externalId)); + } + + @Test + void getStsExternalId_Null() throws NoSuchFieldException, IllegalAccessException { + reflectivelySetField(awsAuthenticationOptions, "awsStsExternalId", null); + assertThat(awsAuthenticationOptions.getAwsStsExternalId(), nullValue()); + } + + private void reflectivelySetField(final AwsAuthenticationOptions awsAuthenticationOptions, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException { + final Field field = AwsAuthenticationOptions.class.getDeclaredField(fieldName); + try { + field.setAccessible(true); + field.set(awsAuthenticationOptions, value); + } finally { + field.setAccessible(false); + } + } +} \ No newline at end of file diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java new file mode 100644 index 0000000000..1ecced9d0a --- /dev/null +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java @@ -0,0 +1,80 @@ +package org.opensearch.dataprepper.plugins.source.sqs; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsClient; + +import java.util.List; + +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + + +class SqsServiceTest { + private SqsSourceConfig sqsSourceConfig; + private SqsEventProcessor sqsEventProcessor; + private SqsClient sqsClient; + private PluginMetrics pluginMetrics; + private AcknowledgementSetManager acknowledgementSetManager; + private Buffer> buffer; + private AwsCredentialsProvider credentialsProvider; + + @BeforeEach + void setUp() { + sqsSourceConfig = mock(SqsSourceConfig.class); + sqsEventProcessor = mock(SqsEventProcessor.class); + sqsClient = mock(SqsClient.class, withSettings()); + pluginMetrics = mock(PluginMetrics.class); + acknowledgementSetManager = mock(AcknowledgementSetManager.class); + buffer = mock(Buffer.class); + credentialsProvider = mock(AwsCredentialsProvider.class); + + AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); + when(sqsSourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + } + + @Test + void start_with_single_queue_starts_workers() { + QueueConfig queueConfig = mock(QueueConfig.class); + when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); + when(queueConfig.getNumWorkers()).thenReturn(2); + when(sqsSourceConfig.getQueues()).thenReturn(List.of(queueConfig)); + when(sqsSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(100); + SqsService sqsService = spy(new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, sqsEventProcessor, pluginMetrics, credentialsProvider)); + doReturn(sqsClient).when(sqsService).createSqsClient(credentialsProvider); + sqsService.start(); // if no exception is thrown here, then workers have been started + } + + @Test + void stop_should_shutdown_executors_and_workers_and_close_client() throws InterruptedException { + QueueConfig queueConfig = mock(QueueConfig.class); + when(queueConfig.getUrl()).thenReturn("MyQueue"); + when(queueConfig.getNumWorkers()).thenReturn(1); + when(sqsSourceConfig.getQueues()).thenReturn(List.of(queueConfig)); + when(sqsSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(100); + SqsClient sqsClient = mock(SqsClient.class); + SqsService sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, sqsEventProcessor, pluginMetrics, credentialsProvider) { + @Override + SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) { + return sqsClient; + } + }; + sqsService.start(); + sqsService.stop(); + verify(sqsClient, times(1)).close(); + } + +} \ No newline at end of file diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsSourceTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsSourceTest.java new file mode 100644 index 0000000000..bc53b7fff5 --- /dev/null +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsSourceTest.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.sqs; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class SqsSourceTest { + private final String PLUGIN_NAME = "sqs"; + private final String TEST_PIPELINE_NAME = "test_pipeline"; + private SqsSource sqsSource; + private PluginMetrics pluginMetrics; + private SqsSourceConfig sqsSourceConfig; + private AcknowledgementSetManager acknowledgementSetManager; + private AwsCredentialsSupplier awsCredentialsSupplier; + private Buffer> buffer; + + + @BeforeEach + void setUp() { + pluginMetrics = PluginMetrics.fromNames(PLUGIN_NAME, TEST_PIPELINE_NAME); + sqsSourceConfig = mock(SqsSourceConfig.class); + acknowledgementSetManager = mock(AcknowledgementSetManager.class); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + sqsSource = new SqsSource(pluginMetrics, sqsSourceConfig, acknowledgementSetManager, awsCredentialsSupplier); + } + + @Test + void start_should_throw_IllegalStateException_when_buffer_is_null() { + assertThrows(IllegalStateException.class, () -> sqsSource.start(null)); + } + + @Test + void start_should_not_throw_when_buffer_is_not_null() { + Buffer> buffer = mock(Buffer.class); + AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn("arn:aws:iam::123456789012:role/example-role"); + when(sqsSourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(mock(AwsCredentialsProvider.class)); + assertDoesNotThrow(() -> sqsSource.start(buffer)); + } + + @Test + void stop_should_not_throw_when_sqsService_is_null() { + sqsSource.stop(); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java new file mode 100644 index 0000000000..a7e38b6d0d --- /dev/null +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java @@ -0,0 +1,282 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.sqs; + +import com.linecorp.armeria.client.retry.Backoff; +import io.micrometer.core.instrument.Counter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.acknowledgements.ProgressCheck; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResultEntry; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; +import software.amazon.awssdk.services.sqs.model.SqsException; + +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.UUID; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class SqsWorkerTest { + + @Mock + private Buffer> buffer; + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + @Mock + private SqsClient sqsClient; + @Mock + private SqsEventProcessor sqsEventProcessor; + @Mock + private SqsSourceConfig sqsSourceConfig; + @Mock + private QueueConfig queueConfig; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private Backoff backoff; + @Mock + private Counter sqsMessagesReceivedCounter; + @Mock + private Counter sqsMessagesDeletedCounter; + @Mock + private Counter sqsMessagesFailedCounter; + @Mock + private Counter sqsMessagesDeleteFailedCounter; + @Mock + private Counter acknowledgementSetCallbackCounter; + @Mock + private Counter sqsVisibilityTimeoutChangedCount; + @Mock + private Counter sqsVisibilityTimeoutChangeFailedCount; + + private SqsWorker createObjectUnderTest() { + return new SqsWorker( + buffer, + acknowledgementSetManager, + sqsClient, + sqsEventProcessor, + sqsSourceConfig, + queueConfig, + pluginMetrics, + backoff); + } + + @BeforeEach + void setUp() { + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_RECEIVED_METRIC_NAME)).thenReturn(sqsMessagesReceivedCounter); + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETED_METRIC_NAME)).thenReturn(sqsMessagesDeletedCounter); + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_FAILED_METRIC_NAME)).thenReturn(sqsMessagesFailedCounter); + when(pluginMetrics.counter(SqsWorker.SQS_MESSAGES_DELETE_FAILED_METRIC_NAME)).thenReturn(sqsMessagesDeleteFailedCounter); + when(pluginMetrics.counter(SqsWorker.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)).thenReturn(acknowledgementSetCallbackCounter); + when(pluginMetrics.counter(SqsWorker.SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME)).thenReturn(sqsVisibilityTimeoutChangedCount); + when(pluginMetrics.counter(SqsWorker.SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME)).thenReturn(sqsVisibilityTimeoutChangeFailedCount); + + when(sqsSourceConfig.getAcknowledgements()).thenReturn(false); + when(sqsSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(100); + when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); + when(queueConfig.getWaitTime()).thenReturn(Duration.ofSeconds(1)); + } + + @Test + void processSqsMessages_should_return_number_of_messages_processed_and_increment_counters() throws IOException { + final Message message = Message.builder() + .messageId(UUID.randomUUID().toString()) + .receiptHandle(UUID.randomUUID().toString()) + .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") + .build(); + + final ReceiveMessageResponse response = ReceiveMessageResponse.builder().messages(message).build(); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); + + final DeleteMessageBatchResultEntry successfulDelete = DeleteMessageBatchResultEntry.builder().id(message.messageId()).build(); + final DeleteMessageBatchResponse deleteResponse = DeleteMessageBatchResponse.builder().successful(successfulDelete).build(); + when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(deleteResponse); + + int messagesProcessed = createObjectUnderTest().processSqsMessages(); + assertThat(messagesProcessed, equalTo(1)); + + verify(sqsMessagesReceivedCounter).increment(1); + verify(sqsMessagesDeletedCounter).increment(1); + verify(sqsMessagesDeleteFailedCounter, never()).increment(anyDouble()); + } + + @Test + void processSqsMessages_should_invoke_processSqsEvent_and_deleteSqsMessages_when_entries_non_empty() throws IOException { + final Message message = Message.builder() + .messageId(UUID.randomUUID().toString()) + .receiptHandle(UUID.randomUUID().toString()) + .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") + .build(); + + final ReceiveMessageResponse response = ReceiveMessageResponse.builder() + .messages(message) + .build(); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); + + final DeleteMessageBatchResultEntry successfulDelete = DeleteMessageBatchResultEntry.builder() + .id(message.messageId()) + .build(); + final DeleteMessageBatchResponse deleteResponse = DeleteMessageBatchResponse.builder() + .successful(successfulDelete) + .build(); + when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(deleteResponse); + + SqsWorker sqsWorker = createObjectUnderTest(); + int messagesProcessed = sqsWorker.processSqsMessages(); + + assertThat(messagesProcessed, equalTo(1)); + verify(sqsEventProcessor, times(1)).addSqsObject(eq(message), eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), any(), isNull()); + verify(sqsClient, times(1)).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); + verify(sqsMessagesReceivedCounter).increment(1); + verify(sqsMessagesDeletedCounter).increment(1); + verify(sqsMessagesDeleteFailedCounter, never()).increment(anyDouble()); + } + + @Test + void processSqsMessages_should_not_invoke_processSqsEvent_and_deleteSqsMessages_when_entries_are_empty() throws IOException { + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(ReceiveMessageResponse.builder().messages(Collections.emptyList()).build()); + SqsWorker sqsWorker = createObjectUnderTest(); + int messagesProcessed = sqsWorker.processSqsMessages(); + assertThat(messagesProcessed, equalTo(0)); + verify(sqsEventProcessor, never()).addSqsObject(any(), anyString(), any(), any()); + verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); + verify(sqsMessagesReceivedCounter, never()).increment(anyDouble()); + verify(sqsMessagesDeletedCounter, never()).increment(anyDouble()); + } + + + @Test + void processSqsMessages_should_not_delete_messages_if_acknowledgements_enabled_until_acknowledged() throws IOException { + when(sqsSourceConfig.getAcknowledgements()).thenReturn(true); + AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); + when(acknowledgementSetManager.create(any(), any())).thenReturn(acknowledgementSet); + when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); + + final Message message = Message.builder() + .messageId("msg-1") + .receiptHandle("rh-1") + .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") + .build(); + + final ReceiveMessageResponse response = ReceiveMessageResponse.builder().messages(message).build(); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); + int messagesProcessed = createObjectUnderTest().processSqsMessages(); + assertThat(messagesProcessed, equalTo(1)); + verify(sqsEventProcessor).addSqsObject(eq(message), + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + any(), + eq(acknowledgementSet)); + verify(sqsMessagesReceivedCounter).increment(1); + verifyNoInteractions(sqsMessagesDeletedCounter); + } + + @Test + void acknowledgementsEnabled_and_visibilityDuplicateProtectionEnabled_should_create_ack_sets_and_progress_check() { + when(sqsSourceConfig.getAcknowledgements()).thenReturn(true); + when(queueConfig.getVisibilityDuplicateProtection()).thenReturn(true); + + SqsWorker worker = new SqsWorker(buffer, acknowledgementSetManager, sqsClient, sqsEventProcessor, sqsSourceConfig, queueConfig, pluginMetrics, backoff); + Message message = Message.builder().messageId("msg-dup").receiptHandle("handle-dup").build(); + ReceiveMessageResponse response = ReceiveMessageResponse.builder().messages(message).build(); + when(sqsClient.receiveMessage((ReceiveMessageRequest) any())).thenReturn(response); + + AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); + when(acknowledgementSetManager.create(any(), any())).thenReturn(acknowledgementSet); + + int processed = worker.processSqsMessages(); + assertThat(processed, equalTo(1)); + + verify(acknowledgementSetManager).create(any(), any()); + verify(acknowledgementSet).addProgressCheck(any(), any()); + } + + @Test + void processSqsMessages_should_return_zero_messages_with_backoff_when_a_SqsException_is_thrown() { + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); + verify(backoff).nextDelayMillis(1); + assertThat(messagesProcessed, equalTo(0)); + } + + @Test + void processSqsMessages_should_throw_when_a_SqsException_is_thrown_with_max_retries() { + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); + when(backoff.nextDelayMillis(anyInt())).thenReturn((long) -1); + SqsWorker objectUnderTest = createObjectUnderTest(); + assertThrows(SqsRetriesExhaustedException.class, objectUnderTest::processSqsMessages); + } + + @Test + void processSqsMessages_should_update_visibility_timeout_when_progress_changes() throws IOException { + AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); + when(queueConfig.getVisibilityDuplicateProtection()).thenReturn(true); + when(queueConfig.getVisibilityTimeout()).thenReturn(Duration.ofMillis(1)); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + when(sqsSourceConfig.getAcknowledgements()).thenReturn(true); + final Message message = mock(Message.class); + final String testReceiptHandle = UUID.randomUUID().toString(); + when(message.messageId()).thenReturn(testReceiptHandle); + when(message.receiptHandle()).thenReturn(testReceiptHandle); + + final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); + when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); + + final int messagesProcessed = createObjectUnderTest().processSqsMessages(); + + assertThat(messagesProcessed, equalTo(1)); + verify(sqsEventProcessor).addSqsObject(any(), anyString(), any(), any()); + verify(acknowledgementSetManager).create(any(), any(Duration.class)); + + ArgumentCaptor> progressConsumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSet).addProgressCheck(progressConsumerArgumentCaptor.capture(), any(Duration.class)); + final Consumer actualConsumer = progressConsumerArgumentCaptor.getValue(); + final ProgressCheck progressCheck = mock(ProgressCheck.class); + actualConsumer.accept(progressCheck); + + ArgumentCaptor changeMessageVisibilityRequestArgumentCaptor = ArgumentCaptor.forClass(ChangeMessageVisibilityRequest.class); + verify(sqsClient).changeMessageVisibility(changeMessageVisibilityRequestArgumentCaptor.capture()); + ChangeMessageVisibilityRequest actualChangeVisibilityRequest = changeMessageVisibilityRequestArgumentCaptor.getValue(); + assertThat(actualChangeVisibilityRequest.queueUrl(), equalTo(queueConfig.getUrl())); + assertThat(actualChangeVisibilityRequest.receiptHandle(), equalTo(testReceiptHandle)); + verify(sqsMessagesReceivedCounter).increment(1); + } +} \ No newline at end of file