From e0dee50ce96a8ebd8cdddf264462f28db64a3249 Mon Sep 17 00:00:00 2001 From: Krishna Kondaka Date: Mon, 18 Nov 2024 09:09:44 -0800 Subject: [PATCH] Refactor lambda code to share code between processor and sink (#5196) * Refactor lambda code to share code between processor and sink Signed-off-by: Kondaka * remove debug statements Signed-off-by: Kondaka * Fixed checkstyle errors, add new test file Signed-off-by: Kondaka * Added copyright headers Signed-off-by: Kondaka * Indentation changes Signed-off-by: Kondaka * introducing buffer batch Signed-off-by: Kondaka * Code cleanup Signed-off-by: Kondaka * applied my major refactoring Signed-off-by: Kondaka * Checkstyle cleanup Signed-off-by: Kondaka * InMemoryBuffer related test Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * Add multi-threading to integration tests Signed-off-by: Srikanth Govindarajan Signed-off-by: Kondaka * SinkTest fixes Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * Added metrics support Signed-off-by: Kondaka * added concurrency limit Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * Add integ test env option Signed-off-by: Srikanth Govindarajan Signed-off-by: Kondaka * Add UT for LambdaClientFactory Signed-off-by: Srikanth Govindarajan Signed-off-by: Kondaka * making use of output codec context Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * null pointer issue fix Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * Fix dlq return value in lambda sink Signed-off-by: Kondaka * making the event type default Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * Fix failing SinkConfigTest Signed-off-by: Kondaka * Add backoff retry and ClientOptions Signed-off-by: Kondaka * Fixed All UTs Signed-off-by: Srikanth Govindarajan Signed-off-by: Kondaka * Fix lambda common handler to handle futures correctly Signed-off-by: Kondaka * Added new IT to test both lambda and sink Signed-off-by: Kondaka * Fixed unit tests Signed-off-by: Kondaka * applied my major refactoring Signed-off-by: Kondaka * InMemoryBuffer related test Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * SinkTest fixes Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Kondaka * Add UT for LambdaClientFactory Signed-off-by: Srikanth Govindarajan Signed-off-by: Kondaka * Fixed checkstyle errors Signed-off-by: Kondaka * Fixed build error Signed-off-by: Kondaka * Fixed IT build failure Signed-off-by: Kondaka * Fixed IT build failure Signed-off-by: Kondaka * Fixed IT failure Signed-off-by: Kondaka --------- Signed-off-by: Kondaka Signed-off-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Signed-off-by: Srikanth Govindarajan Co-authored-by: Santhosh Gandhe Co-authored-by: Santhosh Gandhe <1909520+san81@users.noreply.github.com> Co-authored-by: Srikanth Govindarajan --- data-prepper-plugins/aws-lambda/README.md | 12 +- data-prepper-plugins/aws-lambda/build.gradle | 8 +- .../plugins/lambda/LambdaProcessorSinkIT.java | 315 ++++++++++++++++++ .../lambda/processor/LambdaProcessorIT.java | 282 ++++++++++++++++ .../processor/LambdaProcessorServiceIT.java | 165 --------- .../lambda/sink/LambdaSinkServiceIT.java | 236 ------------- .../lambda/common/LambdaCommonHandler.java | 159 ++++++--- .../ResponseEventHandlingStrategy.java | 7 +- .../lambda/common/accumlator/Buffer.java | 37 +- .../common/accumlator/BufferFactory.java | 14 - .../common/accumlator/InMemoryBuffer.java | 239 ++++++------- .../accumlator/InMemoryBufferFactory.java | 16 - .../common/client/LambdaClientFactory.java | 103 +++--- .../lambda/common/config/ClientOptions.java | 42 +++ .../common/config/LambdaCommonConfig.java | 55 ++- .../common/config/ThresholdOptions.java | 6 +- ...ggregateResponseEventHandlingStrategy.java | 8 +- .../lambda/processor/LambdaProcessor.java | 307 ++++++----------- .../processor/LambdaProcessorConfig.java | 138 +++----- .../lambda/processor/ResponseCardinality.java | 38 --- .../StrictResponseEventHandlingStrategy.java | 6 + .../plugins/lambda/sink/LambdaSink.java | 239 ++++++++----- .../plugins/lambda/sink/LambdaSinkConfig.java | 135 +++----- .../lambda/sink/LambdaSinkService.java | 286 ---------------- .../common/LambdaCommonHandlerTest.java | 230 +++++++------ .../InMemoryBufferFactoryTest.java | 32 -- .../accumulator/InMemoryBufferTest.java | 237 ++++++------- .../client/LambdaClientFactoryTest.java | 232 ++++--------- ...gateResponseEventHandlingStrategyTest.java | 5 + .../lambda/processor/InvocationTypeTest.java | 5 + .../processor/LambdaProcessorConfigTest.java | 82 ++++- .../lambda/processor/LambdaProcessorTest.java | 216 ++++++------ .../processor/ResponseCardinalityTest.java | 25 -- ...rictResponseEventHandlingStrategyTest.java | 5 + .../lambda/sink/LambdaSinkConfigTest.java | 84 ++++- .../lambda/sink/LambdaSinkServiceTest.java | 271 --------------- .../plugins/lambda/sink/LambdaSinkTest.java | 255 ++++++++++++++ 37 files changed, 2200 insertions(+), 2332 deletions(-) create mode 100644 data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/LambdaProcessorSinkIT.java create mode 100644 data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java delete mode 100644 data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java delete mode 100644 data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceIT.java rename data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/{processor => common}/ResponseEventHandlingStrategy.java (73%) delete mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/BufferFactory.java delete mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBufferFactory.java create mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ClientOptions.java delete mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinality.java delete mode 100644 data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java delete mode 100644 data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferFactoryTest.java delete mode 100644 data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinalityTest.java delete mode 100644 data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java create mode 100644 data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java diff --git a/data-prepper-plugins/aws-lambda/README.md b/data-prepper-plugins/aws-lambda/README.md index 89298f7715..099b390702 100644 --- a/data-prepper-plugins/aws-lambda/README.md +++ b/data-prepper-plugins/aws-lambda/README.md @@ -45,7 +45,11 @@ The integration tests for this plugin do not run as part of the Data Prepper bui The following command runs the integration tests: ``` -./gradlew :data-prepper-plugins:aws-lambda:integrationTest -Dtests.processor.lambda.region="us-east-1" -Dtests.processor.lambda.functionName="lambda_test_function" -Dtests.processor.lambda.sts_role_arn="arn:aws:iam::123456789012:role/dataprepper-role +./gradlew :data-prepper-plugins:aws-lambda:integrationTest \ +-Dtests.lambda.processor.region="us-east-1" \ +-Dtests.lambda.processor.functionName="test-lambda-processor" \ +-Dtests.lambda.processor.sts_role_arn="arn:aws:iam::<>:role/lambda-role" + ``` @@ -83,6 +87,10 @@ The integration tests for this plugin do not run as part of the Data Prepper bui The following command runs the integration tests: ``` -./gradlew :data-prepper-plugins:aws-lambda:integrationTest -Dtests.sink.lambda.region="us-east-1" -Dtests.sink.lambda.functionName="lambda_test_function" -Dtests.sink.lambda.sts_role_arn="arn:aws:iam::123456789012:role/dataprepper-role +./gradlew :data-prepper-plugins:aws-lambda:integrationTest \ +-Dtests.lambda.processor.region="us-east-1" \ +-Dtests.lambda.processor.functionName="test-lambda-processor" \ +-Dtests.lambda.processor.sts_role_arn="arn:aws:iam::<>>:role/lambda-role" + ``` diff --git a/data-prepper-plugins/aws-lambda/build.gradle b/data-prepper-plugins/aws-lambda/build.gradle index 6186ba05a4..a0319fabd4 100644 --- a/data-prepper-plugins/aws-lambda/build.gradle +++ b/data-prepper-plugins/aws-lambda/build.gradle @@ -14,10 +14,11 @@ dependencies { implementation 'com.fasterxml.jackson.core:jackson-databind' implementation 'software.amazon.awssdk:lambda:2.17.99' implementation 'software.amazon.awssdk:sdk-core:2.x.x' + implementation 'software.amazon.awssdk:netty-nio-client' implementation 'software.amazon.awssdk:sts' implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final' implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' - implementation'org.json:json' + implementation 'org.json:json' implementation libs.commons.lang3 implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' implementation 'org.projectlombok:lombok:1.18.22' @@ -62,6 +63,11 @@ task integrationTest(type: Test) { classpath = sourceSets.integrationTest.runtimeClasspath systemProperty 'log4j.configurationFile', 'src/test/resources/log4j2.properties' + + //Enable Multi-thread in tests + systemProperty 'junit.jupiter.execution.parallel.enabled', 'true' + systemProperty 'junit.jupiter.execution.parallel.mode.default', 'concurrent' + systemProperty 'tests.lambda.sink.region', System.getProperty('tests.lambda.sink.region') systemProperty 'tests.lambda.sink.functionName', System.getProperty('tests.lambda.sink.functionName') systemProperty 'tests.lambda.sink.sts_role_arn', System.getProperty('tests.lambda.sink.sts_role_arn') diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/LambdaProcessorSinkIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/LambdaProcessorSinkIT.java new file mode 100644 index 0000000000..1745605add --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/LambdaProcessorSinkIT.java @@ -0,0 +1,315 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.lambda; + +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.event.DefaultEventHandle; +import org.opensearch.dataprepper.model.event.EventMetadata; +import org.opensearch.dataprepper.model.event.DefaultEventMetadata; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.types.ByteCount; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; +import org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor; +import org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessorConfig; +import org.opensearch.dataprepper.plugins.lambda.sink.LambdaSink; +import org.opensearch.dataprepper.plugins.lambda.sink.LambdaSinkConfig; + +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.any; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; + +import java.util.ArrayList; +import java.util.Collection; +import java.lang.reflect.Field; +import java.util.List; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class LambdaProcessorSinkIT { + private AwsCredentialsProvider awsCredentialsProvider; + private LambdaProcessor lambdaProcessor; + private LambdaProcessorConfig lambdaProcessorConfig; + private String functionName; + private String lambdaRegion; + private String role; + + @Mock + private PluginSetting pluginSetting; + @Mock + private LambdaSinkConfig lambdaSinkConfig; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private PluginFactory pluginFactory; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private ExpressionEvaluator expressionEvaluator; + @Mock + private Counter testCounter; + @Mock + private Counter sinkSuccessCounter; + @Mock + private Timer testTimer; + @Mock + InvocationType invocationType; + + private AtomicLong successCount; + private AtomicLong numEventHandlesReleased; + + @Mock + private AcknowledgementSet acknowledgementSet; + + private LambdaProcessor createLambdaProcessor(LambdaProcessorConfig processorConfig) { + return new LambdaProcessor(pluginFactory, pluginMetrics, processorConfig, awsCredentialsSupplier, expressionEvaluator); + } + + private LambdaSink createLambdaSink(LambdaSinkConfig lambdaSinkConfig) { + return new LambdaSink(pluginSetting, lambdaSinkConfig, pluginFactory, null, awsCredentialsSupplier, expressionEvaluator); + + } + + @BeforeEach + public void setup() { + lambdaRegion = System.getProperty("tests.lambda.processor.region"); + functionName = System.getProperty("tests.lambda.processor.functionName"); + role = System.getProperty("tests.lambda.processor.sts_role_arn"); + successCount = new AtomicLong(); + numEventHandlesReleased = new AtomicLong(); + + acknowledgementSet = mock(AcknowledgementSet.class); + try { + lenient().doAnswer(args -> { + return null; + }).when(acknowledgementSet).acquire(any(EventHandle.class)); + } catch (Exception e){ } + try { + lenient().doAnswer(args -> { + numEventHandlesReleased.incrementAndGet(); + return null; + }).when(acknowledgementSet).release(any(EventHandle.class), any(Boolean.class)); + } catch (Exception e){ } + pluginMetrics = mock(PluginMetrics.class); + when(pluginMetrics.gauge(any(), any(AtomicLong.class))).thenReturn(new AtomicLong()); + sinkSuccessCounter = mock(Counter.class); + try { + lenient().doAnswer(args -> { + Double c = args.getArgument(0); + successCount.addAndGet(c.intValue()); + return null; + }).when(sinkSuccessCounter).increment(any(Double.class)); + } catch (Exception e){ } + testCounter = mock(Counter.class); + try { + lenient().doAnswer(args -> { + return null; + }).when(testCounter).increment(any(Double.class)); + } catch (Exception e){} + try { + lenient().doAnswer(args -> { + return null; + }).when(testCounter).increment(); + } catch (Exception e){} + try { + lenient().doAnswer(args -> { + return null; + }).when(testTimer).record(any(Long.class), any(TimeUnit.class)); + } catch (Exception e){} + when(pluginMetrics.counter(any())).thenReturn(testCounter); + + testTimer = mock(Timer.class); + when(pluginMetrics.timer(any())).thenReturn(testTimer); + lambdaProcessorConfig = mock(LambdaProcessorConfig.class); + expressionEvaluator = mock(ExpressionEvaluator.class); + awsCredentialsProvider = DefaultCredentialsProvider.create(); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider); + pluginFactory = mock(PluginFactory.class); + JsonInputCodecConfig jsonInputCodecConfig = mock(JsonInputCodecConfig.class); + when(jsonInputCodecConfig.getKeyName()).thenReturn(null); + when(jsonInputCodecConfig.getIncludeKeys()).thenReturn(null); + when(jsonInputCodecConfig.getIncludeKeysMetadata()).thenReturn(null); + InputCodec responseCodec = new JsonInputCodec(jsonInputCodecConfig); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any(PluginSetting.class))).thenReturn(responseCodec); + + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(lambdaProcessorConfig.getWhenCondition()).thenReturn(null); + //when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); + BatchOptions batchOptions = mock(BatchOptions.class); + when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); + when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(null); + invocationType = mock(InvocationType.class); + when(lambdaProcessorConfig.getInvocationType()).thenReturn(invocationType); + when(lambdaProcessorConfig.getResponseCodecConfig()).thenReturn(null); + //when(lambdaProcessorConfig.getConnectionTimeout()).thenReturn(DEFAULT_CONNECTION_TIMEOUT); + ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + when(batchOptions.getKeyName()).thenReturn("osi_key"); + when(thresholdOptions.getEventCount()).thenReturn(ThresholdOptions.DEFAULT_EVENT_COUNT); + when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse(ThresholdOptions.DEFAULT_BYTE_CAPACITY)); + when(thresholdOptions.getEventCollectTimeOut()).thenReturn(ThresholdOptions.DEFAULT_EVENT_TIMEOUT); + AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(lambdaRegion)); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(role); + when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(null); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(null); + when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + + pluginSetting = mock(PluginSetting.class); + when(pluginSetting.getPipelineName()).thenReturn("pipeline"); + when(pluginSetting.getName()).thenReturn("name"); + lambdaSinkConfig = mock(LambdaSinkConfig.class); + when(lambdaSinkConfig.getFunctionName()).thenReturn(functionName); + //when(lambdaSinkConfig.getMaxConnectionRetries()).thenReturn(3); + + InvocationType sinkInvocationType = mock(InvocationType.class); + when(sinkInvocationType.getAwsLambdaValue()).thenReturn(InvocationType.EVENT.getAwsLambdaValue()); + when(lambdaSinkConfig.getInvocationType()).thenReturn(invocationType); + //when(lambdaSinkConfig.getConnectionTimeout()).thenReturn(DEFAULT_CONNECTION_TIMEOUT); + when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); + when(lambdaSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + + } + + private void setPrivateField(Object targetObject, String fieldName, Object value) + throws Exception { + Field field = targetObject.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(targetObject, value); + } + + @ParameterizedTest + @ValueSource(ints = {11}) + public void testLambdaProcessorAndLambdaSink(int numRecords) { + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); + lambdaProcessor = createLambdaProcessor(lambdaProcessorConfig); + List> records = createRecords(numRecords); + + Collection> results = lambdaProcessor.doExecute(records); + + assertThat(results.size(), equalTo(numRecords)); + validateStrictModeResults(records, results); + LambdaSink lambdaSink = createLambdaSink(lambdaSinkConfig); + try { + setPrivateField(lambdaSink, "numberOfRecordsSuccessCounter", sinkSuccessCounter); + } catch (Exception e){} + lambdaSink.output(results); + assertThat(successCount.get(), equalTo((long)numRecords)); + assertThat(numEventHandlesReleased.get(), equalTo((long)numRecords)); + } + + private void validateResultsForAggregateMode(List> records, Collection> results) { + List> resultRecords = new ArrayList<>(results); + Map eventHandlesMap = new HashMap<>(); + for (final Record record: records) { + eventHandlesMap.put((Integer)record.getData().toMap().get("id"), record.getData().getEventHandle()); + } + for (int i = 0; i < resultRecords.size(); i++) { + Event event = resultRecords.get(i).getData(); + Map eventData = event.toMap(); + + // Check if the event contains the expected data + assertThat(eventData.containsKey("id"), equalTo(true)); + int id = (Integer) eventData.get("id"); + assertThat(eventData.get("key" + id), equalTo(id)); + String stringValue = "value" + id; + assertThat(eventData.get("keys" + id), equalTo(stringValue.toUpperCase())); + assertThat(event.getEventHandle(), not(equalTo(eventHandlesMap.get(id)))); + + // Check that there's no metadata or it's empty + EventMetadata metadata = event.getMetadata(); + if (metadata != null) { + assertThat(metadata.getAttributes().isEmpty(), equalTo(true)); + assertThat(metadata.getTags().isEmpty(), equalTo(true)); + } + } + } + + private void validateStrictModeResults(List> records, Collection> results) { + List> resultRecords = new ArrayList<>(results); + Map eventHandlesMap = new HashMap<>(); + for (final Record record: records) { + eventHandlesMap.put((Integer)record.getData().toMap().get("id"), record.getData().getEventHandle()); + } + for (int i = 0; i < resultRecords.size(); i++) { + Event event = resultRecords.get(i).getData(); + Map eventData = event.toMap(); + Map attr = event.getMetadata().getAttributes(); + int id = (Integer)eventData.get("id"); + assertThat(event.getEventHandle(), equalTo(eventHandlesMap.get(id))); + assertThat(eventData.get("key"+id), equalTo(id)); + String stringValue = "value"+id; + assertThat(eventData.get("keys"+id), equalTo(stringValue.toUpperCase())); + assertThat(attr.get("attr"+id), equalTo(id)); + assertThat(attr.get("attrs"+id), equalTo("attrvalue"+id)); + } + } + + private List> createRecords(int numRecords) { + List> records = new ArrayList<>(); + for (int i = 0; i < numRecords; i++) { + Map map = new HashMap<>(); + map.put("id", i); + map.put("key"+i, i); + map.put("keys"+i, "value"+i); + Map attrs = new HashMap<>(); + attrs.put("attr"+i, i); + attrs.put("attrs"+i, "attrvalue"+i); + EventMetadata metadata = DefaultEventMetadata.builder() + .withEventType("event") + .withAttributes(attrs) + .build(); + final Event event = JacksonEvent.builder() + .withData(map) + .withEventType("event") + .withEventMetadata(metadata) + .build(); + ((DefaultEventHandle)event.getEventHandle()).addAcknowledgementSet(acknowledgementSet); + records.add(new Record<>(event)); + } + return records; + } + +} diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java new file mode 100644 index 0000000000..ae9efb9377 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorIT.java @@ -0,0 +1,282 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.dataprepper.plugins.lambda.processor; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.event.EventMetadata; +import org.opensearch.dataprepper.model.event.DefaultEventMetadata; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.any; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.List; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class LambdaProcessorIT { + private AwsCredentialsProvider awsCredentialsProvider; + private LambdaProcessor lambdaProcessor; + private LambdaProcessorConfig lambdaProcessorConfig; + private String functionName; + private String lambdaRegion; + private String role; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private PluginFactory pluginFactory; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private ExpressionEvaluator expressionEvaluator; + @Mock + private Counter testCounter; + @Mock + private Timer testTimer; + @Mock + InvocationType invocationType; + private LambdaProcessor createObjectUnderTest(LambdaProcessorConfig processorConfig) { + return new LambdaProcessor(pluginFactory, pluginMetrics, processorConfig, awsCredentialsSupplier, expressionEvaluator); + } + + @BeforeEach + public void setup() { + lambdaRegion = System.getProperty("tests.lambda.processor.region"); + functionName = System.getProperty("tests.lambda.processor.functionName"); + role = System.getProperty("tests.lambda.processor.sts_role_arn"); + pluginMetrics = mock(PluginMetrics.class); + //when(pluginMetrics.gauge(any(), any(AtomicLong.class))).thenReturn(new AtomicLong()); + //testCounter = mock(Counter.class); + try { + lenient().doAnswer(args -> { + return null; + }).when(testCounter).increment(any(Double.class)); + } catch (Exception e){} + try { + lenient().doAnswer(args -> { + return null; + }).when(testTimer).record(any(Long.class), any(TimeUnit.class)); + } catch (Exception e){} + when(pluginMetrics.counter(any())).thenReturn(testCounter); + testTimer = mock(Timer.class); + when(pluginMetrics.timer(any())).thenReturn(testTimer); + lambdaProcessorConfig = mock(LambdaProcessorConfig.class); + expressionEvaluator = mock(ExpressionEvaluator.class); + awsCredentialsProvider = DefaultCredentialsProvider.create(); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider); + pluginFactory = mock(PluginFactory.class); + JsonInputCodecConfig jsonInputCodecConfig = mock(JsonInputCodecConfig.class); + when(jsonInputCodecConfig.getKeyName()).thenReturn(null); + when(jsonInputCodecConfig.getIncludeKeys()).thenReturn(null); + when(jsonInputCodecConfig.getIncludeKeysMetadata()).thenReturn(null); + InputCodec responseCodec = new JsonInputCodec(jsonInputCodecConfig); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any(PluginSetting.class))).thenReturn(responseCodec); + when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); + when(lambdaProcessorConfig.getWhenCondition()).thenReturn(null); + BatchOptions batchOptions = mock(BatchOptions.class); + when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); + when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(null); + invocationType = mock(InvocationType.class); + when(lambdaProcessorConfig.getInvocationType()).thenReturn(invocationType); + when(lambdaProcessorConfig.getResponseCodecConfig()).thenReturn(null); + ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + when(batchOptions.getKeyName()).thenReturn("osi_key"); + when(thresholdOptions.getEventCount()).thenReturn(ThresholdOptions.DEFAULT_EVENT_COUNT); + when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse(ThresholdOptions.DEFAULT_BYTE_CAPACITY)); + when(thresholdOptions.getEventCollectTimeOut()).thenReturn(ThresholdOptions.DEFAULT_EVENT_TIMEOUT); + AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(lambdaRegion)); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(role); + when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(null); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(null); + when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + } + + @ParameterizedTest + //@ValueSource(ints = {2, 5, 10, 100, 1000}) + @ValueSource(ints = {1000}) + public void testRequestResponseWithMatchingEventsStrictMode(int numRecords) { + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); + lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig); + List> records = createRecords(numRecords); + Collection> results = lambdaProcessor.doExecute(records); + assertThat(results.size(), equalTo(numRecords)); + validateStrictModeResults(results); + } + + @ParameterizedTest + //@ValueSource(ints = {2, 5, 10, 100, 1000}) + @ValueSource(ints = {1000}) + public void testRequestResponseWithMatchingEventsAggregateMode(int numRecords) { + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); + lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig); + List> records = createRecords(numRecords); + Collection> results = lambdaProcessor.doExecute(records); + assertThat(results.size(), equalTo(numRecords)); + validateResultsForAggregateMode(results ); + } + + @ParameterizedTest + @ValueSource(ints = {1000}) + public void testRequestResponse_WithMatchingEvents_StrictMode_WithMultipleThreads(int numRecords) throws InterruptedException { + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); + lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig); + int numThreads = 5; + ExecutorService executorService = Executors.newFixedThreadPool(numThreads); + CountDownLatch latch = new CountDownLatch(numThreads); + List> records = createRecords(numRecords); + for (int i = 0; i < numThreads; i++) { + executorService.submit(() -> { + try { + Collection> results = lambdaProcessor.doExecute(records); + assertThat(results.size(), equalTo(numRecords)); + validateStrictModeResults(results); + } finally { + latch.countDown(); + } + }); + } + latch.await(5, TimeUnit.MINUTES); + executorService.shutdown(); + } + + @ParameterizedTest + @ValueSource(strings = {"RequestResponse", "Event"}) + public void testDifferentInvocationTypes(String invocationType) { + when(this.invocationType.getAwsLambdaValue()).thenReturn(invocationType); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); + lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig); + List> records = createRecords(10); + Collection> results = lambdaProcessor.doExecute(records); + if (invocationType.equals("RequestResponse")) { + assertThat(results.size(), equalTo(10)); + validateStrictModeResults(results); + } else { + // For "Event" invocation type + assertThat(results.size(), equalTo(0)); + } + } + + @Test + public void testWithFailureTags() { + when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue()); + when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); + when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(Collections.singletonList("lambda_failure")); + LambdaProcessor spyLambdaProcessor = spy(createObjectUnderTest(lambdaProcessorConfig)); + doThrow(new RuntimeException("Simulated Lambda failure")) + .when(spyLambdaProcessor).convertLambdaResponseToEvent(any(Buffer.class), any(InvokeResponse.class)); + List> records = createRecords(5); + Collection> results = spyLambdaProcessor.doExecute(records); + assertThat(results.size(), equalTo(5)); + for (Record record : results) { + assertThat(record.getData().getMetadata().getTags().contains("lambda_failure"), equalTo(true)); + } + } + + private void validateResultsForAggregateMode(Collection> results) { + List> resultRecords = new ArrayList<>(results); + for (int i = 0; i < resultRecords.size(); i++) { + Event event = resultRecords.get(i).getData(); + Map eventData = event.toMap(); + // Check if the event contains the expected data + assertThat(eventData.containsKey("id"), equalTo(true)); + int id = (Integer) eventData.get("id"); + assertThat(eventData.get("key" + id), equalTo(id)); + String stringValue = "value" + id; + assertThat(eventData.get("keys" + id), equalTo(stringValue.toUpperCase())); + // Check that there's no metadata or it's empty + EventMetadata metadata = event.getMetadata(); + if (metadata != null) { + assertThat(metadata.getAttributes().isEmpty(), equalTo(true)); + assertThat(metadata.getTags().isEmpty(), equalTo(true)); + } + } + } + + private void validateStrictModeResults(Collection> results) { + List> resultRecords = new ArrayList<>(results); + for (int i = 0; i < resultRecords.size(); i++) { + Map eventData = resultRecords.get(i).getData().toMap(); + Map attr = resultRecords.get(i).getData().getMetadata().getAttributes(); + int id = (Integer)eventData.get("id"); + assertThat(eventData.get("key"+id), equalTo(id)); + String stringValue = "value"+id; + assertThat(eventData.get("keys"+id), equalTo(stringValue.toUpperCase())); + assertThat(attr.get("attr"+id), equalTo(id)); + assertThat(attr.get("attrs"+id), equalTo("attrvalue"+id)); + } + } + + private List> createRecords(int numRecords) { + List> records = new ArrayList<>(); + for (int i = 0; i < numRecords; i++) { + Map map = new HashMap<>(); + map.put("id", i); + map.put("key"+i, i); + map.put("keys"+i, "value"+i); + Map attrs = new HashMap<>(); + attrs.put("attr"+i, i); + attrs.put("attrs"+i, "attrvalue"+i); + EventMetadata metadata = DefaultEventMetadata.builder() + .withEventType("event") + .withAttributes(attrs) + .build(); + final Event event = JacksonEvent.builder() + .withData(map) + .withEventType("event") + .withEventMetadata(metadata) + .build(); + records.add(new Record<>(event)); + } + return records; + } +} diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java deleted file mode 100644 index 0db9626799..0000000000 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorServiceIT.java +++ /dev/null @@ -1,165 +0,0 @@ -package org.opensearch.dataprepper.plugins.lambda.processor; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; -import io.micrometer.core.instrument.Counter; -import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.Mock; -import static org.mockito.Mockito.when; -import org.mockito.MockitoAnnotations; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; -import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.configuration.PluginSetting; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.event.JacksonEvent; -import org.opensearch.dataprepper.model.log.JacksonLog; -import org.opensearch.dataprepper.model.plugin.PluginFactory; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; - -@ExtendWith(MockitoExtension.class) -public class LambdaProcessorServiceIT { - - private LambdaAsyncClient lambdaAsyncClient; - private String functionName; - private String lambdaRegion; - private String role; - private BufferFactory bufferFactory; - @Mock - private LambdaProcessorConfig lambdaProcessorConfig; - @Mock - private BatchOptions batchOptions; - @Mock - private ThresholdOptions thresholdOptions; - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - @Mock - private PluginMetrics pluginMetrics; - @Mock - private PluginFactory pluginFactory; - @Mock - private PluginSetting pluginSetting; - @Mock - private Counter numberOfRecordsSuccessCounter; - @Mock - private Counter numberOfRecordsFailedCounter; - @Mock - private ExpressionEvaluator expressionEvaluator; - - private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); - - - @BeforeEach - public void setUp() throws Exception { - MockitoAnnotations.openMocks(this); - lambdaRegion = System.getProperty("tests.lambda.processor.region"); - functionName = System.getProperty("tests.lambda.processor.functionName"); - role = System.getProperty("tests.lambda.processor.sts_role_arn"); - - final Region region = Region.of(lambdaRegion); - - lambdaAsyncClient = LambdaAsyncClient.builder() - .region(Region.of(lambdaRegion)) - .build(); - - bufferFactory = new InMemoryBufferFactory(); - - when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)). - thenReturn(numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)). - thenReturn(numberOfRecordsFailedCounter); - } - - - private static Record createRecord() { - final JacksonEvent event = JacksonLog.builder().withData("[{\"name\":\"test\"}]").build(); - return new Record<>(event); - } - - public LambdaProcessor createObjectUnderTest(final String config) throws JsonProcessingException { - - final LambdaProcessorConfig lambdaProcessorConfig = objectMapper.readValue(config, LambdaProcessorConfig.class); - return new LambdaProcessor(pluginFactory,pluginMetrics,lambdaProcessorConfig,awsCredentialsSupplier,expressionEvaluator); - } - - public LambdaProcessor createObjectUnderTest(LambdaProcessorConfig lambdaSinkConfig) throws JsonProcessingException { - return new LambdaProcessor(pluginFactory,pluginMetrics,lambdaSinkConfig,awsCredentialsSupplier,expressionEvaluator); - } - - - private static Collection> generateRecords(int numberOfRecords) { - List> recordList = new ArrayList<>(); - - for (int rows = 1; rows <= numberOfRecords; rows++) { - HashMap eventData = new HashMap<>(); - eventData.put("name", "Person" + rows); - eventData.put("age", Integer.toString(rows)); - - Record eventRecord = new Record<>(JacksonEvent.builder().withData(eventData).withEventType("event").build()); - recordList.add(eventRecord); - } - return recordList; - } - - @ParameterizedTest - @ValueSource(ints = {1,3}) - void verify_records_to_lambda_success(final int recordCount) throws Exception { - - when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); - when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); - when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - - LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); - - Collection> recordsData = generateRecords(recordCount); - List> recordsResult = (List>) objectUnderTest.doExecute(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - - assertEquals(recordsResult.size(),recordCount); - } - - @ParameterizedTest - @ValueSource(ints = {1,3}) - void verify_records_with_batching_to_lambda(final int recordCount) throws JsonProcessingException, InterruptedException { - - when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName); - when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); - when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); - when(thresholdOptions.getEventCount()).thenReturn(1); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT10s")); - when(batchOptions.getKeyName()).thenReturn("lambda_batch_key"); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); - - LambdaProcessor objectUnderTest = createObjectUnderTest(lambdaProcessorConfig); - Collection> records = generateRecords(recordCount); - Collection> recordsResult = objectUnderTest.doExecute(records); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - assertEquals(recordsResult.size(),recordCount); - } -} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceIT.java b/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceIT.java deleted file mode 100644 index 352430a02c..0000000000 --- a/data-prepper-plugins/aws-lambda/src/integrationTest/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceIT.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.lambda.sink; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Timer; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import org.mockito.Mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import org.mockito.MockitoAnnotations; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; -import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.configuration.PluginSetting; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.event.JacksonEvent; -import org.opensearch.dataprepper.model.log.JacksonLog; -import org.opensearch.dataprepper.model.plugin.PluginFactory; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.sink.OutputCodecContext; -import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.LAMBDA_LATENCY_METRIC; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.REQUEST_PAYLOAD_SIZE; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.RESPONSE_PAYLOAD_SIZE; -import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; - -@ExtendWith(MockitoExtension.class) -class LambdaSinkServiceIT { - - private LambdaAsyncClient lambdaAsyncClient; - private String functionName; - private String lambdaRegion; - private String role; - private BufferFactory bufferFactory; - @Mock - private LambdaSinkConfig lambdaSinkConfig; - @Mock - private BatchOptions batchOptions; - @Mock - private ThresholdOptions thresholdOptions; - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - @Mock - private PluginMetrics pluginMetrics; - @Mock - private DlqPushHandler dlqPushHandler; - @Mock - private PluginFactory pluginFactory; - @Mock - private PluginSetting pluginSetting; - @Mock - private Counter numberOfRecordsSuccessCounter; - @Mock - private Counter numberOfRecordsFailedCounter; - @Mock - private ExpressionEvaluator expressionEvaluator; - @Mock - private Timer lambdaLatencyMetric; - @Mock - private AtomicLong requestPayload; - @Mock - private AtomicLong responsePayload; - private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); - - - @BeforeEach - public void setUp() throws Exception { - MockitoAnnotations.openMocks(this); - lambdaRegion = System.getProperty("tests.sink.lambda.region"); - functionName = System.getProperty("tests.sink.lambda.functionName"); - role = System.getProperty("tests.sink.lambda.sts_role_arn"); - - final Region region = Region.of(lambdaRegion); - - lambdaAsyncClient = LambdaAsyncClient.builder() - .region(Region.of(lambdaRegion)) - .build(); - - bufferFactory = new InMemoryBufferFactory(); - - when(pluginMetrics.counter(LambdaSinkService.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)). - thenReturn(numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(LambdaSinkService.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)). - thenReturn(numberOfRecordsFailedCounter); - when(pluginMetrics.timer(LAMBDA_LATENCY_METRIC)).thenReturn(lambdaLatencyMetric); - when(pluginMetrics.gauge(eq(REQUEST_PAYLOAD_SIZE), any(AtomicLong.class))).thenReturn(requestPayload); - when(pluginMetrics.gauge(eq(RESPONSE_PAYLOAD_SIZE), any(AtomicLong.class))).thenReturn(responsePayload); - } - - - private static Record createRecord() { - final JacksonEvent event = JacksonLog.builder().withData("[{\"name\":\"test\"}]").build(); - return new Record<>(event); - } - - public LambdaSinkService createObjectUnderTest(final String config) throws JsonProcessingException { - - final LambdaSinkConfig lambdaSinkConfig = objectMapper.readValue(config, LambdaSinkConfig.class); - OutputCodecContext codecContext = new OutputCodecContext("Tag", Collections.emptyList(), Collections.emptyList()); - pluginFactory = null; - return new LambdaSinkService(lambdaAsyncClient, - lambdaSinkConfig, - pluginMetrics, - pluginFactory, - pluginSetting, - codecContext, - awsCredentialsSupplier, - dlqPushHandler, - bufferFactory, - expressionEvaluator); - } - - public LambdaSinkService createObjectUnderTest(LambdaSinkConfig lambdaSinkConfig) throws JsonProcessingException { - - OutputCodecContext codecContext = new OutputCodecContext("Tag", Collections.emptyList(), Collections.emptyList()); - pluginFactory = null; - return new LambdaSinkService(lambdaAsyncClient, - lambdaSinkConfig, - pluginMetrics, - pluginFactory, - pluginSetting, - codecContext, - awsCredentialsSupplier, - dlqPushHandler, - bufferFactory, - expressionEvaluator); - } - - - private static Collection> generateRecords(int numberOfRecords) { - List> recordList = new ArrayList<>(); - - for (int rows = 0; rows < numberOfRecords; rows++) { - HashMap eventData = new HashMap<>(); - eventData.put("name", "Person" + rows); - eventData.put("age", Integer.toString(rows)); - - Record eventRecord = new Record<>(JacksonEvent.builder().withData(eventData).withEventType("event").build()); - recordList.add(eventRecord); - } - return recordList; - } - - @ParameterizedTest - @ValueSource(ints = {1,5}) - void verify_flushed_records_to_lambda_success(final int recordCount) throws Exception { - - final String LAMBDA_SINK_CONFIG_YAML = - " function_name: " + functionName +"\n" + - " aws:\n" + - " region: us-east-1\n" + - " sts_role_arn: " + role + "\n" + - " max_retries: 3\n"; - LambdaSinkService objectUnderTest = createObjectUnderTest(LAMBDA_SINK_CONFIG_YAML); - - Collection> recordsData = generateRecords(recordCount); - objectUnderTest.output(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - - verify(numberOfRecordsSuccessCounter, times(recordCount)).increment(1); - } - - @ParameterizedTest - @ValueSource(ints = {1,5,10}) - void verify_flushed_records_to_lambda_failed_and_dlq_works(final int recordCount) throws Exception { - final String LAMBDA_SINK_CONFIG_INVALID_FUNCTION_NAME = - " function_name: $$$\n" + - " aws:\n" + - " region: us-east-1\n" + - " sts_role_arn: arn:aws:iam::176893235612:role/osis-s3-opensearch-role\n" + - " max_retries: 3\n" + - " dlq: #any failed even\n"+ - " s3:\n"+ - " bucket: test-bucket\n"+ - " key_path_prefix: dlq/\n"; - LambdaSinkService objectUnderTest = createObjectUnderTest(LAMBDA_SINK_CONFIG_INVALID_FUNCTION_NAME); - - Collection> recordsData = generateRecords(recordCount); - objectUnderTest.output(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - - verify( numberOfRecordsFailedCounter, times(recordCount)).increment(1); - } - - @ParameterizedTest - @ValueSource(ints = {2,5}) - void verify_flushed_records_with_batching_to_lambda(final int recordCount) throws JsonProcessingException, InterruptedException { - - int event_count = 2; - when(lambdaSinkConfig.getFunctionName()).thenReturn(functionName); - when(lambdaSinkConfig.getMaxConnectionRetries()).thenReturn(3); - when(thresholdOptions.getEventCount()).thenReturn(event_count); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT10s")); - when(batchOptions.getKeyName()).thenReturn("lambda_batch_key"); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); - - LambdaSinkService objectUnderTest = createObjectUnderTest(lambdaSinkConfig); - Collection> recordsData = generateRecords(recordCount); - objectUnderTest.output(recordsData); - Thread.sleep(Duration.ofSeconds(10).toMillis()); - } -} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java index 1d59ff9139..a0df8af20b 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandler.java @@ -1,61 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.dataprepper.plugins.lambda.common; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiFunction; +import java.util.function.Function; + +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.model.types.ByteCount; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; +import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck; import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.io.IOException; -import java.util.List; -import java.util.concurrent.CompletableFuture; - public class LambdaCommonHandler { - private final Logger LOG; - private final LambdaAsyncClient lambdaAsyncClient; - private final String functionName; - private final String invocationType; - BufferFactory bufferFactory; - - public LambdaCommonHandler( - final Logger log, - final LambdaAsyncClient lambdaAsyncClient, - final String functionName, - final String invocationType){ - this.LOG = log; - this.lambdaAsyncClient = lambdaAsyncClient; - this.functionName = functionName; - this.invocationType = invocationType; + + private static final Logger LOG = LoggerFactory.getLogger(LambdaCommonHandler.class); + private LambdaCommonHandler() { + } + + public static boolean isSuccess(InvokeResponse response) { + int statusCode = response.statusCode(); + if (statusCode < 200 || statusCode >= 300) { + return false; } + return true; + } - public Buffer createBuffer(BufferFactory bufferFactory) { - try { - LOG.debug("Resetting buffer"); - return bufferFactory.getBuffer(lambdaAsyncClient, functionName, invocationType); - } catch (IOException e) { - throw new RuntimeException("Failed to reset buffer", e); - } + public static void waitForFutures(List> futureList) { + + if (!futureList.isEmpty()) { + try { + CompletableFuture.allOf(futureList.toArray(new CompletableFuture[0])).join(); + } catch (Exception e) { + LOG.warn("Exception while waiting for Lambda invocations to complete", e); + } } + } - public boolean checkStatusCode(InvokeResponse response) { - int statusCode = response.statusCode(); - if (statusCode < 200 || statusCode >= 300) { - LOG.error("Lambda invocation returned with non-success status code: {}", statusCode); - return false; - } - return true; + private static List createBufferBatches(Collection> records, + BatchOptions batchOptions, final OutputCodecContext outputCodecContext) { + + int maxEvents = batchOptions.getThresholdOptions().getEventCount(); + ByteCount maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); + String keyName = batchOptions.getKeyName(); + Duration maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut(); + + Buffer currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); + List batchedBuffers = new ArrayList<>(); + + LOG.debug("Batch size received to lambda processor: {}", records.size()); + for (Record record : records) { + + currentBufferPerBatch.addRecord(record); + if (ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, + maxCollectionDuration)) { + batchedBuffers.add(currentBufferPerBatch); + currentBufferPerBatch = new InMemoryBuffer(keyName, outputCodecContext); + } + } + + if (currentBufferPerBatch.getEventCount() > 0) { + batchedBuffers.add(currentBufferPerBatch); } + return batchedBuffers; + } - public void waitForFutures(List> futureList) { - if (!futureList.isEmpty()) { - try { - CompletableFuture.allOf(futureList.toArray(new CompletableFuture[0])).join(); - LOG.info("All {} Lambda invocations have completed", futureList.size()); - } catch (Exception e) { - LOG.warn("Exception while waiting for Lambda invocations to complete", e); - } finally { - futureList.clear(); - } + public static List> sendRecords(Collection> records, + LambdaCommonConfig config, + LambdaAsyncClient lambdaAsyncClient, + final OutputCodecContext outputCodecContext, + BiFunction>> successHandler, + Function>> failureHandler) { + // Initialize here to void multi-threading issues + // Note: By default, one instance of processor is created across threads. + //List> resultRecords = Collections.synchronizedList(new ArrayList<>()); + List> resultRecords = new ArrayList<>(); + List> futureList = new ArrayList<>(); + int totalFlushedEvents = 0; + + List batchedBuffers = createBufferBatches(records, config.getBatchOptions(), + outputCodecContext); + + Map bufferToFutureMap = new HashMap<>(); + LOG.debug("Batch Chunks created after threshold check: {}", batchedBuffers.size()); + for (Buffer buffer : batchedBuffers) { + InvokeRequest requestPayload = buffer.getRequestPayload(config.getFunctionName(), + config.getInvocationType().getAwsLambdaValue()); + CompletableFuture future = lambdaAsyncClient.invoke(requestPayload); + futureList.add(future); + bufferToFutureMap.put(buffer, future); + } + waitForFutures(futureList); + for (Map.Entry entry : bufferToFutureMap.entrySet()) { + CompletableFuture future = entry.getValue(); + Buffer buffer = entry.getKey(); + try { + InvokeResponse response = (InvokeResponse) future.join(); + if (isSuccess(response)) { + resultRecords.addAll(successHandler.apply(buffer, response)); + } else { + LOG.error("Lambda invoke failed with error {} ", response.statusCode()); + resultRecords.addAll(failureHandler.apply(buffer)); } + } catch (Exception e) { + LOG.error("Exception from Lambda invocation ", e); + resultRecords.addAll(failureHandler.apply(buffer)); + } } + return resultRecords; + + } + } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/ResponseEventHandlingStrategy.java similarity index 73% rename from data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseEventHandlingStrategy.java rename to data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/ResponseEventHandlingStrategy.java index 46b5587157..e27f0e1b89 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/ResponseEventHandlingStrategy.java @@ -1,4 +1,9 @@ -package org.opensearch.dataprepper.plugins.lambda.processor; +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.lambda.common; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java index 878d5e9033..b6249008cd 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/Buffer.java @@ -5,45 +5,36 @@ package org.opensearch.dataprepper.plugins.lambda.common.accumlator; +import java.time.Duration; +import java.util.List; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.lambda.model.InvokeResponse; - -import java.io.OutputStream; -import java.time.Duration; -import java.util.List; -import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; /** * A buffer can hold data before flushing it. */ public interface Buffer { - long getSize(); - - int getEventCount(); - - Duration getDuration(); - - CompletableFuture flushToLambda(String invocationType); - - OutputStream getOutputStream(); + long getSize(); - SdkBytes getPayload(); + int getEventCount(); - void setEventCount(int eventCount); + Duration getDuration(); - public void addRecord(Record record); + InvokeRequest getRequestPayload(String functionName, String invocationType); - public List> getRecords(); + SdkBytes getPayload(); - public Duration getFlushLambdaLatencyMetric(); + void addRecord(Record record); - public Long getPayloadRequestSize(); + List> getRecords(); - public Duration stopLatencyWatch(); + Duration getFlushLambdaLatencyMetric(); - void reset(); + Long getPayloadRequestSize(); + Duration stopLatencyWatch(); + } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/BufferFactory.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/BufferFactory.java deleted file mode 100644 index 6836440206..0000000000 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/BufferFactory.java +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.lambda.common.accumlator; - -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; - -import java.io.IOException; - -public interface BufferFactory { - Buffer getBuffer(LambdaAsyncClient lambdaAsyncClient, String functionName, String invocationType) throws IOException; -} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java index 109a141e09..f3e2ea1f8f 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java @@ -5,162 +5,145 @@ package org.opensearch.dataprepper.plugins.lambda.common.accumlator; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.commons.lang3.time.StopWatch; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.record.Record; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeRequest; -import software.amazon.awssdk.services.lambda.model.InvokeResponse; - import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import org.apache.commons.lang3.time.StopWatch; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; /** * A buffer can hold in memory data and flushing it. */ public class InMemoryBuffer implements Buffer { - private final ByteArrayOutputStream byteArrayOutputStream; - - private final LambdaAsyncClient lambdaAsyncClient; - private final String functionName; - private final String invocationType; - private int eventCount; - private StopWatch bufferWatch; - private StopWatch lambdaLatencyWatch; - private long payloadRequestSize; - private long payloadResponseSize; - private final List> records; - - - public InMemoryBuffer(LambdaAsyncClient lambdaAsyncClient, String functionName, String invocationType) { - this.lambdaAsyncClient = lambdaAsyncClient; - this.functionName = functionName; - this.invocationType = invocationType; - byteArrayOutputStream = new ByteArrayOutputStream(); - records = new ArrayList<>(); - bufferWatch = new StopWatch(); - bufferWatch.start(); - lambdaLatencyWatch = new StopWatch(); - eventCount = 0; - payloadRequestSize = 0; - payloadResponseSize = 0; + private final ByteArrayOutputStream byteArrayOutputStream; + + private final List> records; + private final StopWatch bufferWatch; + private final StopWatch lambdaLatencyWatch; + private final OutputCodec requestCodec; + private final OutputCodecContext outputCodecContext; + private final long payloadResponseSize; + private int eventCount; + private long payloadRequestSize; + + + public InMemoryBuffer(String batchOptionKeyName) { + this(batchOptionKeyName, new OutputCodecContext()); + } + + public InMemoryBuffer(String batchOptionKeyName, OutputCodecContext outputCodecContext) { + byteArrayOutputStream = new ByteArrayOutputStream(); + records = new ArrayList<>(); + bufferWatch = new StopWatch(); + bufferWatch.start(); + lambdaLatencyWatch = new StopWatch(); + eventCount = 0; + payloadRequestSize = 0; + payloadResponseSize = 0; + // Setup request codec + JsonOutputCodecConfig jsonOutputCodecConfig = new JsonOutputCodecConfig(); + jsonOutputCodecConfig.setKeyName(batchOptionKeyName); + requestCodec = new JsonOutputCodec(jsonOutputCodecConfig); + this.outputCodecContext = outputCodecContext; + } + + public void addRecord(Record record) { + records.add(record); + Event event = record.getData(); + try { + if (eventCount == 0) { + requestCodec.start(this.byteArrayOutputStream, event, this.outputCodecContext); + } + requestCodec.writeEvent(event, this.byteArrayOutputStream); + } catch (IOException e) { + throw new RuntimeException(e); } + eventCount++; + } - public void addRecord(Record record) { - records.add(record); - eventCount++; - } + public List> getRecords() { + return records; + } - public List> getRecords() { - return records; - } + @Override + public long getSize() { + return byteArrayOutputStream.size(); + } - @Override - public long getSize() { - return byteArrayOutputStream.size(); - } + @Override + public int getEventCount() { + return eventCount; + } - @Override - public int getEventCount() { - return eventCount; - } + public Duration getDuration() { + return Duration.ofMillis(bufferWatch.getTime(TimeUnit.MILLISECONDS)); + } - public Duration getDuration() { - return Duration.ofMillis(bufferWatch.getTime(TimeUnit.MILLISECONDS)); - } - - public void reset() { - byteArrayOutputStream.reset(); - eventCount = 0; - bufferWatch.reset(); - lambdaLatencyWatch.reset(); - payloadRequestSize = 0; - payloadResponseSize = 0; - } - - @Override - public CompletableFuture flushToLambda(String invocationType) { - SdkBytes payload = getPayload(); - payloadRequestSize = payload.asByteArray().length; - - // Setup an InvokeRequest. - InvokeRequest request = InvokeRequest.builder() - .functionName(functionName) - .payload(payload) - .invocationType(invocationType) - .build(); - - synchronized (this) { - if (lambdaLatencyWatch.isStarted()) { - lambdaLatencyWatch.reset(); - } - lambdaLatencyWatch.start(); - } - // Use the async client to invoke the Lambda function - CompletableFuture future = lambdaAsyncClient.invoke(request); - return future; - } - - public synchronized Duration stopLatencyWatch() { - if (lambdaLatencyWatch.isStarted()) { - lambdaLatencyWatch.stop(); - } - long timeInMillis = lambdaLatencyWatch.getTime(); - return Duration.ofMillis(timeInMillis); - } + @Override + public InvokeRequest getRequestPayload(String functionName, String invocationType) { - private SdkBytes validatePayload(String payload_string) { - ObjectMapper mapper = new ObjectMapper(); - try { - JsonNode jsonNode = mapper.readTree(byteArrayOutputStream.toByteArray()); - - // Convert the JsonNode back to a String to normalize it (removes extra spaces, trailing commas, etc.) - String normalizedJson = mapper.writeValueAsString(jsonNode); - return SdkBytes.fromUtf8String(normalizedJson); - } catch (IOException e) { - throw new RuntimeException(e); - } + if (eventCount == 0) { + //We never added any events so there is no payload + return null; } - @Override - public void setEventCount(int eventCount) { - this.eventCount = eventCount; + try { + requestCodec.complete(this.byteArrayOutputStream); + } catch (IOException e) { + throw new RuntimeException(e); } - @Override - public OutputStream getOutputStream() { - return byteArrayOutputStream; - } + SdkBytes payload = getPayload(); + payloadRequestSize = payload.asByteArray().length; - @Override - public SdkBytes getPayload() { - byte[] bytes = byteArrayOutputStream.toByteArray(); - SdkBytes sdkBytes = SdkBytes.fromByteArray(bytes); - return sdkBytes; - } + // Setup an InvokeRequest. + InvokeRequest request = InvokeRequest.builder() + .functionName(functionName) + .payload(payload) + .invocationType(invocationType) + .build(); - public Duration getFlushLambdaLatencyMetric (){ - return Duration.ofMillis(lambdaLatencyWatch.getTime(TimeUnit.MILLISECONDS)); + synchronized (this) { + if (lambdaLatencyWatch.isStarted()) { + lambdaLatencyWatch.reset(); + } + lambdaLatencyWatch.start(); } + return request; + } - public Long getPayloadRequestSize() { - return payloadRequestSize; + public synchronized Duration stopLatencyWatch() { + if (lambdaLatencyWatch.isStarted()) { + lambdaLatencyWatch.stop(); } - - public StopWatch getBufferWatch() {return bufferWatch;} - - public StopWatch getLambdaLatencyWatch(){return lambdaLatencyWatch;} - + long timeInMillis = lambdaLatencyWatch.getTime(); + return Duration.ofMillis(timeInMillis); + } + + @Override + public SdkBytes getPayload() { + byte[] bytes = byteArrayOutputStream.toByteArray(); + return SdkBytes.fromByteArray(bytes); + } + + public Duration getFlushLambdaLatencyMetric() { + return Duration.ofMillis(lambdaLatencyWatch.getTime(TimeUnit.MILLISECONDS)); + } + + public Long getPayloadRequestSize() { + return payloadRequestSize; + } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBufferFactory.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBufferFactory.java deleted file mode 100644 index 91083620dd..0000000000 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBufferFactory.java +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.lambda.common.accumlator; - -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; - - -public class InMemoryBufferFactory implements BufferFactory { - @Override - public Buffer getBuffer(LambdaAsyncClient lambdaAsyncClient, String functionName, String invocationType){ - return new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - } -} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java index 1374b0ca07..87b7a4271b 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactory.java @@ -1,57 +1,72 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - package org.opensearch.dataprepper.plugins.lambda.common.client; import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.metricpublisher.MicrometerMetricPublisher; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.core.retry.backoff.BackoffStrategy; +import software.amazon.awssdk.core.retry.backoff.EqualJitterBackoffStrategy; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import java.time.Duration; - public final class LambdaClientFactory { - private LambdaClientFactory() { } - - public static LambdaAsyncClient createAsyncLambdaClient(final AwsAuthenticationOptions awsAuthenticationOptions, - final int maxConnectionRetries, - final AwsCredentialsSupplier awsCredentialsSupplier, - final Duration sdkTimeout) { - final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(awsAuthenticationOptions); - final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions); - final PluginMetrics awsSdkMetrics = PluginMetrics.fromNames("sdk", "aws"); - - return LambdaAsyncClient.builder() - .region(awsAuthenticationOptions.getAwsRegion()) - .credentialsProvider(awsCredentialsProvider) - .overrideConfiguration(createOverrideConfiguration(maxConnectionRetries, awsSdkMetrics, sdkTimeout)) - .build(); - } - - private static ClientOverrideConfiguration createOverrideConfiguration(final int maxConnectionRetries, - final PluginMetrics awsSdkMetrics, - final Duration sdkTimeout) { - final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(maxConnectionRetries).build(); - return ClientOverrideConfiguration.builder() - .retryPolicy(retryPolicy) - .addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics)) - .apiCallTimeout(sdkTimeout) //default sdk limit is 60secs, requests to lambda might fail if lambda takes >60sec to process - .build(); - } - - private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) { - return AwsCredentialsOptions.builder() - .withRegion(awsAuthenticationOptions.getAwsRegion()) - .withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn()) - .withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId()) - .withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides()) - .build(); - } -} + + public static LambdaAsyncClient createAsyncLambdaClient( + final AwsAuthenticationOptions awsAuthenticationOptions, + final AwsCredentialsSupplier awsCredentialsSupplier, + ClientOptions clientOptions) { + final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions( + awsAuthenticationOptions); + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider( + awsCredentialsOptions); + final PluginMetrics awsSdkMetrics = PluginMetrics.fromNames("sdk", "aws"); + + return LambdaAsyncClient.builder() + .region(awsAuthenticationOptions.getAwsRegion()) + .credentialsProvider(awsCredentialsProvider) + .overrideConfiguration( + createOverrideConfiguration(clientOptions, awsSdkMetrics)) + .httpClient(NettyNioAsyncHttpClient.builder() + .maxConcurrency(clientOptions.getMaxConcurrency()) + .connectionTimeout(clientOptions.getConnectionTimeout()).build()) + .build(); + } + + private static ClientOverrideConfiguration createOverrideConfiguration( + ClientOptions clientOptions, + final PluginMetrics awsSdkMetrics) { + + //TODO - Add AdaptiveRetryStrategy + //https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/retry-strategy.html + BackoffStrategy backoffStrategy = EqualJitterBackoffStrategy.builder() + .baseDelay(clientOptions.getBaseDelay()) + .maxBackoffTime(clientOptions.getMaxBackoff()) + .build(); + + final RetryPolicy retryPolicy = RetryPolicy.builder() + .numRetries(clientOptions.getMaxConnectionRetries()) + .backoffStrategy(backoffStrategy) + .build(); + + return ClientOverrideConfiguration.builder() + .retryPolicy(retryPolicy) + .addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics)) + .apiCallTimeout(clientOptions.getApiCallTimeout()) + .build(); + } + + private static AwsCredentialsOptions convertToCredentialsOptions( + final AwsAuthenticationOptions awsAuthenticationOptions) { + return AwsCredentialsOptions.builder() + .withRegion(awsAuthenticationOptions.getAwsRegion()) + .withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn()) + .withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId()) + .withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides()) + .build(); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ClientOptions.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ClientOptions.java new file mode 100644 index 0000000000..bab1c16c91 --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ClientOptions.java @@ -0,0 +1,42 @@ +package org.opensearch.dataprepper.plugins.lambda.common.config; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import lombok.Getter; + +import java.time.Duration; + +@Getter +public class ClientOptions { + public static final int DEFAULT_CONNECTION_RETRIES = 3; + public static final int DEFAULT_MAXIMUM_CONCURRENCY = 200; + public static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(60); + public static final Duration DEFAULT_API_TIMEOUT = Duration.ofSeconds(60); + public static final Duration DEFAULT_BASE_DELAY = Duration.ofMillis(100); + public static final Duration DEFAULT_MAX_BACKOFF = Duration.ofSeconds(20); + + @JsonPropertyDescription("Total retries we want before failing") + @JsonProperty("max_retries") + private int maxConnectionRetries = DEFAULT_CONNECTION_RETRIES; + + @JsonPropertyDescription("api call timeout defines the time sdk maintains the api call to complete before timing out") + @JsonProperty("api_call_timeout") + private Duration apiCallTimeout = DEFAULT_API_TIMEOUT; + + @JsonPropertyDescription("sdk timeout defines the time sdk maintains the connection to the client before timing out") + @JsonProperty("connection_timeout") + private Duration connectionTimeout = DEFAULT_CONNECTION_TIMEOUT; + + @JsonPropertyDescription("max concurrency defined from the client side") + @JsonProperty("max_concurrency") + private int maxConcurrency = DEFAULT_MAXIMUM_CONCURRENCY; + + @JsonPropertyDescription("Base delay for exponential backoff") + @JsonProperty("base_delay") + private Duration baseDelay = DEFAULT_BASE_DELAY; + + @JsonPropertyDescription("Maximum backoff time for exponential backoff") + @JsonProperty("max_backoff") + private Duration maxBackoff = DEFAULT_MAX_BACKOFF; + +} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/LambdaCommonConfig.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/LambdaCommonConfig.java index c3d58e1f39..eb95a35148 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/LambdaCommonConfig.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/LambdaCommonConfig.java @@ -1,12 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.dataprepper.plugins.lambda.common.config; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; import java.time.Duration; +import lombok.Getter; +import org.opensearch.dataprepper.model.configuration.PluginModel; + +@Getter + +public abstract class LambdaCommonConfig { + + public static final int DEFAULT_CONNECTION_RETRIES = 3; + public static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(60); + public static final String STS_REGION = "region"; + public static final String STS_ROLE_ARN = "sts_role_arn"; + + @JsonProperty("aws") + @NotNull + @Valid + private AwsAuthenticationOptions awsAuthenticationOptions; + + @JsonPropertyDescription("Lambda Function Name") + @JsonProperty("function_name") + @NotEmpty + @Size(min = 3, max = 500, message = "function name length should be at least 3 characters") + private String functionName; + + @JsonPropertyDescription("invocation type defines the way we want to call lambda function") + @JsonProperty("invocation_type") + private InvocationType invocationType = InvocationType.REQUEST_RESPONSE; + + @JsonPropertyDescription("Client options") + @JsonProperty("client") + private ClientOptions clientOptions = new ClientOptions(); + + @JsonPropertyDescription("Batch options") + @JsonProperty("batch") + private BatchOptions batchOptions = new BatchOptions(); -public class LambdaCommonConfig { - public static final int DEFAULT_CONNECTION_RETRIES = 3; - public static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(60); + @JsonPropertyDescription("Codec configuration for parsing Lambda responses") + @JsonProperty("response_codec") + @Valid + private PluginModel responseCodecConfig; - public static final String STS_REGION = "region"; - public static final String STS_ROLE_ARN = "sts_role_arn"; + public abstract InvocationType getInvocationType(); } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ThresholdOptions.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ThresholdOptions.java index 2242ef900d..1ba31eeb50 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ThresholdOptions.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/config/ThresholdOptions.java @@ -16,9 +16,9 @@ public class ThresholdOptions { - private static final int DEFAULT_EVENT_COUNT = 100; - private static final String DEFAULT_BYTE_CAPACITY = "5mb"; - private static final Duration DEFAULT_EVENT_TIMEOUT = Duration.ofSeconds(10); + public static final int DEFAULT_EVENT_COUNT = 100; + public static final String DEFAULT_BYTE_CAPACITY = "5mb"; + public static final Duration DEFAULT_EVENT_TIMEOUT = Duration.ofSeconds(10); @JsonProperty("event_count") @Size(min = 0, max = 10000000, message = "event_count size should be between 0 and 10000000") diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java index 7d32a4f380..b19dd8d156 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategy.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.dataprepper.plugins.lambda.processor; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; @@ -5,6 +10,7 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,4 +39,4 @@ public void handleEvents(List parsedEvents, List> originalR } } } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java index 793baad813..2352678419 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessor.java @@ -5,16 +5,26 @@ package org.opensearch.dataprepper.plugins.lambda.processor; +import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; + import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Timer; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; import org.opensearch.dataprepper.model.codec.InputCodec; -import org.opensearch.dataprepper.model.codec.OutputCodec; import org.opensearch.dataprepper.model.configuration.PluginModel; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; @@ -24,83 +34,70 @@ import org.opensearch.dataprepper.model.processor.Processor; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.sink.OutputCodecContext; -import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; +import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; - @DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class) public class LambdaProcessor extends AbstractProcessor, Record> { public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaProcessorObjectsEventsSucceeded"; public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaProcessorObjectsEventsFailed"; + public static final String NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA = "lambdaProcessorNumberOfRequestsSucceeded"; + public static final String NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA = "lambdaProcessorNumberOfRequestsFailed"; public static final String LAMBDA_LATENCY_METRIC = "lambdaProcessorLatency"; public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize"; public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize"; private static final Logger LOG = LoggerFactory.getLogger(LambdaProcessor.class); - - private final String functionName; + final PluginSetting codecPluginSetting; + final PluginFactory pluginFactory; + final LambdaProcessorConfig lambdaProcessorConfig; private final String whenCondition; private final ExpressionEvaluator expressionEvaluator; private final Counter numberOfRecordsSuccessCounter; private final Counter numberOfRecordsFailedCounter; + private final Counter numberOfRequestsSuccessCounter; + private final Counter numberOfRequestsFailedCounter; private final Timer lambdaLatencyMetric; - private final String invocationType; private final List tagsOnMatchFailure; - private final BatchOptions batchOptions; private final LambdaAsyncClient lambdaAsyncClient; - private final AtomicLong requestPayloadMetric; - private final AtomicLong responsePayloadMetric; - LambdaCommonHandler lambdaCommonHandler; - private final int maxEvents; - private final ByteCount maxBytes; - private final Duration maxCollectionDuration; - private int maxRetries = 0; - private int totalFlushedEvents; - final PluginSetting codecPluginSetting; - final PluginFactory pluginFactory; + private final DistributionSummary requestPayloadMetric; + private final DistributionSummary responsePayloadMetric; private final ResponseEventHandlingStrategy responseStrategy; + private final JsonOutputCodecConfig jsonOutputCodecConfig; + LambdaCommonHandler lambdaCommonHandler; @DataPrepperPluginConstructor - public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pluginMetrics, final LambdaProcessorConfig lambdaProcessorConfig, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) { + public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pluginMetrics, + final LambdaProcessorConfig lambdaProcessorConfig, + final AwsCredentialsSupplier awsCredentialsSupplier, + final ExpressionEvaluator expressionEvaluator) { super(pluginMetrics); this.expressionEvaluator = expressionEvaluator; this.pluginFactory = pluginFactory; - this.numberOfRecordsSuccessCounter = pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); - this.numberOfRecordsFailedCounter = pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED); + this.lambdaProcessorConfig = lambdaProcessorConfig; + this.numberOfRecordsSuccessCounter = pluginMetrics.counter( + NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); + this.numberOfRecordsFailedCounter = pluginMetrics.counter( + NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED); + this.numberOfRequestsSuccessCounter = pluginMetrics.counter( + NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA); + this.numberOfRequestsFailedCounter = pluginMetrics.counter( + NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA); this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC); - this.requestPayloadMetric = pluginMetrics.gauge(REQUEST_PAYLOAD_SIZE, new AtomicLong()); - this.responsePayloadMetric = pluginMetrics.gauge(RESPONSE_PAYLOAD_SIZE, new AtomicLong()); - - functionName = lambdaProcessorConfig.getFunctionName(); - whenCondition = lambdaProcessorConfig.getWhenCondition(); - maxRetries = lambdaProcessorConfig.getMaxConnectionRetries(); - batchOptions = lambdaProcessorConfig.getBatchOptions(); - tagsOnMatchFailure = lambdaProcessorConfig.getTagsOnMatchFailure(); + this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE); + this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE); + this.whenCondition = lambdaProcessorConfig.getWhenCondition(); + this.tagsOnMatchFailure = lambdaProcessorConfig.getTagsOnFailure(); PluginModel responseCodecConfig = lambdaProcessorConfig.getResponseCodecConfig(); @@ -108,16 +105,22 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl // Default to JsonInputCodec with default settings codecPluginSetting = new PluginSetting("json", Collections.emptyMap()); } else { - codecPluginSetting = new PluginSetting(responseCodecConfig.getPluginName(), responseCodecConfig.getPluginSettings()); + codecPluginSetting = new PluginSetting(responseCodecConfig.getPluginName(), + responseCodecConfig.getPluginSettings()); } - maxEvents = batchOptions.getThresholdOptions().getEventCount(); - maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); - maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut(); - invocationType = lambdaProcessorConfig.getInvocationType().getAwsLambdaValue(); + jsonOutputCodecConfig = new JsonOutputCodecConfig(); + jsonOutputCodecConfig.setKeyName(lambdaProcessorConfig.getBatchOptions().getKeyName()); - lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient(lambdaProcessorConfig.getAwsAuthenticationOptions(), - lambdaProcessorConfig.getMaxConnectionRetries(), awsCredentialsSupplier, lambdaProcessorConfig.getConnectionTimeout()); + ClientOptions clientOptions = lambdaProcessorConfig.getClientOptions(); + if(clientOptions == null){ + clientOptions = new ClientOptions(); + } + lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( + lambdaProcessorConfig.getAwsAuthenticationOptions(), + awsCredentialsSupplier, + clientOptions + ); // Select the correct strategy based on the configuration if (lambdaProcessorConfig.getResponseEventsMatch()) { @@ -126,8 +129,6 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl this.responseStrategy = new AggregateResponseEventHandlingStrategy(); } - LOG.info("LambdaFunctionName:{} , responseEventsMatch:{}, invocationType:{}", functionName, - lambdaProcessorConfig.getResponseEventsMatch(), invocationType); } @Override @@ -136,164 +137,63 @@ public Collection> doExecute(Collection> records) { return records; } - // Initialize here to void multi-threading issues - // Note: By default, one instance of processor is created across threads. - BufferFactory bufferFactory = new InMemoryBufferFactory(); - lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, functionName, invocationType); - Buffer currentBufferPerBatch = lambdaCommonHandler.createBuffer(bufferFactory); - List futureList = new ArrayList<>(); - totalFlushedEvents = 0; - - // Setup request codec - JsonOutputCodecConfig jsonOutputCodecConfig = new JsonOutputCodecConfig(); - jsonOutputCodecConfig.setKeyName(batchOptions.getKeyName()); - OutputCodec requestCodec = new JsonOutputCodec(jsonOutputCodecConfig); - - //Setup response codec - InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting); - - List> resultRecords = new ArrayList<>(); - - LOG.info("Batch size received to lambda processor: {}", records.size()); + List> resultRecords = Collections.synchronizedList(new ArrayList()); + List> recordsToLambda = new ArrayList<>(); for (Record record : records) { final Event event = record.getData(); - // If the condition is false, add the event to resultRecords as-is if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { resultRecords.add(record); continue; } - - try { - if (currentBufferPerBatch.getEventCount() == 0) { - requestCodec.start(currentBufferPerBatch.getOutputStream(), event, new OutputCodecContext()); - } - requestCodec.writeEvent(event, currentBufferPerBatch.getOutputStream()); - currentBufferPerBatch.addRecord(record); - - boolean wasFlushed = flushToLambdaIfNeeded(resultRecords, currentBufferPerBatch, - requestCodec, responseCodec,futureList,false); - - // After flushing, create a new buffer for the next batch - if (wasFlushed) { - currentBufferPerBatch = lambdaCommonHandler.createBuffer(bufferFactory); - requestCodec = new JsonOutputCodec(jsonOutputCodecConfig); - } - } catch (Exception e) { - LOG.error(NOISY, "Exception while processing event {}", event, e); - handleFailure(e, currentBufferPerBatch, resultRecords); - currentBufferPerBatch = lambdaCommonHandler.createBuffer(bufferFactory); - requestCodec = new JsonOutputCodec(jsonOutputCodecConfig); - } + recordsToLambda.add(record); } - - // Flush any remaining events in the buffer after processing all records - if (currentBufferPerBatch.getEventCount() > 0) { - LOG.info("Force Flushing the remaining {} events in the buffer", currentBufferPerBatch.getEventCount()); - try { - flushToLambdaIfNeeded(resultRecords, currentBufferPerBatch, - requestCodec, responseCodec, futureList,true); - currentBufferPerBatch.reset(); - } catch (Exception e) { - LOG.error("Exception while flushing remaining events", e); - handleFailure(e, currentBufferPerBatch, resultRecords); - } + try { + resultRecords.addAll( + lambdaCommonHandler.sendRecords(recordsToLambda, lambdaProcessorConfig, lambdaAsyncClient, + new OutputCodecContext(), + (inputBuffer, response) -> { + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsSuccessCounter.increment(); + return convertLambdaResponseToEvent(inputBuffer, response); + }, + (inputBuffer) -> { + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsFailedCounter.increment(); + return addFailureTags(inputBuffer.getRecords()); + }) + + ); + } catch (Exception e) { + LOG.info("Exception in doExecute"); + numberOfRecordsFailedCounter.increment(recordsToLambda.size()); + resultRecords.addAll(addFailureTags(recordsToLambda)); } - - lambdaCommonHandler.waitForFutures(futureList); - LOG.info("Total events flushed to lambda successfully: {}", totalFlushedEvents); - return resultRecords; } - - boolean flushToLambdaIfNeeded(List> resultRecords, Buffer currentBufferPerBatch, - OutputCodec requestCodec, InputCodec responseCodec, List futureList, - boolean forceFlush) { - - LOG.debug("currentBufferPerBatchEventCount:{}, maxEvents:{}, maxBytes:{}, " + - "maxCollectionDuration:{}, forceFlush:{} ", currentBufferPerBatch.getEventCount(), - maxEvents, maxBytes, maxCollectionDuration, forceFlush); - if (forceFlush || ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, maxCollectionDuration)) { - try { - requestCodec.complete(currentBufferPerBatch.getOutputStream()); - - // Capture buffer before resetting - final int eventCount = currentBufferPerBatch.getEventCount(); - - CompletableFuture future = currentBufferPerBatch.flushToLambda(invocationType); - - // Handle future - CompletableFuture processingFuture = future.thenAccept(response -> { - //Success handler - handleLambdaResponse(resultRecords, currentBufferPerBatch, eventCount, response, responseCodec); - }).exceptionally(throwable -> { - //Failure handler - List> bufferRecords = currentBufferPerBatch.getRecords(); - Record eventRecord = bufferRecords.isEmpty() ? null : bufferRecords.get(0); - LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {} , Event: {}", - functionName, eventRecord == null? "null":eventRecord.getData(), throwable); - requestPayloadMetric.set(currentBufferPerBatch.getPayloadRequestSize()); - responsePayloadMetric.set(0); - Duration latency = currentBufferPerBatch.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - handleFailure(throwable, currentBufferPerBatch, resultRecords); - return null; - }); - - futureList.add(processingFuture); - } catch (IOException e) { - LOG.error(NOISY, "Exception while flushing to lambda", e); - handleFailure(e, currentBufferPerBatch, resultRecords); - } - return true; - } - return false; - } - - private void handleLambdaResponse(List> resultRecords, Buffer flushedBuffer, - int eventCount, InvokeResponse response, InputCodec responseCodec) { - boolean success = lambdaCommonHandler.checkStatusCode(response); - if (success) { - LOG.info("Successfully flushed {} events", eventCount); - - //metrics - numberOfRecordsSuccessCounter.increment(eventCount); - Duration latency = flushedBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - totalFlushedEvents += eventCount; - - convertLambdaResponseToEvent(resultRecords, response, flushedBuffer, responseCodec); - } else { - // Non-2xx status code treated as failure - handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()), flushedBuffer, resultRecords); - } - } - /* * Assumption: Lambda always returns json array. * 1. If response has an array, we assume that we split the individual events. * 2. If it is not an array, then create one event per response. */ - void convertLambdaResponseToEvent(final List> resultRecords, final InvokeResponse lambdaResponse, - Buffer flushedBuffer, InputCodec responseCodec) { + List> convertLambdaResponseToEvent(Buffer flushedBuffer, + final InvokeResponse lambdaResponse) { + InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting); + List> originalRecords = flushedBuffer.getRecords(); try { List parsedEvents = new ArrayList<>(); - List> originalRecords = flushedBuffer.getRecords(); + List> resultRecords = new ArrayList<>(); SdkBytes payload = lambdaResponse.payload(); // Handle null or empty payload if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events"); - // Set metrics - requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - responsePayloadMetric.set(0); } else { - // Set metrics - requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - responsePayloadMetric.set(payload.asByteArray().length); - - LOG.debug("Response payload:{}", payload.asUtf8String()); InputStream inputStream = new ByteArrayInputStream(payload.asByteArray()); //Convert to response codec try { @@ -306,18 +206,15 @@ void convertLambdaResponseToEvent(final List> resultRecords, final } LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " + - "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), + "FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(), flushedBuffer.getSize()); responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer); } - - + return resultRecords; } catch (Exception e) { LOG.error(NOISY, "Error converting Lambda response to Event"); - // Metrics update - requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - responsePayloadMetric.set(0); - handleFailure(e, flushedBuffer, resultRecords); + addFailureTags(flushedBuffer.getRecords()); + return originalRecords; } } @@ -325,22 +222,9 @@ void convertLambdaResponseToEvent(final List> resultRecords, final * If one event in the Buffer fails, we consider that the entire * Batch fails and tag each event in that Batch. */ - void handleFailure(Throwable e, Buffer flushedBuffer, List> resultRecords) { - try { - if (flushedBuffer.getEventCount() > 0) { - numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); - } - - addFailureTags(flushedBuffer, resultRecords); - LOG.error(NOISY, "Failed to process batch due to error: ", e); - } catch(Exception ex){ - LOG.error(NOISY, "Exception in handleFailure while processing failure for buffer: ", ex); - } - } - - private void addFailureTags(Buffer flushedBuffer, List> resultRecords) { + private List> addFailureTags(List> records) { // Add failure tags to each event in the batch - for (Record record : flushedBuffer.getRecords()) { + for (Record record : records) { Event event = record.getData(); EventMetadata metadata = event.getMetadata(); if (metadata != null) { @@ -348,14 +232,13 @@ private void addFailureTags(Buffer flushedBuffer, List> resultReco } else { LOG.warn("Event metadata is null, cannot add failure tags."); } - resultRecords.add(record); } + return records; } @Override public void prepareForShutdown() { - } @Override @@ -367,4 +250,4 @@ public boolean isReadyForShutdown() { public void shutdown() { } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java index 7c0bf52754..5d7e05d7c0 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java @@ -2,104 +2,52 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ + package org.opensearch.dataprepper.plugins.lambda.processor; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import jakarta.validation.Valid; -import jakarta.validation.constraints.NotEmpty; -import jakarta.validation.constraints.NotNull; -import jakarta.validation.constraints.Size; -import org.opensearch.dataprepper.model.configuration.PluginModel; -import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import static org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig.DEFAULT_CONNECTION_RETRIES; -import static org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig.DEFAULT_CONNECTION_TIMEOUT; - -import java.time.Duration; import java.util.Collections; import java.util.List; - -public class LambdaProcessorConfig { - - @JsonProperty("aws") - @NotNull - @Valid - private AwsAuthenticationOptions awsAuthenticationOptions; - - @JsonPropertyDescription("Lambda Function Name") - @JsonProperty("function_name") - @NotEmpty - @Size(min = 3, max = 500, message = "function name length should be at least 3 characters") - private String functionName; - - @JsonPropertyDescription("Total retries we want before failing") - @JsonProperty("max_retries") - private int maxConnectionRetries = DEFAULT_CONNECTION_RETRIES; - - @JsonPropertyDescription("invocation type defines the way we want to call lambda function") - @JsonProperty("invocation_type") - private InvocationType invocationType = InvocationType.REQUEST_RESPONSE; - - @JsonPropertyDescription("Defines the way Data Prepper treats the response from Lambda") - @JsonProperty("response_events_match") - private boolean responseEventsMatch = false; - - @JsonPropertyDescription("sdk timeout defines the time sdk maintains the connection to the client before timing out") - @JsonProperty("connection_timeout") - private Duration connectionTimeout = DEFAULT_CONNECTION_TIMEOUT; - - @JsonProperty("batch") - private BatchOptions batchOptions; - - @JsonPropertyDescription("defines a condition for event to use this processor") - @JsonProperty("lambda_when") - private String whenCondition; - - @JsonPropertyDescription("Codec configuration for parsing Lambda responses") - @JsonProperty("response_codec") - @Valid - private PluginModel responseCodecConfig; - - @JsonProperty("tags_on_match_failure") - @JsonPropertyDescription("A List of Strings that specifies the tags to be set in the event when lambda fails to " + - "or exception occurs. This tag may be used in conditional expressions in " + - "other parts of the configuration") - private List tagsOnMatchFailure = Collections.emptyList(); - - public PluginModel getResponseCodecConfig() { - return responseCodecConfig; - } - - public AwsAuthenticationOptions getAwsAuthenticationOptions() { - return awsAuthenticationOptions; - } - - public BatchOptions getBatchOptions(){return batchOptions;} - - public List getTagsOnMatchFailure(){ - return tagsOnMatchFailure; - } - public String getFunctionName() { - return functionName; - } - - public int getMaxConnectionRetries() { - return maxConnectionRetries; - } - - public String getWhenCondition() { - return whenCondition; - } - - public Duration getConnectionTimeout() { return connectionTimeout;} - - public InvocationType getInvocationType() { - return invocationType; - } - - public Boolean getResponseEventsMatch() { - return responseEventsMatch; - } -} \ No newline at end of file +import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; + +public class LambdaProcessorConfig extends LambdaCommonConfig { + + @JsonPropertyDescription("invocation type defines the way we want to call lambda function") + @JsonProperty("invocation_type") + private InvocationType invocationType = InvocationType.REQUEST_RESPONSE; + + @JsonPropertyDescription("Defines the way Data Prepper treats the response from Lambda") + @JsonProperty("response_events_match") + private boolean responseEventsMatch = false; + + @JsonPropertyDescription("defines a condition for event to use this processor") + @JsonProperty("lambda_when") + private String whenCondition; + + @JsonProperty("tags_on_failure") + @JsonPropertyDescription( + "A List of Strings that specifies the tags to be set in the event when lambda fails to " + + + "or exception occurs. This tag may be used in conditional expressions in " + + "other parts of the configuration") + private List tagsOnFailure = Collections.emptyList(); + + public List getTagsOnFailure() { + return tagsOnFailure; + } + + public String getWhenCondition() { + return whenCondition; + } + + public Boolean getResponseEventsMatch() { + return responseEventsMatch; + } + + @Override + public InvocationType getInvocationType() { + return invocationType; + } +} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinality.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinality.java deleted file mode 100644 index bc7f13489e..0000000000 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinality.java +++ /dev/null @@ -1,38 +0,0 @@ -package org.opensearch.dataprepper.plugins.lambda.processor; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonValue; - -import java.util.HashMap; -import java.util.Map; - -@Deprecated -public enum ResponseCardinality { - STRICT("strict"), - AGGREGATE("aggregate"); - - private final String value; - - ResponseCardinality(String value) { - this.value = value; - } - - @JsonValue - public String getValue() { - return value; - } - - private static final Map RESPONSE_CARDINALITY_MAP = new HashMap<>(); - - static { - for (ResponseCardinality type : ResponseCardinality.values()) { - RESPONSE_CARDINALITY_MAP.put(type.getValue(), type); - } - } - - @JsonCreator - public static ResponseCardinality fromString(String value) { - return RESPONSE_CARDINALITY_MAP.get(value); - } -} - diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java index 128efe2b46..4d6a8e9f28 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategy.java @@ -1,8 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.dataprepper.plugins.lambda.processor; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy; import java.util.List; import java.util.Map; diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java index 2d4213ccf3..7f840c4cf5 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java @@ -5,12 +5,21 @@ package org.opensearch.dataprepper.plugins.lambda.sink; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Timer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.TimeUnit; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventHandle; import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; @@ -18,100 +27,180 @@ import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.model.sink.Sink; import org.opensearch.dataprepper.model.sink.SinkContext; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; +import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import java.util.Collection; - @DataPrepperPlugin(name = "aws_lambda", pluginType = Sink.class, pluginConfigurationType = LambdaSinkConfig.class) public class LambdaSink extends AbstractSink> { - private static final Logger LOG = LoggerFactory.getLogger(LambdaSink.class); - private volatile boolean sinkInitialized; - private final LambdaSinkService lambdaSinkService; - private final BufferFactory bufferFactory; - private static final String BUCKET = "bucket"; - private static final String KEY_PATH = "key_path_prefix"; - private DlqPushHandler dlqPushHandler = null; - - @DataPrepperPluginConstructor - public LambdaSink(final PluginSetting pluginSetting, - final LambdaSinkConfig lambdaSinkConfig, - final PluginFactory pluginFactory, - final SinkContext sinkContext, - final AwsCredentialsSupplier awsCredentialsSupplier, - final ExpressionEvaluator expressionEvaluator - ) { - super(pluginSetting); - sinkInitialized = Boolean.FALSE; - OutputCodecContext outputCodecContext = OutputCodecContext.fromSinkContext(sinkContext); - LambdaAsyncClient lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( - lambdaSinkConfig.getAwsAuthenticationOptions(), - lambdaSinkConfig.getMaxConnectionRetries(), - awsCredentialsSupplier, - lambdaSinkConfig.getConnectionTimeout() - ); - if(lambdaSinkConfig.getDlqPluginSetting() != null) { - this.dlqPushHandler = new DlqPushHandler(pluginFactory, - String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(BUCKET)), - lambdaSinkConfig.getDlqStsRoleARN() - , lambdaSinkConfig.getDlqStsRegion(), - String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(KEY_PATH))); - } + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaSinkObjectsEventsSucceeded"; + public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaSinkObjectsEventsFailed"; + public static final String NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA = "lambdaSinkNumberOfRequestsSucceeded"; + public static final String NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA = "lambdaSinkNumberOfRequestsFailed"; + public static final String LAMBDA_LATENCY_METRIC = "lambdaSinkLatency"; + public static final String REQUEST_PAYLOAD_SIZE = "lambdaSinkRequestPayloadSize"; + public static final String RESPONSE_PAYLOAD_SIZE = "lambdaSinkResponsePayloadSize"; - this.bufferFactory = new InMemoryBufferFactory(); + private static final Logger LOG = LoggerFactory.getLogger(LambdaSink.class); + private static final String BUCKET = "bucket"; + private static final String KEY_PATH = "key_path_prefix"; + private final Counter numberOfRecordsSuccessCounter; + private final Counter numberOfRecordsFailedCounter; + private final Counter numberOfRequestsSuccessCounter; + private final Counter numberOfRequestsFailedCounter; + private final LambdaSinkConfig lambdaSinkConfig; + private final ExpressionEvaluator expressionEvaluator; + private final LambdaAsyncClient lambdaAsyncClient; + private final DistributionSummary responsePayloadMetric; + private final Timer lambdaLatencyMetric; + private final DistributionSummary requestPayloadMetric; + private final PluginSetting pluginSetting; + private final OutputCodecContext outputCodecContext; + private volatile boolean sinkInitialized; + private DlqPushHandler dlqPushHandler = null; - lambdaSinkService = new LambdaSinkService(lambdaAsyncClient, - lambdaSinkConfig, - pluginMetrics, - pluginFactory, - pluginSetting, - outputCodecContext, - awsCredentialsSupplier, - dlqPushHandler, - bufferFactory, - expressionEvaluator); + @DataPrepperPluginConstructor + public LambdaSink(final PluginSetting pluginSetting, + final LambdaSinkConfig lambdaSinkConfig, + final PluginFactory pluginFactory, + final SinkContext sinkContext, + final AwsCredentialsSupplier awsCredentialsSupplier, + final ExpressionEvaluator expressionEvaluator + ) { + super(pluginSetting); + this.pluginSetting = pluginSetting; + sinkInitialized = Boolean.FALSE; + this.lambdaSinkConfig = lambdaSinkConfig; + this.expressionEvaluator = expressionEvaluator; + this.outputCodecContext = OutputCodecContext.fromSinkContext(sinkContext); + this.numberOfRecordsSuccessCounter = pluginMetrics.counter( + NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); + this.numberOfRecordsFailedCounter = pluginMetrics.counter( + NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED); + this.numberOfRequestsSuccessCounter = pluginMetrics.counter( + NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA); + this.numberOfRequestsFailedCounter = pluginMetrics.counter( + NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA); + this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC); + this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE); + this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE); + ClientOptions clientOptions = lambdaSinkConfig.getClientOptions(); + if(clientOptions == null){ + clientOptions = new ClientOptions(); } - - @Override - public boolean isReady() { - return sinkInitialized; + this.lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( + lambdaSinkConfig.getAwsAuthenticationOptions(), + awsCredentialsSupplier, + clientOptions + ); + if (lambdaSinkConfig.getDlqPluginSetting() != null) { + this.dlqPushHandler = new DlqPushHandler(pluginFactory, + String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(BUCKET)), + lambdaSinkConfig.getDlqStsRoleARN() + , lambdaSinkConfig.getDlqStsRegion(), + String.valueOf(lambdaSinkConfig.getDlqPluginSetting().get(KEY_PATH))); } - @Override - public void doInitialize() { - try { - doInitializeInternal(); - } catch (InvalidPluginConfigurationException e) { - LOG.error("Invalid plugin configuration, Hence failed to initialize s3-sink plugin."); - this.shutdown(); - throw e; - } catch (Exception e) { - LOG.error("Failed to initialize lambda plugin."); - this.shutdown(); - throw e; - } + } + + @Override + public boolean isReady() { + return sinkInitialized; + } + + @Override + public void doInitialize() { + try { + doInitializeInternal(); + } catch (InvalidPluginConfigurationException e) { + LOG.error("Invalid plugin configuration, Hence failed to initialize s3-sink plugin."); + this.shutdown(); + throw e; + } catch (Exception e) { + LOG.error("Failed to initialize lambda plugin."); + this.shutdown(); + throw e; } + } + + private void doInitializeInternal() { + sinkInitialized = Boolean.TRUE; + } + + /** + * @param records Records to be output + */ + @Override + public void doOutput(final Collection> records) { - private void doInitializeInternal() { - sinkInitialized = Boolean.TRUE; + if (records.isEmpty()) { + return; } - /** - * @param records Records to be output - */ - @Override - public void doOutput(final Collection> records) { + //Result from lambda is not currently processes. + LambdaCommonHandler.sendRecords(records, + lambdaSinkConfig, + lambdaAsyncClient, + outputCodecContext, + (inputBuffer, invokeResponse) -> { + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsSuccessCounter.increment(); + releaseEventHandlesPerBatch(true, inputBuffer); + return new ArrayList<>(); + }, + (inputBuffer) -> { + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsFailedCounter.increment(); + handleFailure(new RuntimeException("failed"), inputBuffer); + return new ArrayList<>(); + }); + } + + void handleFailure(Throwable throwable, Buffer flushedBuffer) { + try { + if (flushedBuffer.getEventCount() > 0) { + numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); + } + + SdkBytes payload = flushedBuffer.getPayload(); + if (dlqPushHandler != null) { + dlqPushHandler.perform(pluginSetting, + new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0)); + releaseEventHandlesPerBatch(true, flushedBuffer); + } else { + releaseEventHandlesPerBatch(false, flushedBuffer); + } + } catch (Exception ex) { + LOG.error("Exception occured during error handling"); + } + } - if (records.isEmpty()) { - return; + /* + * Release events per batch + */ + private void releaseEventHandlesPerBatch(boolean success, Buffer flushedBuffer) { + List> records = flushedBuffer.getRecords(); + for (Record record : records) { + Event event = record.getData(); + if (event != null) { + EventHandle eventHandle = event.getEventHandle(); + if (eventHandle != null) { + eventHandle.release(success); } - lambdaSinkService.output(records); + } } -} \ No newline at end of file + } +} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfig.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfig.java index e901d1fa03..f03b77b896 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfig.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfig.java @@ -2,105 +2,48 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ + package org.opensearch.dataprepper.plugins.lambda.sink; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; -import jakarta.validation.Valid; -import jakarta.validation.constraints.NotEmpty; -import jakarta.validation.constraints.NotNull; -import jakarta.validation.constraints.Size; -import org.opensearch.dataprepper.model.configuration.PluginModel; -import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import static org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig.DEFAULT_CONNECTION_RETRIES; -import static org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig.DEFAULT_CONNECTION_TIMEOUT; -import static org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig.STS_REGION; -import static org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig.STS_ROLE_ARN; - -import java.time.Duration; import java.util.Map; import java.util.Objects; - -public class LambdaSinkConfig { - - @JsonProperty("aws") - @NotNull - @Valid - private AwsAuthenticationOptions awsAuthenticationOptions; - - @JsonPropertyDescription("Lambda Function Name") - @JsonProperty("function_name") - @NotEmpty - @NotNull - @Size(min = 3, max = 500, message = "function name length should be at least 3 characters") - private String functionName; - - @JsonPropertyDescription("Total retries we want before failing") - @JsonProperty("max_retries") - private int maxConnectionRetries = DEFAULT_CONNECTION_RETRIES; - - @JsonPropertyDescription("invocation type defines the way we want to call lambda function") - @JsonProperty("invocation_type") - private InvocationType invocationType = InvocationType.EVENT; - - @JsonProperty("dlq") - private PluginModel dlq; - - @JsonProperty("batch") - private BatchOptions batchOptions; - - @JsonPropertyDescription("defines a condition for event to use this processor") - @JsonProperty("lambda_when") - private String whenCondition; - - @JsonPropertyDescription("sdk timeout defines the time sdk maintains the connection to the client before timing out") - @JsonProperty("connection_timeout") - private Duration connectionTimeout = DEFAULT_CONNECTION_TIMEOUT; - - public Duration getConnectionTimeout(){return connectionTimeout;} - - public AwsAuthenticationOptions getAwsAuthenticationOptions() { - return awsAuthenticationOptions; - } - - public BatchOptions getBatchOptions(){return batchOptions;} - - public String getFunctionName() { - return functionName; - } - - public int getMaxConnectionRetries() { - return maxConnectionRetries; - } - - public PluginModel getDlq() { - return dlq; - } - - public String getDlqStsRoleARN(){ - return Objects.nonNull(getDlqPluginSetting().get(STS_ROLE_ARN)) ? - String.valueOf(getDlqPluginSetting().get(STS_ROLE_ARN)) : - awsAuthenticationOptions.getAwsStsRoleArn(); - } - - public String getDlqStsRegion(){ - return Objects.nonNull(getDlqPluginSetting().get(STS_REGION)) ? - String.valueOf(getDlqPluginSetting().get(STS_REGION)) : - awsAuthenticationOptions.getAwsRegion().toString(); - } - - public Map getDlqPluginSetting(){ - return dlq != null ? dlq.getPluginSettings() : Map.of(); - } - - public InvocationType getInvocationType() { - return invocationType; - } - - public String getWhenCondition() { - return whenCondition; - } - -} \ No newline at end of file +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; + +public class LambdaSinkConfig extends LambdaCommonConfig { + + @JsonPropertyDescription("invocation type defines the way we want to call lambda function") + @JsonProperty("invocation_type") + private InvocationType invocationType = InvocationType.EVENT; + + @JsonProperty("dlq") + private PluginModel dlq; + + public PluginModel getDlq() { + return dlq; + } + + public String getDlqStsRoleARN() { + return dlq != null ? (Objects.nonNull(getDlqPluginSetting().get(STS_ROLE_ARN)) ? + String.valueOf(getDlqPluginSetting().get(STS_ROLE_ARN)) : + getAwsAuthenticationOptions().getAwsStsRoleArn()) : null; + } + + public String getDlqStsRegion() { + return dlq != null ? (Objects.nonNull(getDlqPluginSetting().get(STS_REGION)) ? + String.valueOf(getDlqPluginSetting().get(STS_REGION)) : + getAwsAuthenticationOptions().getAwsRegion().toString()) : null; + } + + public Map getDlqPluginSetting() { + return dlq != null ? dlq.getPluginSettings() : null; + } + + @Override + public InvocationType getInvocationType() { + return invocationType; + } +} diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java deleted file mode 100644 index 595a488c55..0000000000 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkService.java +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.lambda.sink; - -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Timer; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; -import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; -import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.codec.OutputCodec; -import org.opensearch.dataprepper.model.configuration.PluginSetting; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.event.EventHandle; -import org.opensearch.dataprepper.model.plugin.PluginFactory; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.sink.OutputCodecContext; -import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; -import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig; -import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck; -import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; -import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeResponse; - -import java.io.IOException; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; - -public class LambdaSinkService { - - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS = "lambdaSinkObjectsEventsSucceeded"; - public static final String NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED = "lambdaSinkObjectsEventsFailed"; - public static final String LAMBDA_LATENCY_METRIC = "lambdaSinkLatency"; - public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize"; - public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize"; - private static final Logger LOG = LoggerFactory.getLogger(LambdaSinkService.class); - private final AtomicLong requestPayloadMetric; - private final AtomicLong responsePayloadMetric; - private final PluginSetting pluginSetting; - private final Lock reentrantLock; - private final LambdaSinkConfig lambdaSinkConfig; - private final LambdaAsyncClient lambdaAsyncClient; - private final String functionName; - private final String whenCondition; - private final ExpressionEvaluator expressionEvaluator; - private final Counter numberOfRecordsSuccessCounter; - private final Counter numberOfRecordsFailedCounter; - private final Timer lambdaLatencyMetric; - private final String invocationType; - private final BufferFactory bufferFactory; - private final DlqPushHandler dlqPushHandler; - private final BatchOptions batchOptions; - private int maxEvents = 0; - private ByteCount maxBytes = null; - private Duration maxCollectionDuration = null; - private int maxRetries = 0; - private OutputCodec requestCodec = null; - private OutputCodecContext codecContext = null; - private final LambdaCommonHandler lambdaCommonHandler; - private Buffer currentBufferPerBatch = null; - List> futureList; - - - public LambdaSinkService(final LambdaAsyncClient lambdaAsyncClient, final LambdaSinkConfig lambdaSinkConfig, final PluginMetrics pluginMetrics, final PluginFactory pluginFactory, final PluginSetting pluginSetting, final OutputCodecContext codecContext, final AwsCredentialsSupplier awsCredentialsSupplier, final DlqPushHandler dlqPushHandler, final BufferFactory bufferFactory, final ExpressionEvaluator expressionEvaluator) { - this.lambdaSinkConfig = lambdaSinkConfig; - this.pluginSetting = pluginSetting; - this.expressionEvaluator = expressionEvaluator; - this.dlqPushHandler = dlqPushHandler; - this.lambdaAsyncClient = lambdaAsyncClient; - this.numberOfRecordsSuccessCounter = pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); - this.numberOfRecordsFailedCounter = pluginMetrics.counter(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED); - this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC); - this.requestPayloadMetric = pluginMetrics.gauge(REQUEST_PAYLOAD_SIZE, new AtomicLong()); - this.responsePayloadMetric = pluginMetrics.gauge(RESPONSE_PAYLOAD_SIZE, new AtomicLong()); - this.codecContext = codecContext; - - reentrantLock = new ReentrantLock(); - functionName = lambdaSinkConfig.getFunctionName(); - maxRetries = lambdaSinkConfig.getMaxConnectionRetries(); - batchOptions = lambdaSinkConfig.getBatchOptions(); - whenCondition = lambdaSinkConfig.getWhenCondition(); - - JsonOutputCodecConfig jsonOutputCodecConfig = new JsonOutputCodecConfig(); - jsonOutputCodecConfig.setKeyName(batchOptions.getKeyName()); - requestCodec = new JsonOutputCodec(jsonOutputCodecConfig); - - maxEvents = batchOptions.getThresholdOptions().getEventCount(); - maxBytes = batchOptions.getThresholdOptions().getMaximumSize(); - maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut(); - invocationType = lambdaSinkConfig.getInvocationType().getAwsLambdaValue(); - futureList = Collections.synchronizedList(new ArrayList<>()); - - this.bufferFactory = bufferFactory; - - LOG.info("LambdaFunctionName:{} , invocationType:{}", functionName, invocationType); - // Initialize LambdaCommonHandler - lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, functionName, invocationType); - } - - - public void output(Collection> records) { - if (records.isEmpty()) { - return; - } - - //Result from lambda is not currently processes. - List> resultRecords = null; - - reentrantLock.lock(); - try { - for (Record record : records) { - final Event event = record.getData(); - - if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) { - releaseEventHandle(event, true); - continue; - } - try { - if (currentBufferPerBatch.getEventCount() == 0) { - requestCodec.start(currentBufferPerBatch.getOutputStream(), event, codecContext); - } - requestCodec.writeEvent(event, currentBufferPerBatch.getOutputStream()); - currentBufferPerBatch.addRecord(record); - - flushToLambdaIfNeeded(resultRecords, false); - } catch (IOException e) { - LOG.error("Exception while writing to codec {}", event, e); - handleFailure(e, currentBufferPerBatch); - } catch (Exception e) { - LOG.error("Exception while processing event {}", event, e); - handleFailure(e, currentBufferPerBatch); - currentBufferPerBatch.reset(); - } - } - // Flush any remaining events after processing all records - if (currentBufferPerBatch.getEventCount() > 0) { - LOG.info("Force Flushing the remaining {} events in the buffer", currentBufferPerBatch.getEventCount()); - try { - flushToLambdaIfNeeded(resultRecords, true); // Force flush remaining events - } catch (Exception e) { - LOG.error("Exception while flushing remaining events", e); - handleFailure(e, currentBufferPerBatch); - } - } - } finally { - reentrantLock.unlock(); - } - - // Wait for all futures to complete - lambdaCommonHandler.waitForFutures(futureList); - - // Release event handles for records not sent to Lambda - for (Record record : records) { - Event event = record.getData(); - releaseEventHandle(event, true); - } - - } - - void flushToLambdaIfNeeded(List> resultRecords, boolean forceFlush) { - if (forceFlush || ThresholdCheck.checkThresholdExceed(currentBufferPerBatch, maxEvents, maxBytes, maxCollectionDuration)) { - try { - requestCodec.complete(currentBufferPerBatch.getOutputStream()); - - // Capture buffer before resetting - final Buffer flushedBuffer = currentBufferPerBatch; - final int eventCount = currentBufferPerBatch.getEventCount(); - - CompletableFuture future = flushedBuffer.flushToLambda(invocationType); - - // Handle future - CompletableFuture processingFuture = future.thenAccept(response -> { - handleLambdaResponse(flushedBuffer, eventCount, response); - }).exceptionally(throwable -> { - // Failure handler - List> bufferRecords = flushedBuffer.getRecords(); - Record eventRecord = bufferRecords.isEmpty() ? null : bufferRecords.get(0); - LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {} , Event: {}", - functionName, eventRecord == null? "null":eventRecord.getData(), throwable); - requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - responsePayloadMetric.set(0); - Duration latency = flushedBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - handleFailure(throwable, flushedBuffer); - return null; - }); - - futureList.add(processingFuture); - - // Create a new buffer for the next batch - currentBufferPerBatch = lambdaCommonHandler.createBuffer(bufferFactory); - } catch (IOException e) { - LOG.error("Exception while flushing to lambda", e); - handleFailure(e, currentBufferPerBatch); - currentBufferPerBatch = lambdaCommonHandler.createBuffer(bufferFactory); - } - } - } - - void handleFailure(Throwable throwable, Buffer flushedBuffer) { - try { - if (flushedBuffer.getEventCount() > 0) { - numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount()); - } - - SdkBytes payload = flushedBuffer.getPayload(); - if (dlqPushHandler != null) { - dlqPushHandler.perform(pluginSetting, new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0)); - releaseEventHandlesPerBatch(true, flushedBuffer); - } else { - releaseEventHandlesPerBatch(false, flushedBuffer); - } - } catch (Exception ex){ - LOG.error("Exception occured during error handling"); - } - } - - /* - * Release events per batch - */ - private void releaseEventHandlesPerBatch(boolean success, Buffer flushedBuffer) { - List> records = flushedBuffer.getRecords(); - for (Record record : records) { - Event event = record.getData(); - releaseEventHandle(event, success); - } - } - - /** - * Releases the event handle based on processing success. - * - * @param event the event to release - * @param success indicates if processing was successful - */ - private void releaseEventHandle(Event event, boolean success) { - if (event != null) { - EventHandle eventHandle = event.getEventHandle(); - if (eventHandle != null) { - eventHandle.release(success); - } - } - } - - private void handleLambdaResponse(Buffer flushedBuffer, int eventCount, InvokeResponse response) { - boolean success = lambdaCommonHandler.checkStatusCode(response); - if (success) { - LOG.info("Successfully flushed {} events", eventCount); - SdkBytes payload = response.payload(); - if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) { - responsePayloadMetric.set(0); - } else { - responsePayloadMetric.set(payload.asByteArray().length); - } - //metrics - requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize()); - numberOfRecordsSuccessCounter.increment(eventCount); - Duration latency = flushedBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - } - else { - // Non-2xx status code treated as failure - handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()), flushedBuffer); - } - } - -} diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java index 1d4e67316e..2c7f27654a 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/LambdaCommonHandlerTest.java @@ -1,122 +1,144 @@ package org.opensearch.dataprepper.plugins.lambda.common; -import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.jupiter.api.BeforeEach; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.Test; -import static org.mockito.ArgumentMatchers.any; -import org.mockito.InjectMocks; +import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; -import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import org.mockito.MockitoAnnotations; -import org.opensearch.dataprepper.model.event.EventHandle; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import org.slf4j.Logger; +import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; +import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.function.BiFunction; +import java.util.function.Function; -public class LambdaCommonHandlerTest { - - @Mock - private Logger mockLogger; - - @Mock - private LambdaAsyncClient mockLambdaAsyncClient; - - @Mock - private BufferFactory mockBufferFactory; - - @Mock - private Buffer mockBuffer; - - @Mock - private InvokeResponse mockInvokeResponse; - - @InjectMocks - private LambdaCommonHandler lambdaCommonHandler; - - private String functionName = "test-function"; - - private String invocationType = InvocationType.REQUEST_RESPONSE.getAwsLambdaValue(); - - @BeforeEach - public void setUp() { - MockitoAnnotations.openMocks(this); - lambdaCommonHandler = new LambdaCommonHandler(mockLogger, mockLambdaAsyncClient, functionName, invocationType); - } - - @Test - public void testCreateBuffer_success() throws IOException { - // Arrange - when(mockBufferFactory.getBuffer(any(), anyString(), any())).thenReturn(mockBuffer); - - // Act - Buffer result = lambdaCommonHandler.createBuffer(mockBufferFactory); - - // Assert - verify(mockBufferFactory, times(1)).getBuffer(mockLambdaAsyncClient, functionName, invocationType); - verify(mockLogger, times(1)).debug("Resetting buffer"); - assertEquals(result, mockBuffer); - } - - @Test - public void testCreateBuffer_throwsException() throws IOException { - // Arrange - when(mockBufferFactory.getBuffer(any(), anyString(), any())).thenThrow(new IOException("Test Exception")); - - // Act & Assert - try { - lambdaCommonHandler.createBuffer(mockBufferFactory); - } catch (RuntimeException e) { - assert e.getMessage().contains("Failed to reset buffer"); - } - verify(mockBufferFactory, times(1)).getBuffer(mockLambdaAsyncClient, functionName, invocationType); - } - - @Test - public void testWaitForFutures_allComplete() { - // Arrange - List> futureList = new ArrayList<>(); - futureList.add(CompletableFuture.completedFuture(null)); - futureList.add(CompletableFuture.completedFuture(null)); - - // Act - lambdaCommonHandler.waitForFutures(futureList); - - // Assert - assert futureList.isEmpty(); - } - - @Test - public void testWaitForFutures_withException() { - // Arrange - List> futureList = new ArrayList<>(); - futureList.add(CompletableFuture.failedFuture(new RuntimeException("Test Exception"))); - - // Act - lambdaCommonHandler.waitForFutures(futureList); - - // Assert - assert futureList.isEmpty(); - } - - private List mockEventHandleList(int size) { - List eventHandleList = new ArrayList<>(); - for (int i = 0; i < size; i++) { - EventHandle eventHandle = mock(EventHandle.class); - eventHandleList.add(eventHandle); - } - return eventHandleList; - } +import static org.mockito.ArgumentMatchers.any; +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +class LambdaCommonHandlerTest { + + @Mock + private LambdaAsyncClient lambdaAsyncClient; + + @Mock + private LambdaCommonConfig config; + + @Mock + private BatchOptions batchOptions; + + @Mock + private OutputCodecContext outputCodecContext; + + @Test + void testCheckStatusCode() { + InvokeResponse successResponse = InvokeResponse.builder().statusCode(200).build(); + InvokeResponse failureResponse = InvokeResponse.builder().statusCode(400).build(); + + assertTrue(LambdaCommonHandler.isSuccess(successResponse)); + assertFalse(LambdaCommonHandler.isSuccess(failureResponse)); + } + + @Test + void testWaitForFutures() { + List> futureList = new ArrayList<>(); + CompletableFuture future1 = new CompletableFuture<>(); + CompletableFuture future2 = new CompletableFuture<>(); + futureList.add(future1); + futureList.add(future2); + + // Simulate completion of futures + future1.complete(InvokeResponse.builder().build()); + future2.complete(InvokeResponse.builder().build()); + + LambdaCommonHandler.waitForFutures(futureList); + + assertFalse(futureList.isEmpty()); + } + + @Test + void testSendRecords() { + when(config.getBatchOptions()).thenReturn(batchOptions); + when(batchOptions.getThresholdOptions()).thenReturn(mock(ThresholdOptions.class)); + when(batchOptions.getKeyName()).thenReturn("testKey"); + when(config.getFunctionName()).thenReturn("testFunction"); + when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.completedFuture(InvokeResponse.builder().statusCode(200).build())); + + Event mockEvent = mock(Event.class); + when(mockEvent.toMap()).thenReturn(Collections.singletonMap("testKey", "testValue")); + List> records = Collections.singletonList(new Record<>(mockEvent)); + + BiFunction>> successHandler = (buffer, response) -> new ArrayList<>(); + Function>> failureHandler = (buffer) -> new ArrayList<>(); + + List> result = LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext, successHandler, failureHandler); + + assertNotNull(result); + verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); + } + + @Test + void testSendRecordsWithNullKeyName() { + when(config.getBatchOptions()).thenReturn(batchOptions); + when(batchOptions.getThresholdOptions()).thenReturn(mock(ThresholdOptions.class)); + when(batchOptions.getKeyName()).thenReturn(null); + when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); + when(config.getFunctionName()).thenReturn("testFunction"); + + Event mockEvent = mock(Event.class); + when(mockEvent.toMap()).thenReturn(Collections.singletonMap("testKey", "testValue")); + List> records = Collections.singletonList(new Record<>(mockEvent)); + + BiFunction>> successHandler = (buffer, response) -> new ArrayList<>(); + Function>> failureHandler = (buffer) -> new ArrayList<>(); + + assertThrows(NullPointerException.class, () -> + LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext, successHandler, failureHandler) + ); + } + + @Test + void testSendRecordsWithFailure() { + when(config.getBatchOptions()).thenReturn(batchOptions); + when(batchOptions.getThresholdOptions()).thenReturn(mock(ThresholdOptions.class)); + when(batchOptions.getKeyName()).thenReturn("testKey"); + when(config.getFunctionName()).thenReturn("testFunction"); + when(config.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException("Test exception"))); + + List> records = new ArrayList<>(); + records.add(new Record<>(mock(Event.class))); + + BiFunction>> successHandler = (buffer, response) -> new ArrayList<>(); + Function>> failureHandler = (buffer) -> new ArrayList<>(); + + List> result = LambdaCommonHandler.sendRecords(records, config, lambdaAsyncClient, outputCodecContext, successHandler, failureHandler); + + assertNotNull(result); + verify(lambdaAsyncClient, atLeastOnce()).invoke(any(InvokeRequest.class)); + } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferFactoryTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferFactoryTest.java deleted file mode 100644 index 37276db819..0000000000 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferFactoryTest.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.dataprepper.plugins.lambda.common.accumulator; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferFactory; - -import static org.hamcrest.CoreMatchers.instanceOf; -import static org.hamcrest.MatcherAssert.assertThat; - -class InMemoryBufferFactoryTest { - - @Test - void test_inMemoryBufferFactory_notNull(){ - InMemoryBufferFactory inMemoryBufferFactory = new InMemoryBufferFactory(); - Assertions.assertNotNull(inMemoryBufferFactory); - } - - @Test - void test_buffer_notNull(){ - InMemoryBufferFactory inMemoryBufferFactory = new InMemoryBufferFactory(); - Assertions.assertNotNull(inMemoryBufferFactory); - Buffer buffer = inMemoryBufferFactory.getBuffer(null, null, null); - Assertions.assertNotNull(buffer); - assertThat(buffer, instanceOf(Buffer.class)); - } -} \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java index 9a9bb1eef6..89489e7ab1 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/accumulator/InMemoryBufferTest.java @@ -12,13 +12,22 @@ import static org.hamcrest.Matchers.instanceOf; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import static org.mockito.ArgumentMatchers.any; import org.mockito.Mock; -import static org.mockito.Mockito.when; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; import software.amazon.awssdk.core.SdkBytes; @@ -27,121 +36,121 @@ import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.io.IOException; -import java.io.OutputStream; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; - @ExtendWith(MockitoExtension.class) class InMemoryBufferTest { - public static final int MAX_EVENTS = 55; - @Mock - private LambdaAsyncClient lambdaAsyncClient; - - private final String invocationType = InvocationType.REQUEST_RESPONSE.getAwsLambdaValue(); - - private final String functionName = "testFunction"; - - private InMemoryBuffer inMemoryBuffer; - - @Test - void test_with_write_event_into_buffer() throws IOException { - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - - while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { - OutputStream outputStream = inMemoryBuffer.getOutputStream(); - outputStream.write(generateByteArray()); - int eventCount = inMemoryBuffer.getEventCount() +1; - inMemoryBuffer.setEventCount(eventCount); - } - assertThat(inMemoryBuffer.getSize(), greaterThanOrEqualTo(54110L)); - assertThat(inMemoryBuffer.getEventCount(), equalTo(MAX_EVENTS)); - assertThat(inMemoryBuffer.getDuration(), notNullValue()); - assertThat(inMemoryBuffer.getDuration(), greaterThanOrEqualTo(Duration.ZERO)); - } - - @Test - void test_with_write_event_into_buffer_and_flush_toLambda() throws IOException { - - // Mock the response of the invoke method - InvokeResponse mockResponse = InvokeResponse.builder() - .statusCode(200) // HTTP 200 for successful invocation - .payload(SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) - .build(); - CompletableFuture future = CompletableFuture.completedFuture(mockResponse); - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); - - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { - OutputStream outputStream = inMemoryBuffer.getOutputStream(); - outputStream.write(generateByteArray()); - int eventCount = inMemoryBuffer.getEventCount() +1; - inMemoryBuffer.setEventCount(eventCount); - } - assertDoesNotThrow(() -> { - CompletableFuture responseFuture = inMemoryBuffer.flushToLambda(invocationType); - InvokeResponse response = responseFuture.join(); - assertThat(response.statusCode(), equalTo(200)); - }); + public static final int MAX_EVENTS = 55; + private final String invocationType = InvocationType.REQUEST_RESPONSE.getAwsLambdaValue(); + private final String functionName = "testFunction"; + private final String batchOptionKeyName = "bathOption"; + @Mock + private LambdaAsyncClient lambdaAsyncClient; + + + @Test + void test_with_write_event_into_buffer() { + InMemoryBuffer inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + //UUID based random event created. Each UUID string is of 36 characters long + int eachEventSize = 36; + long sizeToAssert = eachEventSize * MAX_EVENTS; + while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { + inMemoryBuffer.addRecord(getSampleRecord()); } - - @Test - void test_uploadedToLambda_success() throws IOException { - // Mock the response of the invoke method - InvokeResponse mockResponse = InvokeResponse.builder() - .statusCode(200) // HTTP 200 for successful invocation - .payload(SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) - .build(); - - CompletableFuture future = CompletableFuture.completedFuture(mockResponse); - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); - - - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - assertNotNull(inMemoryBuffer); - OutputStream outputStream = inMemoryBuffer.getOutputStream(); - outputStream.write(generateByteArray()); - inMemoryBuffer.setEventCount(1); - - assertDoesNotThrow(() -> { - CompletableFuture responseFuture = inMemoryBuffer.flushToLambda(invocationType); - InvokeResponse response = responseFuture.join(); - assertThat(response.statusCode(), equalTo(200)); - }); - } - - @Test - void test_uploadedToLambda_fails() { - // Mock an exception when invoking lambda - SdkClientException sdkClientException = SdkClientException.create("Mock exception"); - - CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(sdkClientException); - - when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); - - inMemoryBuffer = new InMemoryBuffer(lambdaAsyncClient, functionName, invocationType); - assertNotNull(inMemoryBuffer); - - // Execute and assert exception - CompletionException exception = assertThrows(CompletionException.class, () -> { - CompletableFuture responseFuture = inMemoryBuffer.flushToLambda(invocationType); - responseFuture.join(); // This will throw CompletionException - }); - - // Verify that the cause of the CompletionException is the SdkClientException we threw - assertThat(exception.getCause(), instanceOf(SdkClientException.class)); - assertThat(exception.getCause().getMessage(), equalTo("Mock exception")); - + assertThat(inMemoryBuffer.getSize(), greaterThanOrEqualTo(sizeToAssert)); + assertThat(inMemoryBuffer.getEventCount(), equalTo(MAX_EVENTS)); + assertThat(inMemoryBuffer.getDuration(), notNullValue()); + assertThat(inMemoryBuffer.getDuration(), greaterThanOrEqualTo(Duration.ZERO)); + } + + @Test + void test_with_write_event_into_buffer_and_flush_toLambda() { + + // Mock the response of the invoke method + InvokeResponse mockResponse = InvokeResponse.builder() + .statusCode(200) // HTTP 200 for successful invocation + .payload( + SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) + .build(); + CompletableFuture future = CompletableFuture.completedFuture(mockResponse); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); + + InMemoryBuffer inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { + inMemoryBuffer.addRecord(getSampleRecord()); } - - private byte[] generateByteArray() { - byte[] bytes = new byte[1000]; - for (int i = 0; i < 1000; i++) { - bytes[i] = (byte) i; - } - return bytes; + assertDoesNotThrow(() -> { + InvokeRequest requestPayload = inMemoryBuffer.getRequestPayload( + functionName, invocationType); + CompletableFuture responseFuture = lambdaAsyncClient.invoke(requestPayload); + InvokeResponse response = responseFuture.join(); + assertThat(response.statusCode(), equalTo(200)); + }); + } + + private Record getSampleRecord() { + Event event = JacksonEvent.fromMessage(String.valueOf(UUID.randomUUID())); + return new Record<>(event); + } + + @Test + void test_uploadedToLambda_success() { + // Mock the response of the invoke method + InvokeResponse mockResponse = InvokeResponse.builder() + .statusCode(200) // HTTP 200 for successful invocation + .payload( + SdkBytes.fromString("{\"key\": \"value\"}", java.nio.charset.StandardCharsets.UTF_8)) + .build(); + + CompletableFuture future = CompletableFuture.completedFuture(mockResponse); + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); + + InMemoryBuffer inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + assertNotNull(inMemoryBuffer); + inMemoryBuffer.addRecord(getSampleRecord()); + + assertDoesNotThrow(() -> { + InvokeRequest requestPayload = inMemoryBuffer.getRequestPayload( + functionName, invocationType); + CompletableFuture responseFuture = lambdaAsyncClient.invoke(requestPayload); + InvokeResponse response = responseFuture.join(); + assertThat(response.statusCode(), equalTo(200)); + }); + } + + @Test + void test_uploadedToLambda_fails() { + // Mock an exception when invoking lambda + SdkClientException sdkClientException = SdkClientException.create("Mock exception"); + + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(sdkClientException); + + when(lambdaAsyncClient.invoke(any(InvokeRequest.class))).thenReturn(future); + + InMemoryBuffer inMemoryBuffer = new InMemoryBuffer(batchOptionKeyName); + assertNotNull(inMemoryBuffer); + + assertNull(inMemoryBuffer.getRequestPayload(functionName, invocationType)); + inMemoryBuffer.addRecord(getSampleRecord()); + // Execute and assert exception + CompletionException exception = assertThrows(CompletionException.class, () -> { + InvokeRequest requestPayload = inMemoryBuffer.getRequestPayload( + functionName, invocationType); + CompletableFuture responseFuture = lambdaAsyncClient.invoke(requestPayload); + responseFuture.join();// This will throw CompletionException + }); + + // Verify that the cause of the CompletionException is the SdkClientException we threw + assertThat(exception.getCause(), instanceOf(SdkClientException.class)); + assertThat(exception.getCause().getMessage(), equalTo("Mock exception")); + + } + + private byte[] generateByteArray() { + byte[] bytes = new byte[1000]; + for (int i = 0; i < 1000; i++) { + bytes[i] = (byte) i; } -} \ No newline at end of file + return bytes; + } +} diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java index 7e3d160d58..cd68d73362 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/common/client/LambdaClientFactoryTest.java @@ -1,192 +1,78 @@ package org.opensearch.dataprepper.plugins.lambda.common.client; -import static org.hamcrest.CoreMatchers.notNullValue; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.ArgumentCaptor; -import static org.mockito.ArgumentMatchers.any; import org.mockito.Mock; -import org.mockito.MockedStatic; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; -import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.plugins.metricpublisher.MicrometerMetricPublisher; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; -import software.amazon.awssdk.core.retry.RetryPolicy; -import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.LambdaAsyncClientBuilder; -import java.time.Duration; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; +import java.util.HashMap; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class LambdaClientFactoryTest { - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - - @Mock - private AwsCredentialsProvider awsCredentialsProvider; - - private Duration sdkTimeout = Duration.ofSeconds(60); - - @Test - void createLambdaAsyncClient_with_real_LambdaAsyncClient() { - try (MockedStatic mockedStaticLambdaAsyncClient = mockStatic(LambdaAsyncClient.class); - MockedStatic mockedPluginMetrics = mockStatic(PluginMetrics.class)) { - - PluginMetrics pluginMetricsMock = mock(PluginMetrics.class); - mockedPluginMetrics.when(() -> PluginMetrics.fromNames("sdk", "aws")).thenReturn(pluginMetricsMock); - - LambdaAsyncClientBuilder lambdaAsyncClientBuilder = mock(LambdaAsyncClientBuilder.class); - LambdaAsyncClient lambdaAsyncClientMock = mock(LambdaAsyncClient.class); - - mockedStaticLambdaAsyncClient.when(LambdaAsyncClient::builder).thenReturn(lambdaAsyncClientBuilder); - - when(lambdaAsyncClientBuilder.region(any(Region.class))).thenReturn(lambdaAsyncClientBuilder); - when(lambdaAsyncClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(lambdaAsyncClientBuilder); - when(lambdaAsyncClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(lambdaAsyncClientBuilder); - when(lambdaAsyncClientBuilder.build()).thenReturn(lambdaAsyncClientMock); - - Region region = Region.US_WEST_2; - when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region); - String stsRoleArn = UUID.randomUUID().toString(); - String stsExternalId = UUID.randomUUID().toString(); - Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); - when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn); - when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(stsExternalId); - when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); - when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(awsCredentialsProvider); - - // Act - int maxConnectionRetries = 3; - LambdaAsyncClient lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( - awsAuthenticationOptions, - maxConnectionRetries, - awsCredentialsSupplier, - sdkTimeout - ); - - // Verify - assertThat(lambdaAsyncClient, notNullValue()); - verify(lambdaAsyncClientBuilder).region(region); - verify(lambdaAsyncClientBuilder).credentialsProvider(awsCredentialsProvider); - // Capture and verify ClientOverrideConfiguration - ArgumentCaptor configCaptor = ArgumentCaptor.forClass(ClientOverrideConfiguration.class); - verify(lambdaAsyncClientBuilder).overrideConfiguration(configCaptor.capture()); - ClientOverrideConfiguration config = configCaptor.getValue(); - - assertThat(config.apiCallTimeout(), equalTo(Optional.of(sdkTimeout))); - - // Verify RetryPolicy - assertThat(config.retryPolicy().isPresent(), equalTo(true)); - RetryPolicy retryPolicy = config.retryPolicy().get(); - assertThat(retryPolicy.numRetries(), equalTo(maxConnectionRetries)); - - // Verify MetricPublisher - assertThat(config.metricPublishers(), notNullValue()); - assertThat(config.metricPublishers().size(), equalTo(1)); - MetricPublisher metricPublisher = config.metricPublishers().get(0); - assertThat(metricPublisher, instanceOf(MicrometerMetricPublisher.class)); - - // Verify that awsCredentialsSupplier.getProvider was called with correct AwsCredentialsOptions - ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); - verify(awsCredentialsSupplier).getProvider(optionsCaptor.capture()); - AwsCredentialsOptions credentialsOptions = optionsCaptor.getValue(); - assertThat(credentialsOptions.getRegion(), equalTo(region)); - assertThat(credentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); - assertThat(credentialsOptions.getStsExternalId(), equalTo(stsExternalId)); - assertThat(credentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); - } - } - - @Test - void createAsyncLambdaClient_with_correct_configuration() { - try (MockedStatic mockedStaticLambdaAsyncClient = mockStatic(LambdaAsyncClient.class); - MockedStatic mockedPluginMetrics = mockStatic(PluginMetrics.class)) { - - PluginMetrics pluginMetricsMock = mock(PluginMetrics.class); - mockedPluginMetrics.when(() -> PluginMetrics.fromNames("sdk", "aws")).thenReturn(pluginMetricsMock); - - LambdaAsyncClientBuilder lambdaAsyncClientBuilder = mock(LambdaAsyncClientBuilder.class); - LambdaAsyncClient lambdaAsyncClientMock = mock(LambdaAsyncClient.class); - - mockedStaticLambdaAsyncClient.when(LambdaAsyncClient::builder).thenReturn(lambdaAsyncClientBuilder); - - when(lambdaAsyncClientBuilder.region(any(Region.class))).thenReturn(lambdaAsyncClientBuilder); - when(lambdaAsyncClientBuilder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(lambdaAsyncClientBuilder); - when(lambdaAsyncClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(lambdaAsyncClientBuilder); - when(lambdaAsyncClientBuilder.build()).thenReturn(lambdaAsyncClientMock); - - Region region = Region.US_WEST_2; - when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region); - String stsRoleArn = UUID.randomUUID().toString(); - String stsExternalId = UUID.randomUUID().toString(); - Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); - when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn); - when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(stsExternalId); - when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); - when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(awsCredentialsProvider); - - // Act - int maxConnectionRetries = 3; - LambdaAsyncClient lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient( - awsAuthenticationOptions, - maxConnectionRetries, - awsCredentialsSupplier, - sdkTimeout - ); - - // Verify - assertThat(lambdaAsyncClient, notNullValue()); - - // Verify builder methods - verify(lambdaAsyncClientBuilder).region(region); - verify(lambdaAsyncClientBuilder).credentialsProvider(awsCredentialsProvider); - - // Capture and verify ClientOverrideConfiguration - ArgumentCaptor configCaptor = ArgumentCaptor.forClass(ClientOverrideConfiguration.class); - verify(lambdaAsyncClientBuilder).overrideConfiguration(configCaptor.capture()); - ClientOverrideConfiguration config = configCaptor.getValue(); - - // Verify apiCallTimeout - assertThat(config.apiCallTimeout(), equalTo(Optional.of(sdkTimeout))); - - // Verify RetryPolicy - assertThat(config.retryPolicy().isPresent(), equalTo(true)); - RetryPolicy retryPolicy = config.retryPolicy().get(); - assertThat(retryPolicy.numRetries(), equalTo(maxConnectionRetries)); - - // Verify MetricPublisher - assertThat(config.metricPublishers(), notNullValue()); - assertThat(config.metricPublishers().size(), equalTo(1)); - MetricPublisher metricPublisher = config.metricPublishers().get(0); - assertThat(metricPublisher, instanceOf(MicrometerMetricPublisher.class)); - // Verify that awsCredentialsSupplier.getProvider was called with correct AwsCredentialsOptions - ArgumentCaptor optionsCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); - verify(awsCredentialsSupplier).getProvider(optionsCaptor.capture()); - AwsCredentialsOptions credentialsOptions = optionsCaptor.getValue(); - assertThat(credentialsOptions.getRegion(), equalTo(region)); - assertThat(credentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); - assertThat(credentialsOptions.getStsExternalId(), equalTo(stsExternalId)); - assertThat(credentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); - } - } + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private AwsCredentialsProvider awsCredentialsProvider; + + @BeforeEach + void setUp() { + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_WEST_2); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn("arn:aws:iam::123456789012:role/example-role"); + when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn("externalId"); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(new HashMap<>()); + + when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(awsCredentialsProvider); + } + + @Test + void testCreateAsyncLambdaClient() { + ClientOptions clientOptions = new ClientOptions(); + + LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient( + awsAuthenticationOptions, + awsCredentialsSupplier, + clientOptions + ); + + assertNotNull(client); + assertEquals(Region.US_WEST_2, client.serviceClientConfiguration().region()); + } + @Test + void testCreateAsyncLambdaClientOverrideConfiguration() { + ClientOptions clientOptions = new ClientOptions(); + + LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient( + awsAuthenticationOptions, + awsCredentialsSupplier, + clientOptions + ); + + assertNotNull(client); + ClientOverrideConfiguration overrideConfig = client.serviceClientConfiguration().overrideConfiguration(); + + assertEquals(clientOptions.getApiCallTimeout(), overrideConfig.apiCallTimeout().get()); + assertNotNull(overrideConfig.retryPolicy()); + assertNotNull(overrideConfig.metricPublishers()); + assertFalse(overrideConfig.metricPublishers().isEmpty()); + } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java index b5a4a088e5..691e72ee5c 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/AggregateResponseEventHandlingStrategyTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.dataprepper.plugins.lambda.processor; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/InvocationTypeTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/InvocationTypeTest.java index ff3e95e705..ace2145db7 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/InvocationTypeTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/InvocationTypeTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.dataprepper.plugins.lambda.processor; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfigTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfigTest.java index 9254a4d9dc..95e43361c6 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfigTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfigTest.java @@ -4,37 +4,93 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import software.amazon.awssdk.regions.Region; import java.time.Duration; +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; public class LambdaProcessorConfigTest { - public static final int DEFAULT_MAX_RETRIES = 3; - public static final Duration DEFAULT_SDK_TIMEOUT = Duration.ofSeconds(60); - private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); + private ObjectMapper objectMapper; - @Test - void lambda_processor_default_max_connection_retries_test() { - assertThat(new LambdaProcessorConfig().getMaxConnectionRetries(), equalTo(DEFAULT_MAX_RETRIES)); + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); + objectMapper.registerModule(new JavaTimeModule()); } @Test - void lambda_processor_default_sdk_timeout_test() { - assertThat(new LambdaProcessorConfig().getConnectionTimeout(), equalTo(DEFAULT_SDK_TIMEOUT)); + void test_defaults() { + final LambdaProcessorConfig lambdaProcessorConfig = new LambdaProcessorConfig(); + assertThat(lambdaProcessorConfig.getTagsOnFailure(), equalTo(List.of())); + assertThat(lambdaProcessorConfig.getWhenCondition(), equalTo(null)); + assertThat(lambdaProcessorConfig.getResponseEventsMatch(), equalTo(false)); + + // Test default client options + assertThat(lambdaProcessorConfig.getClientOptions(), notNullValue()); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxConnectionRetries(), equalTo(ClientOptions.DEFAULT_CONNECTION_RETRIES)); + assertThat(lambdaProcessorConfig.getClientOptions().getApiCallTimeout(), equalTo(ClientOptions.DEFAULT_API_TIMEOUT)); + assertThat(lambdaProcessorConfig.getClientOptions().getConnectionTimeout(), equalTo(ClientOptions.DEFAULT_CONNECTION_TIMEOUT)); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxConcurrency(), equalTo(ClientOptions.DEFAULT_MAXIMUM_CONCURRENCY)); + assertThat(lambdaProcessorConfig.getClientOptions().getBaseDelay(), equalTo(ClientOptions.DEFAULT_BASE_DELAY)); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxBackoff(), equalTo(ClientOptions.DEFAULT_MAX_BACKOFF)); } @Test - public void testAwsAuthenticationOptionsNotNull() throws JsonProcessingException { - final String config = " function_name: test_function\n" + " aws:\n" + " region: ap-south-1\n" + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + " sts_header_overrides: {\"test\":\"test\"}\n" + " max_retries: 10\n"; + public void testAwsAuthenticationOptionsAndClientOptions() throws JsonProcessingException { + final String config = "function_name: test_function\n" + + "aws:\n" + + " region: ap-south-1\n" + + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + + " sts_header_overrides: {\"test\":\"test\"}\n" + + "client:\n" + + " max_retries: 10\n" + + " api_call_timeout: 120\n" + + " connection_timeout: 30\n" + + " max_concurrency: 150\n" + + " base_delay: 0.2\n" + + " max_backoff: 30\n"; final LambdaProcessorConfig lambdaProcessorConfig = objectMapper.readValue(config, LambdaProcessorConfig.class); - assertThat(lambdaProcessorConfig.getMaxConnectionRetries(), equalTo(10)); + assertThat(lambdaProcessorConfig.getAwsAuthenticationOptions(), notNullValue()); assertThat(lambdaProcessorConfig.getAwsAuthenticationOptions().getAwsRegion(), equalTo(Region.AP_SOUTH_1)); assertThat(lambdaProcessorConfig.getAwsAuthenticationOptions().getAwsStsRoleArn(), equalTo("arn:aws:iam::524239988912:role/app-test")); assertThat(lambdaProcessorConfig.getAwsAuthenticationOptions().getAwsStsHeaderOverrides().get("test"), equalTo("test")); + + assertThat(lambdaProcessorConfig.getClientOptions(), notNullValue()); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxConnectionRetries(), equalTo(10)); + assertThat(lambdaProcessorConfig.getClientOptions().getApiCallTimeout(), equalTo(Duration.ofSeconds(120))); + assertThat(lambdaProcessorConfig.getClientOptions().getConnectionTimeout(), equalTo(Duration.ofSeconds(30))); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxConcurrency(), equalTo(150)); + assertThat(lambdaProcessorConfig.getClientOptions().getBaseDelay(), equalTo(Duration.ofMillis(200))); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxBackoff(), equalTo(Duration.ofSeconds(30))); + } + + @Test + public void testPartialClientOptions() throws JsonProcessingException { + final String config = "function_name: test_function\n" + + "aws:\n" + + " region: us-west-2\n" + + "client:\n" + + " max_retries: 5\n" + + " connection_timeout: 45\n"; + final LambdaProcessorConfig lambdaProcessorConfig = objectMapper.readValue(config, LambdaProcessorConfig.class); + + assertThat(lambdaProcessorConfig.getClientOptions(), notNullValue()); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxConnectionRetries(), equalTo(5)); + assertThat(lambdaProcessorConfig.getClientOptions().getConnectionTimeout(), equalTo(Duration.ofSeconds(45))); + // Assert defaults for unspecified options + assertThat(lambdaProcessorConfig.getClientOptions().getApiCallTimeout(), equalTo(ClientOptions.DEFAULT_API_TIMEOUT)); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxConcurrency(), equalTo(ClientOptions.DEFAULT_MAXIMUM_CONCURRENCY)); + assertThat(lambdaProcessorConfig.getClientOptions().getBaseDelay(), equalTo(ClientOptions.DEFAULT_BASE_DELAY)); + assertThat(lambdaProcessorConfig.getClientOptions().getMaxBackoff(), equalTo(ClientOptions.DEFAULT_MAX_BACKOFF)); } } \ No newline at end of file diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java index ced8020b7c..15749b853e 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorTest.java @@ -5,18 +5,15 @@ package org.opensearch.dataprepper.plugins.lambda.processor; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Timer; import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import org.mockito.Captor; -import org.mockito.Mock; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -24,6 +21,30 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA; +import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA; + + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; @@ -43,34 +64,18 @@ import org.opensearch.dataprepper.model.types.ByteCount; import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED; -import static org.opensearch.dataprepper.plugins.lambda.processor.LambdaProcessor.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import java.io.InputStream; -import java.lang.reflect.Field; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Consumer; - @MockitoSettings(strictness = Strictness.LENIENT) public class LambdaProcessorTest { - // Mock dependencies @Mock private AwsAuthenticationOptions awsAuthenticationOptions; @@ -105,8 +110,13 @@ public class LambdaProcessorTest { @Mock private Counter numberOfRecordsSuccessCounter; + @Mock + private Counter numberOfRequestsSuccessCounter; + @Mock private Counter numberOfRecordsFailedCounter; + @Mock + private Counter numberOfRequestsFailedCounter; @Mock private InvokeResponse invokeResponse; @@ -125,28 +135,34 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); // Mock PluginMetrics counters and timers - when(pluginMetrics.counter(eq(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS))).thenReturn(numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(eq(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED))).thenReturn(numberOfRecordsFailedCounter); + when(pluginMetrics.counter(eq(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS))).thenReturn( + numberOfRecordsSuccessCounter); + when(pluginMetrics.counter(eq(NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED))).thenReturn( + numberOfRecordsFailedCounter); + when(pluginMetrics.counter(eq(NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA))).thenReturn( + numberOfRecordsSuccessCounter); + when(pluginMetrics.counter(eq(NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA))).thenReturn( + numberOfRecordsFailedCounter); when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); - when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenAnswer(invocation -> invocation.getArgument(1)); + when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenAnswer( + invocation -> invocation.getArgument(1)); - // Mock LambdaProcessorConfig + ClientOptions clientOptions = new ClientOptions(); + when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions); when(lambdaProcessorConfig.getFunctionName()).thenReturn("test-function"); - when(lambdaProcessorConfig.getWhenCondition()).thenReturn(null); + when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); when(lambdaProcessorConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE); + BatchOptions batchOptions = mock(BatchOptions.class); + ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); + when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); + when(lambdaProcessorConfig.getWhenCondition()).thenReturn(null); when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); - when(lambdaProcessorConfig.getMaxConnectionRetries()).thenReturn(3); - when(lambdaProcessorConfig.getConnectionTimeout()).thenReturn(Duration.ofSeconds(5)); // Mock AWS Authentication Options - when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn("testRole"); // Mock BatchOptions and ThresholdOptions - BatchOptions batchOptions = mock(BatchOptions.class); - ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); - when(lambdaProcessorConfig.getBatchOptions()).thenReturn(batchOptions); when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); when(thresholdOptions.getEventCount()).thenReturn(10); when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("6mb")); @@ -158,26 +174,20 @@ public void setUp() throws Exception { PluginSetting responseCodecPluginSetting; if (responseCodecConfig == null) { - // Default to JsonInputCodec with default settings - responseCodecPluginSetting = new PluginSetting("json", Collections.emptyMap()); + // Default to JsonInputCodec with default settings + responseCodecPluginSetting = new PluginSetting("json", Collections.emptyMap()); } else { - responseCodecPluginSetting = new PluginSetting(responseCodecConfig.getPluginName(), responseCodecConfig.getPluginSettings()); + responseCodecPluginSetting = new PluginSetting(responseCodecConfig.getPluginName(), + responseCodecConfig.getPluginSettings()); } // Mock PluginFactory to return the mocked responseCodec - when(pluginFactory.loadPlugin(eq(InputCodec.class), any(PluginSetting.class))).thenReturn(responseCodec); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any(PluginSetting.class))).thenReturn( + responseCodec); // Instantiate the LambdaProcessor manually - lambdaProcessor = new LambdaProcessor(pluginFactory, pluginMetrics, lambdaProcessorConfig, awsCredentialsSupplier, expressionEvaluator); - - // Inject mocks into the LambdaProcessor using reflection - populatePrivateFields(); - - // Mock LambdaCommonHandler behavior - when(lambdaCommonHandler.createBuffer(any(BufferFactory.class))).thenReturn(bufferMock); - - // Mock Buffer behavior for flushToLambda - when(bufferMock.flushToLambda(anyString())).thenReturn(CompletableFuture.completedFuture(invokeResponse)); + lambdaProcessor = new LambdaProcessor(pluginFactory, pluginMetrics, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); // Mock InvokeResponse when(invokeResponse.payload()).thenReturn(SdkBytes.fromUtf8String("[{\"key\":\"value\"}]")); @@ -188,12 +198,10 @@ public void setUp() throws Exception { setPrivateField(lambdaProcessor, "lambdaAsyncClient", lambdaAsyncClientMock); // Mock the invoke method to return a completed future - CompletableFuture invokeFuture = CompletableFuture.completedFuture(invokeResponse); + CompletableFuture invokeFuture = CompletableFuture.completedFuture( + invokeResponse); when(lambdaAsyncClientMock.invoke(any(InvokeRequest.class))).thenReturn(invokeFuture); - // Mock the checkStatusCode method - when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); - // Mock Response Codec parse method doNothing().when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); @@ -202,19 +210,53 @@ public void setUp() throws Exception { private void populatePrivateFields() throws Exception { List tagsOnMatchFailure = Collections.singletonList("failure_tag"); // Use reflection to set the private fields - setPrivateField(lambdaProcessor, "numberOfRecordsSuccessCounter", numberOfRecordsSuccessCounter); + setPrivateField(lambdaProcessor, "numberOfRecordsSuccessCounter", + numberOfRecordsSuccessCounter); + setPrivateField(lambdaProcessor, "numberOfRequestsSuccessCounter", + numberOfRequestsSuccessCounter); setPrivateField(lambdaProcessor, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); + setPrivateField(lambdaProcessor, "numberOfRequestsFailedCounter", numberOfRequestsFailedCounter); setPrivateField(lambdaProcessor, "tagsOnMatchFailure", tagsOnMatchFailure); setPrivateField(lambdaProcessor, "lambdaCommonHandler", lambdaCommonHandler); } // Helper method to set private fields via reflection - private void setPrivateField(Object targetObject, String fieldName, Object value) throws Exception { + private void setPrivateField(Object targetObject, String fieldName, Object value) + throws Exception { Field field = targetObject.getClass().getDeclaredField(fieldName); field.setAccessible(true); field.set(targetObject, value); } + @Test + public void testProcessorDefaults() { + // Create a new LambdaProcessorConfig with default values + LambdaProcessorConfig defaultConfig = new LambdaProcessorConfig(); + + // Test default values + assertNull(defaultConfig.getFunctionName()); + assertNull(defaultConfig.getAwsAuthenticationOptions()); + assertNull(defaultConfig.getResponseCodecConfig()); + assertEquals(InvocationType.REQUEST_RESPONSE, defaultConfig.getInvocationType()); + assertFalse(defaultConfig.getResponseEventsMatch()); + assertNull(defaultConfig.getWhenCondition()); + assertTrue(defaultConfig.getTagsOnFailure().isEmpty()); + + // Test ClientOptions defaults + ClientOptions clientOptions = defaultConfig.getClientOptions(); + assertNotNull(clientOptions); + assertEquals(ClientOptions.DEFAULT_CONNECTION_RETRIES, clientOptions.getMaxConnectionRetries()); + assertEquals(ClientOptions.DEFAULT_API_TIMEOUT, clientOptions.getApiCallTimeout()); + assertEquals(ClientOptions.DEFAULT_CONNECTION_TIMEOUT, clientOptions.getConnectionTimeout()); + assertEquals(ClientOptions.DEFAULT_MAXIMUM_CONCURRENCY, clientOptions.getMaxConcurrency()); + assertEquals(ClientOptions.DEFAULT_BASE_DELAY, clientOptions.getBaseDelay()); + assertEquals(ClientOptions.DEFAULT_MAX_BACKOFF, clientOptions.getMaxBackoff()); + + // Test BatchOptions defaults + BatchOptions batchOptions = defaultConfig.getBatchOptions(); + assertNotNull(batchOptions); + } + @Test public void testDoExecute_WithExceptionDuringProcessing() throws Exception { // Arrange @@ -222,16 +264,8 @@ public void testDoExecute_WithExceptionDuringProcessing() throws Exception { Record record = new Record<>(event); List> records = Collections.singletonList(record); - // Mock Buffer - Buffer bufferMock = mock(Buffer.class); - when(lambdaProcessor.lambdaCommonHandler.createBuffer(any(BufferFactory.class))).thenReturn(bufferMock); - when(bufferMock.getEventCount()).thenReturn(0, 1); - when(bufferMock.getRecords()).thenReturn(records); - doNothing().when(bufferMock).reset(); - - // Mock exception during flush - when(bufferMock.flushToLambda(any())).thenThrow(new RuntimeException("Test exception")); - + // make batch options null to generate exception + when(lambdaProcessorConfig.getBatchOptions()).thenReturn(null); // Act Collection> result = lambdaProcessor.doExecute(records); @@ -306,15 +340,14 @@ public void testDoExecute_WhenConditionFalse() { when(lambdaProcessorConfig.getWhenCondition()).thenReturn("some_condition"); // Instantiate the LambdaProcessor manually - lambdaProcessor = new LambdaProcessor(pluginFactory, pluginMetrics, lambdaProcessorConfig, awsCredentialsSupplier, expressionEvaluator); + lambdaProcessor = new LambdaProcessor(pluginFactory, pluginMetrics, lambdaProcessorConfig, + awsCredentialsSupplier, expressionEvaluator); // Act Collection> result = lambdaProcessor.doExecute(records); // Assert assertEquals(1, result.size(), "Result should contain one record as the condition is false."); - verify(lambdaCommonHandler, never()).createBuffer(any(BufferFactory.class)); - verify(bufferMock, never()).flushToLambda(anyString()); verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble()); verify(numberOfRecordsFailedCounter, never()).increment(anyDouble()); } @@ -335,26 +368,25 @@ public void testDoExecute_SuccessfulProcessing() throws Exception { setPrivateField(lambdaProcessor, "lambdaAsyncClient", lambdaAsyncClientMock); // Mock the invoke method to return a completed future - CompletableFuture invokeFuture = CompletableFuture.completedFuture(invokeResponse); + CompletableFuture invokeFuture = CompletableFuture.completedFuture( + invokeResponse); when(lambdaAsyncClientMock.invoke(any(InvokeRequest.class))).thenReturn(invokeFuture); - // Mock Buffer behavior when(bufferMock.getEventCount()).thenReturn(0).thenReturn(1).thenReturn(0); when(bufferMock.getRecords()).thenReturn(Collections.singletonList(record)); - doNothing().when(bufferMock).reset(); doAnswer(invocation -> { - InputStream inputStream = invocation.getArgument(0); - @SuppressWarnings("unchecked") - Consumer> consumer = invocation.getArgument(1); + InputStream inputStream = invocation.getArgument(0); + @SuppressWarnings("unchecked") + Consumer> consumer = invocation.getArgument(1); - // Simulate parsing by providing a mocked event - Event parsedEvent = mock(Event.class); - Record parsedRecord = new Record<>(parsedEvent); - consumer.accept(parsedRecord); + // Simulate parsing by providing a mocked event + Event parsedEvent = mock(Event.class); + Record parsedRecord = new Record<>(parsedEvent); + consumer.accept(parsedRecord); - return null; + return null; }).when(responseCodec).parse(any(InputStream.class), any(Consumer.class)); // Act @@ -363,25 +395,6 @@ public void testDoExecute_SuccessfulProcessing() throws Exception { // Assert assertEquals(1, result.size(), "Result should contain one record."); verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); - }; - - - @Test - public void testHandleFailure() { - // Arrange - Event event = mock(Event.class); - Buffer bufferMock = mock(Buffer.class); - List> records = List.of(new Record<>(event)); - when(bufferMock.getEventCount()).thenReturn(1); - when(bufferMock.getRecords()).thenReturn(records); - - // Act - lambdaProcessor.handleFailure(new RuntimeException("Test Exception"), bufferMock, records); - - // Assert - verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); - // Ensure failure tags are added; assuming addFailureTags is implemented correctly - // You might need to verify interactions with event metadata if it's mocked } @Test @@ -423,8 +436,7 @@ public void testConvertLambdaResponseToEvent_WithEqualEventCounts_SuccessfulProc when(bufferMock.getEventCount()).thenReturn(2); // Act - List> resultRecords = new ArrayList<>(); - lambdaProcessor.convertLambdaResponseToEvent(resultRecords, invokeResponse, bufferMock, responseCodec); + List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, invokeResponse); // Assert assertEquals(2, resultRecords.size(), "ResultRecords should contain two records."); @@ -485,12 +497,10 @@ public void testConvertLambdaResponseToEvent_WithUnequalEventCounts_SuccessfulPr when(bufferMock.getEventCount()).thenReturn(2); // Act - List> resultRecords = new ArrayList<>(); - lambdaProcessor.convertLambdaResponseToEvent(resultRecords, invokeResponse, bufferMock, responseCodec); - + List> resultRecords = lambdaProcessor.convertLambdaResponseToEvent(bufferMock, invokeResponse); // Assert // Verify that three records are added to the result assertEquals(3, resultRecords.size(), "ResultRecords should contain three records."); } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinalityTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinalityTest.java deleted file mode 100644 index 9179c7832c..0000000000 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/ResponseCardinalityTest.java +++ /dev/null @@ -1,25 +0,0 @@ -package org.opensearch.dataprepper.plugins.lambda.processor; - -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class ResponseCardinalityTest { - - @Test - public void testFromStringWithValidValue() { - assertEquals(ResponseCardinality.STRICT, ResponseCardinality.fromString("strict")); - assertEquals(ResponseCardinality.AGGREGATE, ResponseCardinality.fromString("aggregate")); - } - - @Test - public void testFromStringWithInvalidValue() { - assertEquals(null, ResponseCardinality.fromString("invalid-value")); - } - - @Test - public void testGetValue() { - assertEquals("strict", ResponseCardinality.STRICT.getValue()); - assertEquals("aggregate", ResponseCardinality.AGGREGATE.getValue()); - } -} - diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java index 4da3b91c5d..9b3fc4b35b 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/processor/StrictResponseEventHandlingStrategyTest.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.dataprepper.plugins.lambda.processor; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfigTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfigTest.java index 2a6dad3a69..a2476bce47 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfigTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkConfigTest.java @@ -2,29 +2,71 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ + package org.opensearch.dataprepper.plugins.lambda.sink; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; import software.amazon.awssdk.regions.Region; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import java.time.Duration; + class LambdaSinkConfigTest { - public static final int DEFAULT_MAX_RETRIES = 3; - private ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory().enable(YAMLGenerator.Feature.USE_PLATFORM_LINE_BREAKS)); + private ObjectMapper objectMapper; + @BeforeEach + void setUp() { + this.objectMapper = new ObjectMapper(new YAMLFactory()); + this.objectMapper.registerModule(new JavaTimeModule()); + } + @Test - void lambda_sink_default_max_connection_retries_test(){ - MatcherAssert.assertThat(new LambdaSinkConfig().getMaxConnectionRetries(),equalTo(DEFAULT_MAX_RETRIES)); + void test_defaults(){ + MatcherAssert.assertThat(new LambdaSinkConfig().getDlq(),equalTo(null)); + } + + @Test + void lambda_sink_pipeline_config_test_with_client_options() throws JsonProcessingException { + final String config = + "function_name: test_function\n" + + "aws:\n" + + " region: ap-south-1\n" + + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + + " sts_header_overrides: {\"test\":\"test\"}\n" + + "client:\n" + + " max_retries: 5\n" + + " api_call_timeout: PT120S\n" + + " connection_timeout: PT30S\n" + + " max_concurrency: 150\n" + + " max_backoff: PT30S\n"; + + final LambdaSinkConfig lambdaSinkConfig = objectMapper.readValue(config, LambdaSinkConfig.class); + + // Assert AWS authentication options + assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsRegion(), equalTo(Region.AP_SOUTH_1)); + assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsRoleArn(), equalTo("arn:aws:iam::524239988912:role/app-test")); + assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsHeaderOverrides().get("test"), equalTo("test")); + + // Assert ClientOptions + ClientOptions clientOptions = lambdaSinkConfig.getClientOptions(); + assertThat(clientOptions.getMaxConnectionRetries(), equalTo(5)); + assertThat(clientOptions.getApiCallTimeout(), equalTo(Duration.ofSeconds(120))); + assertThat(clientOptions.getConnectionTimeout(), equalTo(Duration.ofSeconds(30))); + assertThat(clientOptions.getMaxConcurrency(), equalTo(150)); + assertThat(clientOptions.getMaxBackoff(), equalTo(Duration.ofSeconds(30))); } + @Test void lambda_sink_pipeline_config_test() throws JsonProcessingException { final String config = @@ -33,7 +75,6 @@ void lambda_sink_pipeline_config_test() throws JsonProcessingException { " region: ap-south-1\n" + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + " sts_header_overrides: {\"test\":\"test\"}\n" + - " max_retries: 10\n" + " dlq:\n" + " s3:\n" + " bucket: test\n" + @@ -41,7 +82,6 @@ void lambda_sink_pipeline_config_test() throws JsonProcessingException { " region: ap-south-1\n" + " sts_role_arn: test-role-arn\n"; final LambdaSinkConfig lambdaSinkConfig = objectMapper.readValue(config, LambdaSinkConfig.class); - assertThat(lambdaSinkConfig.getMaxConnectionRetries(),equalTo(10)); assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsRegion(),equalTo(Region.AP_SOUTH_1)); assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsRoleArn(),equalTo("arn:aws:iam::524239988912:role/app-test")); assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsHeaderOverrides().get("test"),equalTo("test")); @@ -50,21 +90,41 @@ void lambda_sink_pipeline_config_test() throws JsonProcessingException { } @Test - void lambda_sink_pipeline_config_test_with_no_dlq() throws JsonProcessingException { + void lambda_sink_pipeline_config_test_with_no_explicit_aws_config() throws JsonProcessingException { final String config = - " function_name: test_function\n" + + " function_name: test_function\n" + " aws:\n" + " region: ap-south-1\n" + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + " sts_header_overrides: {\"test\":\"test\"}\n" + - " max_retries: 10\n"; + " dlq:\n" + + " s3:\n" + + " bucket: test\n" + + " key_path_prefix: test\n"; final LambdaSinkConfig lambdaSinkConfig = objectMapper.readValue(config, LambdaSinkConfig.class); - assertThat(lambdaSinkConfig.getMaxConnectionRetries(),equalTo(10)); + assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsRegion(),equalTo(Region.AP_SOUTH_1)); assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsRoleArn(),equalTo("arn:aws:iam::524239988912:role/app-test")); assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsHeaderOverrides().get("test"),equalTo("test")); assertThat(lambdaSinkConfig.getDlqStsRegion(),equalTo("ap-south-1")); assertThat(lambdaSinkConfig.getDlqStsRoleARN(),equalTo("arn:aws:iam::524239988912:role/app-test")); - assertThat(lambdaSinkConfig.getDlqPluginSetting().get("key"),equalTo(null)); + } + + @Test + void lambda_sink_pipeline_config_test_with_no_dlq() throws JsonProcessingException { + final String config = + " function_name: test_function\n" + + " aws:\n" + + " region: ap-south-1\n" + + " sts_role_arn: arn:aws:iam::524239988912:role/app-test\n" + + " sts_header_overrides: {\"test\":\"test\"}\n" ; + final LambdaSinkConfig lambdaSinkConfig = objectMapper.readValue(config, LambdaSinkConfig.class); + assertThat(lambdaSinkConfig.getDlq(),equalTo(null)); + assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsRegion(),equalTo(Region.AP_SOUTH_1)); + assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsRoleArn(),equalTo("arn:aws:iam::524239988912:role/app-test")); + assertThat(lambdaSinkConfig.getAwsAuthenticationOptions().getAwsStsHeaderOverrides().get("test"),equalTo("test")); + assertThat(lambdaSinkConfig.getDlqStsRegion(),equalTo(null)); + assertThat(lambdaSinkConfig.getDlqStsRoleARN(),equalTo(null)); + assertThat(lambdaSinkConfig.getDlqPluginSetting(),equalTo(null)); } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java deleted file mode 100644 index 1c7b7df53d..0000000000 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkServiceTest.java +++ /dev/null @@ -1,271 +0,0 @@ -package org.opensearch.dataprepper.plugins.lambda.sink; - -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Timer; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import org.mockito.MockitoAnnotations; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; -import org.opensearch.dataprepper.expression.ExpressionEvaluator; -import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.codec.OutputCodec; -import org.opensearch.dataprepper.model.configuration.PluginSetting; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.event.EventHandle; -import org.opensearch.dataprepper.model.event.EventMetadata; -import org.opensearch.dataprepper.model.plugin.PluginFactory; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.sink.OutputCodecContext; -import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; -import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; -import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; -import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; -import software.amazon.awssdk.services.lambda.model.InvokeResponse; - -import java.io.IOException; -import java.lang.reflect.Field; -import java.time.Duration; -import java.util.Collection; -import java.util.Collections; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicLong; - -public class LambdaSinkServiceTest { - - @Mock - private LambdaAsyncClient lambdaAsyncClient; - - @Mock - private LambdaSinkConfig lambdaSinkConfig; - - @Mock - private PluginMetrics pluginMetrics; - - @Mock - private PluginFactory pluginFactory; - - @Mock - private PluginSetting pluginSetting; - - @Mock - private OutputCodecContext codecContext; - - @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; - - @Mock - private DlqPushHandler dlqPushHandler; - - @Mock - private BufferFactory bufferFactory; - - @Mock - private ExpressionEvaluator expressionEvaluator; - - @Mock - private Counter numberOfRecordsSuccessCounter; - - @Mock - private Counter numberOfRecordsFailedCounter; - - @Mock - private Timer lambdaLatencyMetric; - - @Mock - private OutputCodec requestCodec; - - @Mock - private Buffer currentBufferPerBatch; - - @Mock - private LambdaCommonHandler lambdaCommonHandler; - - @Mock - private Event event; - - @Mock - private EventHandle eventHandle; - - @Mock - private EventMetadata eventMetadata; - - @Mock - private InvokeResponse invokeResponse; - - private LambdaSinkService lambdaSinkService; - - @BeforeEach - public void setUp() { - MockitoAnnotations.openMocks(this); - - // Mock PluginMetrics counters and timers - when(pluginMetrics.counter("lambdaSinkObjectsEventsSucceeded")).thenReturn(numberOfRecordsSuccessCounter); - when(pluginMetrics.counter("lambdaSinkObjectsEventsFailed")).thenReturn(numberOfRecordsFailedCounter); - when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); - when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenReturn(new AtomicLong()); - - // Mock lambdaSinkConfig - when(lambdaSinkConfig.getFunctionName()).thenReturn("test-function"); - when(lambdaSinkConfig.getWhenCondition()).thenReturn(null); - when(lambdaSinkConfig.getInvocationType()).thenReturn(InvocationType.EVENT); - - // Mock BatchOptions and ThresholdOptions - BatchOptions batchOptions = mock(BatchOptions.class); - ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); - when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(thresholdOptions.getEventCount()).thenReturn(10); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("1mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(1)); - - // Mock JsonOutputCodec - requestCodec = mock(JsonOutputCodec.class); - when(pluginFactory.loadPlugin(eq(OutputCodec.class), any(PluginSetting.class))).thenReturn(requestCodec); - - // Initialize bufferFactory and buffer - bufferFactory = mock(BufferFactory.class); - currentBufferPerBatch = mock(Buffer.class); - when(currentBufferPerBatch.getEventCount()).thenReturn(0); - - // Mock LambdaCommonHandler - lambdaCommonHandler = mock(LambdaCommonHandler.class); - when(lambdaCommonHandler.createBuffer(bufferFactory)).thenReturn(currentBufferPerBatch); - doNothing().when(currentBufferPerBatch).reset(); - - lambdaSinkService = new LambdaSinkService( - lambdaAsyncClient, - lambdaSinkConfig, - pluginMetrics, - pluginFactory, - pluginSetting, - codecContext, - awsCredentialsSupplier, - dlqPushHandler, - bufferFactory, - expressionEvaluator - ); - - // Set private fields - setPrivateField(lambdaSinkService, "lambdaCommonHandler", lambdaCommonHandler); - setPrivateField(lambdaSinkService, "requestCodec", requestCodec); - setPrivateField(lambdaSinkService, "currentBufferPerBatch", currentBufferPerBatch); - } - - // Helper method to set private fields via reflection - private void setPrivateField(Object targetObject, String fieldName, Object value) { - try { - Field field = targetObject.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - field.set(targetObject, value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Test - public void testOutput_SuccessfulProcessing() throws Exception { - Event event = mock(Event.class); - Record record = new Record<>(event); - Collection> records = Collections.singletonList(record); - - when(expressionEvaluator.evaluateConditional(anyString(), eq(event))).thenReturn(true); - when(lambdaSinkConfig.getWhenCondition()).thenReturn(null); - when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); - doNothing().when(requestCodec).start(any(), eq(event), any()); - doNothing().when(requestCodec).writeEvent(eq(event), any()); - doNothing().when(currentBufferPerBatch).addRecord(eq(record)); - when(currentBufferPerBatch.getEventCount()).thenReturn(1); - when(currentBufferPerBatch.getSize()).thenReturn(100L); - when(currentBufferPerBatch.getDuration()).thenReturn(Duration.ofMillis(500)); - CompletableFuture future = CompletableFuture.completedFuture(invokeResponse); - when(currentBufferPerBatch.flushToLambda(any())).thenReturn(future); - when(invokeResponse.statusCode()).thenReturn(202); - when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); - doNothing().when(lambdaLatencyMetric).record(any(Duration.class)); - - lambdaSinkService.output(records); - - verify(currentBufferPerBatch, times(1)).addRecord(eq(record)); - verify(currentBufferPerBatch, times(1)).flushToLambda(any()); - verify(lambdaCommonHandler, times(1)).checkStatusCode(eq(invokeResponse)); - verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); - } - - @Test - public void testHandleFailure_WithDlq() { - Throwable throwable = new RuntimeException("Test Exception"); - SdkBytes payload = SdkBytes.fromUtf8String("test payload"); - when(currentBufferPerBatch.getEventCount()).thenReturn(1); - when(currentBufferPerBatch.getPayload()).thenReturn(payload); - - lambdaSinkService.handleFailure(throwable, currentBufferPerBatch); - - verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); - verify(dlqPushHandler, times(1)).perform(eq(pluginSetting), any(LambdaSinkFailedDlqData.class)); - } - - @Test - public void testHandleFailure_WithoutDlq() { - setPrivateField(lambdaSinkService, "dlqPushHandler", null); - Throwable throwable = new RuntimeException("Test Exception"); - when(currentBufferPerBatch.getEventCount()).thenReturn(1); - - lambdaSinkService.handleFailure(throwable, currentBufferPerBatch); - - verify(numberOfRecordsFailedCounter, times(1)).increment(1); - verify(dlqPushHandler, never()).perform(any(), any()); - } - - @Test - public void testOutput_ExceptionDuringProcessing() throws Exception { - // Arrange - Record record = new Record<>(event); - Collection> records = Collections.singletonList(record); - - // Mock whenCondition evaluation - when(expressionEvaluator.evaluateConditional(anyString(), eq(event))).thenReturn(true); - when(lambdaSinkConfig.getWhenCondition()).thenReturn(null); - - // Mock event handling to throw exception when writeEvent is called - when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); - when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); - doNothing().when(requestCodec).start(any(), eq(event), any()); - doThrow(new IOException("Test IOException")).when(requestCodec).writeEvent(eq(event), any()); - - // Mock buffer reset - doNothing().when(currentBufferPerBatch).reset(); - - // Mock flushToLambda to prevent NullPointerException - CompletableFuture future = CompletableFuture.completedFuture(invokeResponse); - when(currentBufferPerBatch.flushToLambda(any())).thenReturn(future); - - // Act - lambdaSinkService.output(records); - - // Assert - verify(requestCodec, times(1)).start(any(), eq(event), any()); - verify(requestCodec, times(1)).writeEvent(eq(event), any()); - verify(numberOfRecordsFailedCounter, times(1)).increment(1); - } - - -} diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java new file mode 100644 index 0000000000..185e781b0b --- /dev/null +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java @@ -0,0 +1,255 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.lambda.sink; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.Collections; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.expression.ExpressionEvaluator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.codec.OutputCodec; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventHandle; +import org.opensearch.dataprepper.model.event.EventMetadata; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.record.RecordMetadata; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; +import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; +import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; +import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; +import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; +import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; + +public class LambdaSinkTest { + + @Mock + SinkContext sinkContext; + @Mock + private LambdaAsyncClient lambdaAsyncClient; + @Mock + private LambdaSinkConfig lambdaSinkConfig; + @Mock + private PluginMetrics pluginMetrics; + @Mock + private PluginFactory pluginFactory; + + private PluginSetting pluginSetting; + @Mock + private OutputCodecContext codecContext; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + @Mock + private DlqPushHandler dlqPushHandler; + @Mock + private ExpressionEvaluator expressionEvaluator; + @Mock + private Counter numberOfRecordsSuccessCounter; + @Mock + private Counter numberOfRecordsFailedCounter; + @Mock + private Timer lambdaLatencyMetric; + @Mock + private OutputCodec requestCodec; + @Mock + private Buffer currentBufferPerBatch; + @Mock + private Event event; + @Mock + private EventHandle eventHandle; + @Mock + private EventMetadata eventMetadata; + @Mock + private InvokeResponse invokeResponse; + + private LambdaSink lambdaSink; + + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + + public static Record getSampleRecord() { + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + return new Record<>(event, RecordMetadata.defaultMetadata()); + } + + @BeforeEach + public void setUp() { + MockitoAnnotations.openMocks(this); + + // Mock PluginMetrics counters and timers + when(pluginMetrics.counter("lambdaSinkObjectsEventsSucceeded")).thenReturn( + numberOfRecordsSuccessCounter); + when(pluginMetrics.counter("lambdaSinkObjectsEventsFailed")).thenReturn( + numberOfRecordsFailedCounter); + when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); + when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenReturn(new AtomicLong()); + + // Mock lambdaSinkConfig + when(lambdaSinkConfig.getFunctionName()).thenReturn("test-function"); + when(lambdaSinkConfig.getInvocationType()).thenReturn(InvocationType.EVENT); + + // Mock BatchOptions and ThresholdOptions + BatchOptions batchOptions = mock(BatchOptions.class); + ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); + when(batchOptions.getKeyName()).thenReturn("test"); + when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + when(thresholdOptions.getEventCount()).thenReturn(10); + when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("1mb")); + when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(1)); + + // Mock JsonOutputCodec + requestCodec = mock(JsonOutputCodec.class); + when(pluginFactory.loadPlugin(eq(OutputCodec.class), any(PluginSetting.class))).thenReturn( + requestCodec); + + // Initialize bufferFactory and buffer + currentBufferPerBatch = mock(Buffer.class); + when(currentBufferPerBatch.getEventCount()).thenReturn(0); + when(lambdaSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of("us-east-1")); + this.pluginSetting = new PluginSetting("aws_lambda", Collections.emptyMap()); + this.pluginSetting.setPipelineName(UUID.randomUUID().toString()); + this.awsAuthenticationOptions = new AwsAuthenticationOptions(); + + ClientOptions clientOptions = new ClientOptions(); + when(lambdaSinkConfig.getClientOptions()).thenReturn(clientOptions); + + this.lambdaSink = new LambdaSink(pluginSetting, lambdaSinkConfig, pluginFactory, sinkContext, + awsCredentialsSupplier, expressionEvaluator); + } + + /* + @Test + public void testOutput_SuccessfulProcessing() throws Exception { + Event event = mock(Event.class); + Record record = new Record<>(event); + Collection> records = Collections.singletonList(record); + + when(expressionEvaluator.evaluateConditional(anyString(), eq(event))).thenReturn(true); + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); + doNothing().when(requestCodec).start(any(), eq(event), any()); + doNothing().when(requestCodec).writeEvent(eq(event), any()); + doNothing().when(currentBufferPerBatch).addRecord(eq(record)); + when(currentBufferPerBatch.getEventCount()).thenReturn(1); + when(currentBufferPerBatch.getSize()).thenReturn(100L); + when(currentBufferPerBatch.getDuration()).thenReturn(Duration.ofMillis(500)); + CompletableFuture future = CompletableFuture.completedFuture(invokeResponse); + when(currentBufferPerBatch.flushToLambda(any())).thenReturn(future); + when(invokeResponse.statusCode()).thenReturn(202); + when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); + doNothing().when(lambdaLatencyMetric).record(any(Duration.class)); + + lambdaSinkService.output(records); + + verify(currentBufferPerBatch, times(1)).addRecord(eq(record)); + verify(currentBufferPerBatch, times(1)).flushToLambda(any()); + verify(lambdaCommonHandler, times(1)).checkStatusCode(eq(invokeResponse)); + verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); + } + + */ + + // Helper method to set private fields via reflection + private void setPrivateField(Object targetObject, String fieldName, Object value) { + try { + Field field = targetObject.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(targetObject, value); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testHandleFailure_WithDlq() { + Throwable throwable = new RuntimeException("Test Exception"); + Buffer buffer = new InMemoryBuffer(UUID.randomUUID().toString()); + buffer.addRecord(getSampleRecord()); + setPrivateField(lambdaSink, "dlqPushHandler", dlqPushHandler); + setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); + lambdaSink.handleFailure(throwable, buffer); + verify(numberOfRecordsFailedCounter, times(1)).increment(1.0); + verify(dlqPushHandler, times(1)).perform(eq(pluginSetting), any(LambdaSinkFailedDlqData.class)); + } + + @Test + public void testHandleFailure_WithoutDlq() { + Throwable throwable = new RuntimeException("Test Exception"); + Buffer buffer = new InMemoryBuffer(UUID.randomUUID().toString()); + buffer.addRecord(getSampleRecord()); + when(lambdaSinkConfig.getDlqPluginSetting()).thenReturn(null); + setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); + lambdaSink.handleFailure(throwable, buffer); + verify(numberOfRecordsFailedCounter, times(1)).increment(1); + verify(dlqPushHandler, never()).perform(any(), any()); + } + + /* + @Test + public void testOutput_ExceptionDuringProcessing() throws Exception { + // Arrange + Record record = new Record<>(event); + Collection> records = Collections.singletonList(record); + + // Mock whenCondition evaluation + when(expressionEvaluator.evaluateConditional(anyString(), eq(event))).thenReturn(true); + + // Mock event handling to throw exception when writeEvent is called + when(currentBufferPerBatch.getEventCount()).thenReturn(0).thenReturn(1); + when(lambdaCommonHandler.checkStatusCode(any())).thenReturn(true); + doNothing().when(requestCodec).start(any(), eq(event), any()); + doThrow(new IOException("Test IOException")).when(requestCodec).writeEvent(eq(event), any()); + + // Mock buffer reset + doNothing().when(currentBufferPerBatch).reset(); + + // Mock flushToLambda to prevent NullPointerException + CompletableFuture future = CompletableFuture.completedFuture(invokeResponse); + when(currentBufferPerBatch.flushToLambda(any())).thenReturn(future); + + // Act + lambdaSinkService.output(records); + + // Assert + verify(requestCodec, times(1)).start(any(), eq(event), any()); + verify(requestCodec, times(1)).writeEvent(eq(event), any()); + verify(numberOfRecordsFailedCounter, times(1)).increment(1); + } + */ + + +}