Skip to content

Commit

Permalink
Address thread safety for lambda processor and lambda sink (#5181)
Browse files Browse the repository at this point in the history
* Address thread safety for lambda processor and additional fixes

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address comments

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address comments

Signed-off-by: Srikanth Govindarajan <[email protected]>

---------

Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg authored Nov 13, 2024
1 parent d096aae commit 059e1c5
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 377 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.opensearch.dataprepper.plugins.lambda.common;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.BufferFactory;
import org.slf4j.Logger;
Expand All @@ -17,26 +16,22 @@ public class LambdaCommonHandler {
private final String functionName;
private final String invocationType;
BufferFactory bufferFactory;
private final ObjectMapper objectMapper = new ObjectMapper();

public LambdaCommonHandler(
final Logger log,
final LambdaAsyncClient lambdaAsyncClient,
final String functionName,
final String invocationType,
BufferFactory bufferFactory){
final String invocationType){
this.LOG = log;
this.lambdaAsyncClient = lambdaAsyncClient;
this.functionName = functionName;
this.invocationType = invocationType;
this.bufferFactory = bufferFactory;
}

public Buffer createBuffer(Buffer currentBuffer) {
public Buffer createBuffer(BufferFactory bufferFactory) {
try {
LOG.debug("Resetting buffer");
currentBuffer = bufferFactory.getBuffer(lambdaAsyncClient, functionName, invocationType);
return currentBuffer;
return bufferFactory.getBuffer(lambdaAsyncClient, functionName, invocationType);
} catch (IOException e) {
throw new RuntimeException("Failed to reset buffer", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ public interface Buffer {

public Long getPayloadRequestSize();

public Long getPayloadResponseSize();

public Duration stopLatencyWatch();


void reset();

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public class InMemoryBuffer implements Buffer {
private StopWatch lambdaLatencyWatch;
private long payloadRequestSize;
private long payloadResponseSize;
private boolean isCodecStarted;
private final List<Record<Event>> records;


Expand All @@ -53,7 +52,6 @@ public InMemoryBuffer(LambdaAsyncClient lambdaAsyncClient, String functionName,
bufferWatch.start();
lambdaLatencyWatch = new StopWatch();
eventCount = 0;
isCodecStarted = false;
payloadRequestSize = 0;
payloadResponseSize = 0;
}
Expand Down Expand Up @@ -86,7 +84,6 @@ public void reset() {
eventCount = 0;
bufferWatch.reset();
lambdaLatencyWatch.reset();
isCodecStarted = false;
payloadRequestSize = 0;
payloadResponseSize = 0;
}
Expand Down Expand Up @@ -160,13 +157,10 @@ public Long getPayloadRequestSize() {
return payloadRequestSize;
}

public Long getPayloadResponseSize() {
return payloadResponseSize;
}

public StopWatch getBufferWatch() {return bufferWatch;}

public StopWatch getLambdaLatencyWatch(){return lambdaLatencyWatch;}


}

Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

public class AggregateResponseEventHandlingStrategy implements ResponseEventHandlingStrategy {

private static final Logger LOG = LoggerFactory.getLogger(AggregateResponseEventHandlingStrategy.class);

@Override
public void handleEvents(List<Event> parsedEvents, List<Record<Event>> originalRecords, List<Record<Event>> resultRecords, Buffer flushedBuffer) {
public void handleEvents(List<Event> parsedEvents, List<Record<Event>> originalRecords,
List<Record<Event>> resultRecords, Buffer flushedBuffer) {

Event originalEvent = originalRecords.get(0).getData();
DefaultEventHandle eventHandle = (DefaultEventHandle) originalEvent.getEventHandle();
AcknowledgementSet originalAcknowledgementSet = eventHandle.getAcknowledgementSet();
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io.micrometer.core.instrument.Timer;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.expression.ExpressionEvaluator;
import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.codec.OutputCodec;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
Expand Down Expand Up @@ -37,6 +38,7 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -67,7 +69,6 @@ public class LambdaSinkService {
private final String invocationType;
private final BufferFactory bufferFactory;
private final DlqPushHandler dlqPushHandler;
private final List<Event> events;
private final BatchOptions batchOptions;
private int maxEvents = 0;
private ByteCount maxBytes = null;
Expand Down Expand Up @@ -107,14 +108,13 @@ public LambdaSinkService(final LambdaAsyncClient lambdaAsyncClient, final Lambda
maxBytes = batchOptions.getThresholdOptions().getMaximumSize();
maxCollectionDuration = batchOptions.getThresholdOptions().getEventCollectTimeOut();
invocationType = lambdaSinkConfig.getInvocationType().getAwsLambdaValue();
events = new ArrayList();
futureList = new ArrayList<>();
futureList = Collections.synchronizedList(new ArrayList<>());

this.bufferFactory = bufferFactory;

LOG.info("LambdaFunctionName:{} , invocationType:{}", functionName, invocationType);
// Initialize LambdaCommonHandler
lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, functionName, invocationType, bufferFactory);
lambdaCommonHandler = new LambdaCommonHandler(LOG, lambdaAsyncClient, functionName, invocationType);
}


Expand All @@ -123,14 +123,16 @@ public void output(Collection<Record<Event>> records) {
return;
}

List<Record<Event>> resultRecords = new ArrayList<>();
//Result from lambda is not currently processes.
List<Record<Event>> resultRecords = null;

reentrantLock.lock();
try {
for (Record<Event> record : records) {
final Event event = record.getData();

if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, event)) {
resultRecords.add(record);
releaseEventHandle(event, true);
continue;
}
try {
Expand Down Expand Up @@ -167,6 +169,12 @@ public void output(Collection<Record<Event>> records) {
// Wait for all futures to complete
lambdaCommonHandler.waitForFutures(futureList);

// Release event handles for records not sent to Lambda
for (Record<Event> record : records) {
Event event = record.getData();
releaseEventHandle(event, true);
}

}

void flushToLambdaIfNeeded(List<Record<Event>> resultRecords, boolean forceFlush) {
Expand All @@ -182,22 +190,13 @@ void flushToLambdaIfNeeded(List<Record<Event>> resultRecords, boolean forceFlush

// Handle future
CompletableFuture<Void> processingFuture = future.thenAccept(response -> {
// Success handler
boolean success = lambdaCommonHandler.checkStatusCode(response);
if(success) {
LOG.info("Successfully flushed {} events", eventCount);
numberOfRecordsSuccessCounter.increment(eventCount);
requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
Duration latency = flushedBuffer.stopLatencyWatch();
lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
} else {
// Non-2xx status code treated as failure
handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()),
flushedBuffer);
}
handleLambdaResponse(flushedBuffer, eventCount, response);
}).exceptionally(throwable -> {
// Failure handler
LOG.error("Exception occurred while invoking Lambda. Function: {}, event in batch:{} | Exception: ", functionName, currentBufferPerBatch.getRecords().get(0), throwable);
List<Record<Event>> bufferRecords = flushedBuffer.getRecords();
Record<Event> eventRecord = bufferRecords.isEmpty() ? null : bufferRecords.get(0);
LOG.error(NOISY, "Exception occurred while invoking Lambda. Function: {} , Event: {}",
functionName, eventRecord == null? "null":eventRecord.getData(), throwable);
requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
responsePayloadMetric.set(0);
Duration latency = flushedBuffer.stopLatencyWatch();
Expand All @@ -209,28 +208,30 @@ void flushToLambdaIfNeeded(List<Record<Event>> resultRecords, boolean forceFlush
futureList.add(processingFuture);

// Create a new buffer for the next batch
currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch);
currentBufferPerBatch = lambdaCommonHandler.createBuffer(bufferFactory);
} catch (IOException e) {
LOG.error("Exception while flushing to lambda", e);
handleFailure(e, currentBufferPerBatch);
currentBufferPerBatch = lambdaCommonHandler.createBuffer(currentBufferPerBatch);
currentBufferPerBatch = lambdaCommonHandler.createBuffer(bufferFactory);
}
}
}

void handleFailure(Throwable throwable, Buffer flushedBuffer) {
if (currentBufferPerBatch.getEventCount() > 0) {
numberOfRecordsFailedCounter.increment(currentBufferPerBatch.getEventCount());
} else {
numberOfRecordsFailedCounter.increment();
}
try {
if (flushedBuffer.getEventCount() > 0) {
numberOfRecordsFailedCounter.increment(flushedBuffer.getEventCount());
}

SdkBytes payload = currentBufferPerBatch.getPayload();
if (dlqPushHandler != null) {
dlqPushHandler.perform(pluginSetting, new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0));
releaseEventHandlesPerBatch(true, flushedBuffer);
} else {
releaseEventHandlesPerBatch(false, flushedBuffer);
SdkBytes payload = flushedBuffer.getPayload();
if (dlqPushHandler != null) {
dlqPushHandler.perform(pluginSetting, new LambdaSinkFailedDlqData(payload, throwable.getMessage(), 0));
releaseEventHandlesPerBatch(true, flushedBuffer);
} else {
releaseEventHandlesPerBatch(false, flushedBuffer);
}
} catch (Exception ex){
LOG.error("Exception occured during error handling");
}
}

Expand All @@ -241,11 +242,45 @@ private void releaseEventHandlesPerBatch(boolean success, Buffer flushedBuffer)
List<Record<Event>> records = flushedBuffer.getRecords();
for (Record<Event> record : records) {
Event event = record.getData();
releaseEventHandle(event, success);
}
}

/**
* Releases the event handle based on processing success.
*
* @param event the event to release
* @param success indicates if processing was successful
*/
private void releaseEventHandle(Event event, boolean success) {
if (event != null) {
EventHandle eventHandle = event.getEventHandle();
if (eventHandle != null) {
eventHandle.release(success);
}
}
}

private void handleLambdaResponse(Buffer flushedBuffer, int eventCount, InvokeResponse response) {
boolean success = lambdaCommonHandler.checkStatusCode(response);
if (success) {
LOG.info("Successfully flushed {} events", eventCount);
SdkBytes payload = response.payload();
if (payload == null || payload.asByteArray() == null || payload.asByteArray().length == 0) {
responsePayloadMetric.set(0);
} else {
responsePayloadMetric.set(payload.asByteArray().length);
}
//metrics
requestPayloadMetric.set(flushedBuffer.getPayloadRequestSize());
numberOfRecordsSuccessCounter.increment(eventCount);
Duration latency = flushedBuffer.stopLatencyWatch();
lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
}
else {
// Non-2xx status code treated as failure
handleFailure(new RuntimeException("Non-success Lambda status code: " + response.statusCode()), flushedBuffer);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class LambdaCommonHandlerTest {
@BeforeEach
public void setUp() {
MockitoAnnotations.openMocks(this);
lambdaCommonHandler = new LambdaCommonHandler(mockLogger, mockLambdaAsyncClient, functionName, invocationType, mockBufferFactory);
lambdaCommonHandler = new LambdaCommonHandler(mockLogger, mockLambdaAsyncClient, functionName, invocationType);
}

@Test
Expand All @@ -61,7 +61,7 @@ public void testCreateBuffer_success() throws IOException {
when(mockBufferFactory.getBuffer(any(), anyString(), any())).thenReturn(mockBuffer);

// Act
Buffer result = lambdaCommonHandler.createBuffer(mockBuffer);
Buffer result = lambdaCommonHandler.createBuffer(mockBufferFactory);

// Assert
verify(mockBufferFactory, times(1)).getBuffer(mockLambdaAsyncClient, functionName, invocationType);
Expand All @@ -76,7 +76,7 @@ public void testCreateBuffer_throwsException() throws IOException {

// Act & Assert
try {
lambdaCommonHandler.createBuffer(mockBuffer);
lambdaCommonHandler.createBuffer(mockBufferFactory);
} catch (RuntimeException e) {
assert e.getMessage().contains("Failed to reset buffer");
}
Expand Down
Loading

0 comments on commit 059e1c5

Please sign in to comment.