Skip to content

Commit

Permalink
Add stateful buffer for lambda sink
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg committed Jan 24, 2025
1 parent 9c61e03 commit 3fa74cb
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public InMemoryBuffer(String batchOptionKeyName, OutputCodecContext outputCodecC
this.outputCodecContext = outputCodecContext;
}

@Override
public void addRecord(Record<Event> record) {
records.add(record);
Event event = record.getData();
Expand All @@ -72,6 +73,7 @@ public void addRecord(Record<Event> record) {
eventCount++;
}

@Override
public List<Record<Event>> getRecords() {
return records;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import org.opensearch.dataprepper.model.sink.OutputCodecContext;
import org.opensearch.dataprepper.model.sink.Sink;
import org.opensearch.dataprepper.model.sink.SinkContext;
import org.opensearch.dataprepper.model.types.ByteCount;
import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck;
import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler;
import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData;
import org.opensearch.dataprepper.model.failures.DlqObject;
Expand All @@ -38,6 +41,7 @@
import java.time.Duration;
import java.util.Collection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -75,6 +79,13 @@ public class LambdaSink extends AbstractSink<Record<Event>> {
private final OutputCodecContext outputCodecContext;
private volatile boolean sinkInitialized;
private DlqPushHandler dlqPushHandler = null;
final int maxEvents;
final long maxBytes;
final Duration maxCollectTime;

// The partial buffer that may not yet have reached threshold.
// Access must be synchronized
private Buffer statefulBuffer;

@DataPrepperPluginConstructor
public LambdaSink(final PluginSetting pluginSetting,
Expand All @@ -90,6 +101,9 @@ public LambdaSink(final PluginSetting pluginSetting,
this.lambdaSinkConfig = lambdaSinkConfig;
this.expressionEvaluator = expressionEvaluator;
this.outputCodecContext = OutputCodecContext.fromSinkContext(sinkContext);
this.maxEvents = lambdaSinkConfig.getBatchOptions().getThresholdOptions().getEventCount();
this.maxBytes = lambdaSinkConfig.getBatchOptions().getThresholdOptions().getMaximumSize().getBytes();
this.maxCollectTime = lambdaSinkConfig.getBatchOptions().getThresholdOptions().getEventCollectTimeOut();

this.numberOfRecordsSuccessCounter = pluginMetrics.counter(
NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS);
Expand Down Expand Up @@ -138,57 +152,59 @@ public void doInitialize() {
}

private void doInitializeInternal() {
// Initialize the partial buffer
statefulBuffer = new InMemoryBuffer(
lambdaSinkConfig.getBatchOptions().getKeyName(),
outputCodecContext
);
sinkInitialized = Boolean.TRUE;
}

/**
* @param records Records to be output
* We only flush the partial buffer if we're shutting down or if we want to
* do a time-based flush.
*/
@Override
public void doOutput(final Collection<Record<Event>> records) {
public synchronized void shutdown() {
// Flush the partial buffer if any leftover
if (statefulBuffer.getEventCount() > 0) {
flushBuffers(Collections.singletonList(statefulBuffer));
}
}

@Override
public synchronized void doOutput(final Collection<Record<Event>> records) {
if (!sinkInitialized) {
LOG.warn("LambdaSink doOutput called before initialization");
return;
}
if (records.isEmpty()) {
return;
}

Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap = new HashMap<>();
try {
//Result from lambda is not currently processes.
bufferToFutureMap = LambdaCommonHandler.sendRecords(
records,
lambdaSinkConfig,
lambdaAsyncClient,
outputCodecContext);
} catch (Exception e) {
LOG.error("Exception while processing records ", e);
handleFailure(records, e, HttpURLConnection.HTTP_BAD_REQUEST);
}
// We'll collect any "full" buffers in a local list, flush them at the end
List<Buffer> fullBuffers = new ArrayList<>();

for (Map.Entry<Buffer, CompletableFuture<InvokeResponse>> entry : bufferToFutureMap.entrySet()) {
CompletableFuture<InvokeResponse> future = entry.getValue();
Buffer inputBuffer = entry.getKey();
try {
InvokeResponse response = future.join();
Duration latency = inputBuffer.stopLatencyWatch();
lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
requestPayloadMetric.record(inputBuffer.getPayloadRequestSize());
if (!isSuccess(response)) {
String errorMessage = String.format("Lambda invoke failed with status code %s error %s ",
response.statusCode(), response.payload().asUtf8String());
throw new RuntimeException(errorMessage);
}

releaseEventHandles(inputBuffer.getRecords(), true);
numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount());
numberOfRequestsSuccessCounter.increment();
if (response.payload() != null) {
responsePayloadMetric.record(response.payload().asByteArray().length);
}
// Add to the persistent buffer, check threshold
for (Record<Event> record : records) {
//statefulBuffer is either empty or partially filled(from previous run)
statefulBuffer.addRecord(record);

} catch (Exception e) {
LOG.error(NOISY, e.getMessage(), e);
handleFailure(inputBuffer.getRecords(), new RuntimeException("failed"), HttpURLConnection.HTTP_INTERNAL_ERROR);
if (isThresholdExceeded(statefulBuffer)) {
// This buffer is full
fullBuffers.add(statefulBuffer);
// Create new partial buffer
statefulBuffer = new InMemoryBuffer(
lambdaSinkConfig.getBatchOptions().getKeyName(),
outputCodecContext
);
}
}

// Flush any full buffers
if (!fullBuffers.isEmpty()) {
flushBuffers(fullBuffers);
}
}


Expand All @@ -210,7 +226,7 @@ private DlqObject createDlqObjectFromEvent(final Event event,
.build();
}

void handleFailure(Collection<Record<Event>> failedRecords, Throwable throwable, int statusCode) {
synchronized void handleFailure(Collection<Record<Event>> failedRecords, Throwable throwable, int statusCode) {
if (failedRecords.isEmpty()) {
return;
}
Expand Down Expand Up @@ -249,4 +265,65 @@ private void releaseEventHandles(Collection<Record<Event>> records, boolean succ
}
}
}

private synchronized void flushBuffers(final List<Buffer> buffersToFlush) {
// Combine all their records for a single call to sendRecords
List<Record<Event>> combinedRecords = new ArrayList<>();
for (Buffer buf : buffersToFlush) {
combinedRecords.addAll(buf.getRecords());
}

Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap;
try {
bufferToFutureMap = LambdaCommonHandler.sendRecords(
combinedRecords,
lambdaSinkConfig,
lambdaAsyncClient,
outputCodecContext
);
} catch (Exception e) {
LOG.error(NOISY, "Error sending buffers to Lambda", e);
handleFailure(combinedRecords, e, HttpURLConnection.HTTP_INTERNAL_ERROR);
return;
}

for (Map.Entry<Buffer, CompletableFuture<InvokeResponse>> entry : bufferToFutureMap.entrySet()) {
Buffer inputBuffer = entry.getKey();
CompletableFuture<InvokeResponse> future = entry.getValue();

try {
InvokeResponse response = future.join();
Duration latency = inputBuffer.stopLatencyWatch();
lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
requestPayloadMetric.record(inputBuffer.getPayloadRequestSize());
if (!isSuccess(response)) {
String errorMsg = String.format(
"Lambda invoke failed with code %d, error: %s",
response.statusCode(),
response.payload() != null ? response.payload().asUtf8String() : "No payload"
);
throw new RuntimeException(errorMsg);
}

releaseEventHandles(inputBuffer.getRecords(), true);
numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount());
numberOfRequestsSuccessCounter.increment();
if (response.payload() != null) {
responsePayloadMetric.record(response.payload().asByteArray().length);
}
} catch (Exception ex) {
LOG.error(NOISY, "Error handling future response from Lambda", ex);
handleFailure(inputBuffer.getRecords(), ex, HttpURLConnection.HTTP_INTERNAL_ERROR);
}
}
}

private boolean isThresholdExceeded(Buffer buffer) {
return ThresholdCheck.checkThresholdExceed(
buffer,
maxEvents,
ByteCount.ofBytes(maxBytes),
maxCollectTime
);
}
}
Loading

0 comments on commit 3fa74cb

Please sign in to comment.