Skip to content

Commit

Permalink
Addressed review comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Kondaka <[email protected]>
  • Loading branch information
kkondaka committed Nov 19, 2024
1 parent 8645c1e commit 482f65c
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ public class PluginMetrics {

private final String metricsPrefix;

public static PluginMetrics fromPluginSetting(final PluginSetting pluginSetting) {
public static PluginMetrics fromPluginSetting(final PluginSetting pluginSetting, final String name) {
if(pluginSetting.getPipelineName() == null) {
throw new IllegalArgumentException("PluginSetting.pipelineName must not be null");
}
return PluginMetrics.fromNames(pluginSetting.getName(), pluginSetting.getPipelineName());
return PluginMetrics.fromNames(name, pluginSetting.getPipelineName());
}

public static PluginMetrics fromPluginSetting(final PluginSetting pluginSetting) {
return fromPluginSetting(pluginSetting, pluginSetting.getName());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ protected AbstractProcessor(final PluginMetrics pluginMetrics) {
timeElapsedTimer = pluginMetrics.timer(MetricNames.TIME_ELAPSED);
}

public PluginMetrics getPluginMetrics() {
return pluginMetrics;
}

/**
* @since 1.2
* This execute function calls the {@link AbstractProcessor#doExecute(Collection)} function of the implementation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ public void testCounterWithMetricsPrefix() {
counter.getId().getName());
}

@Test
public void testCounterWithMetricsPrefixWithCustomMetricsName() {
final String customName = PLUGIN_NAME + "_custom";
objectUnderTest = PluginMetrics.fromPluginSetting(pluginSetting, customName);

final Counter counter = objectUnderTest.counter("counter");
assertEquals(
pluginSetting.getPipelineName() + MetricNames.DELIMITER +
customName + MetricNames.DELIMITER +
"counter",
counter.getId().getName());
}

@Test
public void testCounter() {
final Counter counter = objectUnderTest.counter("counter");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,20 @@ public class LambdaProcessorSinkIT {
private AcknowledgementSet acknowledgementSet;

private LambdaProcessor createLambdaProcessor(LambdaProcessorConfig processorConfig) {
return new LambdaProcessor(pluginFactory, pluginMetrics, processorConfig, awsCredentialsSupplier, expressionEvaluator);
return new LambdaProcessor(pluginFactory, pluginSetting, processorConfig, awsCredentialsSupplier, expressionEvaluator);
}

private LambdaSink createLambdaSink(LambdaSinkConfig lambdaSinkConfig) {
return new LambdaSink(pluginSetting, lambdaSinkConfig, pluginFactory, null, awsCredentialsSupplier, expressionEvaluator);

}

private void setPrivateField(Object targetObject, String fieldName, Object value) throws Exception {
Field field = targetObject.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
field.set(targetObject, value);
}

@BeforeEach
public void setup() {
lambdaRegion = System.getProperty("tests.lambda.processor.region");
Expand Down Expand Up @@ -212,19 +218,13 @@ public void setup() {

}

private void setPrivateField(Object targetObject, String fieldName, Object value)
throws Exception {
Field field = targetObject.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
field.set(targetObject, value);
}

@ParameterizedTest
@ValueSource(ints = {11})
public void testLambdaProcessorAndLambdaSink(int numRecords) {
public void testLambdaProcessorAndLambdaSink(int numRecords) throws Exception {
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true);
lambdaProcessor = createLambdaProcessor(lambdaProcessorConfig);
setPrivateField(lambdaProcessor, "pluginMetrics", pluginMetrics);
List<Record<Event>> records = createRecords(numRecords);

Collection<Record<Event>> results = lambdaProcessor.doExecute(records);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -73,6 +74,8 @@ public class LambdaProcessorIT {
@Mock
private PluginMetrics pluginMetrics;
@Mock
private PluginSetting pluginSetting;
@Mock
private ExpressionEvaluator expressionEvaluator;
@Mock
private Counter testCounter;
Expand All @@ -81,7 +84,13 @@ public class LambdaProcessorIT {
@Mock
InvocationType invocationType;
private LambdaProcessor createObjectUnderTest(LambdaProcessorConfig processorConfig) {
return new LambdaProcessor(pluginFactory, pluginMetrics, processorConfig, awsCredentialsSupplier, expressionEvaluator);
return new LambdaProcessor(pluginFactory, pluginSetting, processorConfig, awsCredentialsSupplier, expressionEvaluator);
}

private void setPrivateField(Object targetObject, String fieldName, Object value) throws Exception {
Field field = targetObject.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
field.set(targetObject, value);
}

@BeforeEach
Expand All @@ -90,8 +99,10 @@ public void setup() {
functionName = System.getProperty("tests.lambda.processor.functionName");
role = System.getProperty("tests.lambda.processor.sts_role_arn");
pluginMetrics = mock(PluginMetrics.class);
//when(pluginMetrics.gauge(any(), any(AtomicLong.class))).thenReturn(new AtomicLong());
//testCounter = mock(Counter.class);
pluginSetting = mock(PluginSetting.class);
when(pluginSetting.getPipelineName()).thenReturn("pipeline");
when(pluginSetting.getName()).thenReturn("name");
testCounter = mock(Counter.class);
try {
lenient().doAnswer(args -> {
return null;
Expand Down Expand Up @@ -166,10 +177,11 @@ public void testRequestResponseWithMatchingEventsAggregateMode(int numRecords) {

@ParameterizedTest
@ValueSource(ints = {1000})
public void testRequestResponse_WithMatchingEvents_StrictMode_WithMultipleThreads(int numRecords) throws InterruptedException {
public void testRequestResponse_WithMatchingEvents_StrictMode_WithMultipleThreads(int numRecords) throws Exception {
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true);
lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig);
setPrivateField(lambdaProcessor, "pluginMetrics", pluginMetrics);
int numThreads = 5;
ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
CountDownLatch latch = new CountDownLatch(numThreads);
Expand All @@ -191,10 +203,11 @@ public void testRequestResponse_WithMatchingEvents_StrictMode_WithMultipleThread

@ParameterizedTest
@ValueSource(strings = {"RequestResponse", "Event"})
public void testDifferentInvocationTypes(String invocationType) {
public void testDifferentInvocationTypes(String invocationType) throws Exception {
when(this.invocationType.getAwsLambdaValue()).thenReturn(invocationType);
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true);
lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig);
setPrivateField(lambdaProcessor, "pluginMetrics", pluginMetrics);
List<Record<Event>> records = createRecords(10);
Collection<Record<Event>> results = lambdaProcessor.doExecute(records);
if (invocationType.equals("RequestResponse")) {
Expand All @@ -207,11 +220,12 @@ public void testDifferentInvocationTypes(String invocationType) {
}

@Test
public void testWithFailureTags() {
public void testWithFailureTags() throws Exception {
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false);
when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(Collections.singletonList("lambda_failure"));
LambdaProcessor spyLambdaProcessor = spy(createObjectUnderTest(lambdaProcessorConfig));
setPrivateField(spyLambdaProcessor, "pluginMetrics", pluginMetrics);
doThrow(new RuntimeException("Simulated Lambda failure"))
.when(spyLambdaProcessor).convertLambdaResponseToEvent(any(Buffer.class), any(InvokeResponse.class));
List<Record<Event>> records = createRecords(5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import javax.management.RuntimeMBeanException;

@DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class)
public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {

Expand All @@ -71,19 +73,21 @@ public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Eve
private final Counter numberOfRequestsSuccessCounter;
private final Counter numberOfRequestsFailedCounter;
private final Timer lambdaLatencyMetric;
private final List<String> tagsOnMatchFailure;
private final List<String> tagsOnFailure;
private final LambdaAsyncClient lambdaAsyncClient;
private final DistributionSummary requestPayloadMetric;
private final DistributionSummary responsePayloadMetric;
private final ResponseEventHandlingStrategy responseStrategy;
private final JsonOutputCodecConfig jsonOutputCodecConfig;
private final PluginMetrics pluginMetrics;

@DataPrepperPluginConstructor
public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pluginMetrics,
public LambdaProcessor(final PluginFactory pluginFactory, final PluginSetting pluginSetting,
final LambdaProcessorConfig lambdaProcessorConfig,
final AwsCredentialsSupplier awsCredentialsSupplier,
final ExpressionEvaluator expressionEvaluator) {
super(pluginMetrics);
super(PluginMetrics.fromPluginSetting(pluginSetting, pluginSetting.getName()+"_processor"));
pluginMetrics = getPluginMetrics();
this.expressionEvaluator = expressionEvaluator;
this.pluginFactory = pluginFactory;
this.lambdaProcessorConfig = lambdaProcessorConfig;
Expand All @@ -99,7 +103,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginMetrics pl
this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE);
this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE);
this.whenCondition = lambdaProcessorConfig.getWhenCondition();
this.tagsOnMatchFailure = lambdaProcessorConfig.getTagsOnFailure();
this.tagsOnFailure = lambdaProcessorConfig.getTagsOnFailure();

PluginModel responseCodecConfig = lambdaProcessorConfig.getResponseCodecConfig();

Expand Down Expand Up @@ -144,8 +148,7 @@ public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
for (Record<Event> record : records) {
final Event event = record.getData();
// If the condition is false, add the event to resultRecords as-is
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition,
event)) {
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) {
resultRecords.add(record);
continue;
}
Expand All @@ -155,28 +158,33 @@ public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap = LambdaCommonHandler.sendRecords(
recordsToLambda, lambdaProcessorConfig, lambdaAsyncClient,
new OutputCodecContext());

for (Map.Entry<Buffer, CompletableFuture<InvokeResponse>> entry : bufferToFutureMap.entrySet()) {
CompletableFuture<InvokeResponse> future = entry.getValue();
Buffer inputBuffer = entry.getKey();
try {
InvokeResponse response = future.join();
Duration latency = inputBuffer.stopLatencyWatch();
lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
requestPayloadMetric.record(inputBuffer.getPayloadRequestSize());
if (isSuccess(response)) {
resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response));
numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount());
numberOfRequestsSuccessCounter.increment();
resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response));
if (response.payload() != null) {
responsePayloadMetric.record(response.payload().asByteArray().length);
}
continue;
} else {
LOG.error("Lambda invoke failed with error {} ", response.statusCode());
resultRecords.addAll(addFailureTags(inputBuffer.getRecords()));
/* fall through */
}
} catch (Exception e) {
LOG.error("Exception from Lambda invocation ", e);
numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount());
numberOfRequestsFailedCounter.increment();
resultRecords.addAll(addFailureTags(inputBuffer.getRecords()));
/* fall through */
}
numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount());
numberOfRequestsFailedCounter.increment();
resultRecords.addAll(addFailureTags(inputBuffer.getRecords()));
}
return resultRecords;
}
Expand All @@ -190,53 +198,52 @@ List<Record<Event>> convertLambdaResponseToEvent(Buffer flushedBuffer,
final InvokeResponse lambdaResponse) {
InputCodec responseCodec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSetting);
List<Record<Event>> originalRecords = flushedBuffer.getRecords();
try {
List<Event> parsedEvents = new ArrayList<>();

List<Record<Event>> resultRecords = new ArrayList<>();
SdkBytes payload = lambdaResponse.payload();
// Handle null or empty payload
if (payload == null || payload.asByteArray() == null
|| payload.asByteArray().length == 0) {
LOG.warn(NOISY,
"Lambda response payload is null or empty, dropping the original events");
} else {
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
//Convert to response codec
try {
responseCodec.parse(inputStream, record -> {
List<Event> parsedEvents = new ArrayList<>();

List<Record<Event>> resultRecords = new ArrayList<>();
SdkBytes payload = lambdaResponse.payload();
// Handle null or empty payload
if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) {
LOG.warn(NOISY, "Lambda response payload is null or empty, dropping the original events");
} else {
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
//Convert to response codec
try {
responseCodec.parse(inputStream, record -> {
Event event = record.getData();
parsedEvents.add(event);
});
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
} catch (IOException ex) {
LOG.error("Error while trying to parse response from Lambda", ex);
throw new RuntimeException(ex);
}
if (parsedEvents.size() == 0) {
throw new RuntimeException("Lambda Response could not be parsed, returning original events");
}

LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " +
LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " +
"FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(),
flushedBuffer.getSize());
responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords,
flushedBuffer);
}
return resultRecords;
} catch (Exception e) {
LOG.error(NOISY, "Error converting Lambda response to Event");
addFailureTags(flushedBuffer.getRecords());
return originalRecords;
responseStrategy.handleEvents(parsedEvents, originalRecords, resultRecords, flushedBuffer);
}
return resultRecords;
}

/*
* If one event in the Buffer fails, we consider that the entire
* Batch fails and tag each event in that Batch.
*/
private List<Record<Event>> addFailureTags(List<Record<Event>> records) {
if (tagsOnFailure == null || tagsOnFailure.isEmpty()) {
return records;
}
// Add failure tags to each event in the batch
for (Record<Event> record : records) {
Event event = record.getData();
EventMetadata metadata = event.getMetadata();
if (metadata != null) {
metadata.addTags(tagsOnMatchFailure);
metadata.addTags(tagsOnFailure);
} else {
LOG.warn("Event metadata is null, cannot add failure tags.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,15 @@ public void doOutput(final Collection<Record<Event>> records) {
InvokeResponse response = future.join();
Duration latency = inputBuffer.stopLatencyWatch();
lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
requestPayloadMetric.record(inputBuffer.getPayloadRequestSize());
if (isSuccess(response)) {
releaseEventHandlesPerBatch(true, inputBuffer);
numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount());
numberOfRequestsSuccessCounter.increment();
releaseEventHandlesPerBatch(true, inputBuffer);
if (response.payload() != null) {
responsePayloadMetric.record(response.payload().asByteArray().length);
}
continue;
} else {
LOG.error("Lambda invoke failed with error {} ", response.statusCode());
handleFailure(new RuntimeException("failed"), inputBuffer);
Expand Down
Loading

0 comments on commit 482f65c

Please sign in to comment.