Skip to content

Commit

Permalink
added unit tests and addressed comments in design doc
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremy Michael <[email protected]>
  • Loading branch information
Jeremy Michael committed Dec 18, 2024
1 parent b79f0d8 commit 29701d0
Show file tree
Hide file tree
Showing 10 changed files with 616 additions and 39 deletions.
2 changes: 1 addition & 1 deletion data-prepper-plugins/sqs-source/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ sqs-pipeline:
region: <AWS_REGION>
sts_role_arn: <IAM_ROLE_ARN>
sink:
- stdout:
- stdout:
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@

public class QueueConfig {

private static final int DEFAULT_MAXIMUM_MESSAGES = 10;
private static final Integer DEFAULT_MAXIMUM_MESSAGES = null;
private static final Boolean DEFAULT_VISIBILITY_DUPLICATE_PROTECTION = false;
private static final Duration DEFAULT_VISIBILITY_TIMEOUT_SECONDS = Duration.ofSeconds(30);
private static final Duration DEFAULT_VISIBILITY_TIMEOUT_SECONDS = null;
private static final Duration DEFAULT_VISIBILITY_DUPLICATE_PROTECTION_TIMEOUT = Duration.ofHours(2);
private static final Duration DEFAULT_WAIT_TIME_SECONDS = Duration.ofSeconds(20);
private static final Duration DEFAULT_POLL_DELAY_SECONDS = Duration.ofSeconds(0);
static final int DEFAULT_NUMBER_OF_WORKERS = 1;
private static final int DEFAULT_BATCH_SIZE = 10;

@JsonProperty("url")
@NotNull
Expand All @@ -33,14 +32,7 @@ public class QueueConfig {
@JsonProperty("maximum_messages")
@Min(1)
@Max(10)
private int maximumMessages = DEFAULT_MAXIMUM_MESSAGES;

@JsonProperty("batch_size")
@Max(10)
private Integer batchSize = DEFAULT_BATCH_SIZE;

@JsonProperty("polling_frequency")
private Duration pollingFrequency = Duration.ZERO;
private Integer maximumMessages = DEFAULT_MAXIMUM_MESSAGES;

@JsonProperty("poll_delay")
@DurationMin(seconds = 0)
Expand Down Expand Up @@ -68,15 +60,7 @@ public String getUrl() {
return url;
}

public Duration getPollingFrequency() {
return pollingFrequency;
}

public Integer getBatchSize() {
return batchSize;
}

public int getMaximumMessages() {
public Integer getMaximumMessages() {
return maximumMessages;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.sqs.model.Message;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.Map;
import java.time.Instant;
import java.util.Objects;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.sqs.SqsClient;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import com.linecorp.armeria.client.retry.Backoff;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
Expand All @@ -21,7 +20,6 @@
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResultEntry;
import software.amazon.awssdk.services.sqs.model.Message;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
import software.amazon.awssdk.services.sqs.model.SqsException;
Expand All @@ -44,7 +42,6 @@ public class SqsWorker implements Runnable {
static final String SQS_MESSAGES_DELETED_METRIC_NAME = "sqsMessagesDeleted";
static final String SQS_MESSAGES_FAILED_METRIC_NAME = "sqsMessagesFailed";
static final String SQS_MESSAGES_DELETE_FAILED_METRIC_NAME = "sqsMessagesDeleteFailed";
static final String SQS_MESSAGE_DELAY_METRIC_NAME = "sqsMessageDelay";
static final String SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangedCount";
static final String SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangeFailedCount";
static final String ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME = "acknowledgementSetCallbackCounter";
Expand All @@ -58,7 +55,6 @@ public class SqsWorker implements Runnable {
private final Counter acknowledgementSetCallbackCounter;
private final Counter sqsVisibilityTimeoutChangedCount;
private final Counter sqsVisibilityTimeoutChangeFailedCount;
private final Timer sqsMessageDelayTimer;
private final Backoff standardBackoff;
private final QueueConfig queueConfig;
private int failedAttemptCount;
Expand Down Expand Up @@ -93,7 +89,6 @@ public SqsWorker(final Buffer<Record<Event>> buffer,
sqsMessagesDeletedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETED_METRIC_NAME);
sqsMessagesFailedCounter = pluginMetrics.counter(SQS_MESSAGES_FAILED_METRIC_NAME);
sqsMessagesDeleteFailedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETE_FAILED_METRIC_NAME);
sqsMessageDelayTimer = pluginMetrics.timer(SQS_MESSAGE_DELAY_METRIC_NAME);
acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME);
sqsVisibilityTimeoutChangedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME);
sqsVisibilityTimeoutChangeFailedCount = pluginMetrics.counter(SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME);
Expand All @@ -108,7 +103,6 @@ public void run() {

} catch (final Exception e) {
LOG.error("Unable to process SQS messages. Processing error due to: {}", e.getMessage());
// There shouldn't be any exceptions caught here, but added backoff just to control the amount of logging in case of an exception is thrown.
applyBackoff();
}

Expand Down Expand Up @@ -163,14 +157,20 @@ private void applyBackoff() {
}
}


private ReceiveMessageRequest createReceiveMessageRequest() {
return ReceiveMessageRequest.builder()
ReceiveMessageRequest.Builder requestBuilder = ReceiveMessageRequest.builder()
.queueUrl(queueConfig.getUrl())
.maxNumberOfMessages(queueConfig.getMaximumMessages())
.visibilityTimeout((int) queueConfig.getVisibilityTimeout().getSeconds())
.waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds())
.build();
.waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds());

if (queueConfig.getMaximumMessages() != null) {
requestBuilder.maxNumberOfMessages(queueConfig.getMaximumMessages());
}

if (queueConfig.getVisibilityTimeout() != null) {
requestBuilder.visibilityTimeout((int) queueConfig.getVisibilityTimeout().getSeconds());
}

return requestBuilder.build();
}

private List<DeleteMessageBatchRequestEntry> processSqsEvents(final List<Message> messages) {
Expand All @@ -181,7 +181,15 @@ private List<DeleteMessageBatchRequestEntry> processSqsEvents(final List<Message
for (Message message : messages) {
List<DeleteMessageBatchRequestEntry> waitingForAcknowledgements = new ArrayList<>();
AcknowledgementSet acknowledgementSet = null;
final int visibilityTimeout = (int)queueConfig.getVisibilityTimeout().getSeconds();

final int visibilityTimeout;
if (queueConfig.getVisibilityTimeout() != null) {
visibilityTimeout = (int) queueConfig.getVisibilityTimeout().getSeconds();
} else {
visibilityTimeout = (int) Duration.ofSeconds(30).getSeconds();

}

final int maxVisibilityTimeout = (int)queueConfig.getVisibilityDuplicateProtectionTimeout().getSeconds();
final int progressCheckInterval = visibilityTimeout/2 - 1;
if (endToEndAcknowledgementsEnabled) {
Expand All @@ -197,20 +205,20 @@ private List<DeleteMessageBatchRequestEntry> processSqsEvents(final List<Message
if (visibilityDuplicateProtectionEnabled) {
messageVisibilityTimesMap.remove(message);
}
if (result == true) {
if (result) {
deleteSqsMessages(waitingForAcknowledgements);
}
},
Duration.ofSeconds(expiryTimeout));
if (visibilityDuplicateProtectionEnabled) {
acknowledgementSet.addProgressCheck(
(ratio) -> {
final int newVisibilityTimeoutSeconds = visibilityTimeout;
int newValue = messageVisibilityTimesMap.getOrDefault(message, visibilityTimeout) + progressCheckInterval;
if (newValue >= maxVisibilityTimeout) {
return;
}
messageVisibilityTimesMap.put(message, newValue);
final int newVisibilityTimeoutSeconds = visibilityTimeout;
increaseVisibilityTimeout(message, newVisibilityTimeoutSeconds);
},
Duration.ofSeconds(progressCheckInterval));
Expand Down Expand Up @@ -245,9 +253,9 @@ private Optional<DeleteMessageBatchRequestEntry> processSqsObject(
final AcknowledgementSet acknowledgementSet) {
try {
sqsEventProcessor.addSqsObject(message, queueConfig.getUrl(), bufferAccumulator, acknowledgementSet);
// TODO: see implementation in s3
return Optional.of(buildDeleteMessageBatchRequestEntry(message));
} catch (final Exception e) {
sqsMessagesFailedCounter.increment();
LOG.error("Error processing from SQS: {}. Retrying with exponential backoff.", e.getMessage());
applyBackoff();
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;

import java.util.Collections;
import java.util.Map;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.mock;


@ExtendWith(MockitoExtension.class)
class AwsAuthenticationAdapterTest {
@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;
@Mock
private SqsSourceConfig sqsSourceConfig;

@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;
private String stsRoleArn;

@BeforeEach
void setUp() {
when(sqsSourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);

stsRoleArn = UUID.randomUUID().toString();
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
}

private AwsAuthenticationAdapter createObjectUnderTest() {
return new AwsAuthenticationAdapter(awsCredentialsSupplier, sqsSourceConfig);
}

@Test
void getCredentialsProvider_returns_AwsCredentialsProvider_from_AwsCredentialsSupplier() {
final AwsCredentialsProvider expectedProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class)))
.thenReturn(expectedProvider);

assertThat(createObjectUnderTest().getCredentialsProvider(), equalTo(expectedProvider));
}

@ParameterizedTest
@ValueSource(strings = {"us-east-1", "eu-west-1"})
void getCredentialsProvider_creates_expected_AwsCredentialsOptions(final String regionString) {
final String externalId = UUID.randomUUID().toString();
final Region region = Region.of(regionString);

final Map<String, String> headerOverrides = Collections.singletonMap(UUID.randomUUID().toString(), UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(externalId);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(headerOverrides);

createObjectUnderTest().getCredentialsProvider();

final ArgumentCaptor<AwsCredentialsOptions> credentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(credentialsOptionsArgumentCaptor.capture());

final AwsCredentialsOptions actualOptions = credentialsOptionsArgumentCaptor.getValue();

assertThat(actualOptions, notNullValue());
assertThat(actualOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(actualOptions.getStsExternalId(), equalTo(externalId));
assertThat(actualOptions.getRegion(), equalTo(region));
assertThat(actualOptions.getStsHeaderOverrides(), equalTo(headerOverrides));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import software.amazon.awssdk.regions.Region;

import java.lang.reflect.Field;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;

class AwsAuthenticationOptionsTest {

private AwsAuthenticationOptions awsAuthenticationOptions;

@BeforeEach
void setUp() {
awsAuthenticationOptions = new AwsAuthenticationOptions();
}

@Test
void getAwsRegion_returns_Region_of() throws NoSuchFieldException, IllegalAccessException {
final String regionString = UUID.randomUUID().toString();
final Region expectedRegionObject = mock(Region.class);
reflectivelySetField(awsAuthenticationOptions, "awsRegion", regionString);
final Region actualRegion;
try(final MockedStatic<Region> regionMockedStatic = mockStatic(Region.class)) {
regionMockedStatic.when(() -> Region.of(regionString)).thenReturn(expectedRegionObject);
actualRegion = awsAuthenticationOptions.getAwsRegion();
}
assertThat(actualRegion, equalTo(expectedRegionObject));
}

@Test
void getAwsRegion_returns_null_when_region_is_null() throws NoSuchFieldException, IllegalAccessException {
reflectivelySetField(awsAuthenticationOptions, "awsRegion", null);
assertThat(awsAuthenticationOptions.getAwsRegion(), nullValue());
}

@Test
void getStsExternalId_notNull() throws NoSuchFieldException, IllegalAccessException {
final String externalId = UUID.randomUUID().toString();
reflectivelySetField(awsAuthenticationOptions, "awsStsExternalId", externalId);
assertThat(awsAuthenticationOptions.getAwsStsExternalId(), equalTo(externalId));
}

@Test
void getStsExternalId_Null() throws NoSuchFieldException, IllegalAccessException {
reflectivelySetField(awsAuthenticationOptions, "awsStsExternalId", null);
assertThat(awsAuthenticationOptions.getAwsStsExternalId(), nullValue());
}

private void reflectivelySetField(final AwsAuthenticationOptions awsAuthenticationOptions, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException {
final Field field = AwsAuthenticationOptions.class.getDeclaredField(fieldName);
try {
field.setAccessible(true);
field.set(awsAuthenticationOptions, value);
} finally {
field.setAccessible(false);
}
}
}
Loading

0 comments on commit 29701d0

Please sign in to comment.