Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQS connector batch send #2889

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions documentation/src/main/docs/sqs/sending-aws-sqs-messages.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ explicitly specify metadata on the outgoing message:
{{ insert('sqs/outbound/SqsMessageStringProducer.java') }}
```

## Sending messages in batch

You can configure the outbound channel to send messages in batch of maximum 10 messages (AWS SQS limitation).

You can customize the size of batches, `10` being the default batch size, and the delay to wait for new messages to be added to the batch, 3000ms being the default delay:

``` java
mp.messaging.outgoing.prices.connector=smallrye-sqs
mp.messaging.outgoing.prices.queue=prices
mp.messaging.outgoing.prices.batch=true
mp.messaging.outgoing.prices.batch-size=5
mp.messaging.outgoing.prices.batch-delay=3000
```

## Serialization

When sending a `Message<T>`, the connector converts the message into a AWS SQS Message.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.smallrye.reactive.messaging.aws.sqs;

import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;

/**
* Exception thrown when a send message batch result contains an error.
*
* @see BatchResultErrorEntry
*/
public class BatchResultErrorException extends Exception {

public BatchResultErrorException(BatchResultErrorEntry entry) {
super("BatchResultError " + entry.code() + " " + entry.message() + ", senderFault = " + entry.senderFault());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
@ConnectorAttribute(name = "health-enabled", type = "boolean", direction = ConnectorAttribute.Direction.INCOMING_AND_OUTGOING, description = "Whether health reporting is enabled (default) or disabled", defaultValue = "true")

@ConnectorAttribute(name = "group.id", type = "string", direction = ConnectorAttribute.Direction.OUTGOING, description = "When set, sends messages with the specified group id")
@ConnectorAttribute(name = "batch", type = "boolean", direction = ConnectorAttribute.Direction.OUTGOING, description = "When set, sends messages in batches of maximum 10 messages", defaultValue = "false")
@ConnectorAttribute(name = "batch-size", type = "int", direction = ConnectorAttribute.Direction.OUTGOING, description = "In batch send mode, the maximum number of messages to include in batch, currently SQS maximum is 10 messages", defaultValue = "10")
@ConnectorAttribute(name = "batch-delay", type = "int", direction = ConnectorAttribute.Direction.OUTGOING, description = "In batch send mode, the maximum delay in milliseconds to wait for messages to be included in the batch", defaultValue = "3000")

@ConnectorAttribute(name = "wait-time-seconds", type = "int", direction = ConnectorAttribute.Direction.INCOMING, description = "The maximum amount of time in seconds to wait for messages to be received", defaultValue = "1")
@ConnectorAttribute(name = "max-number-of-messages", type = "int", direction = ConnectorAttribute.Direction.INCOMING, description = "The maximum number of messages to receive", defaultValue = "10")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.smallrye.reactive.messaging.aws.sqs;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -8,18 +9,25 @@
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.eclipse.microprofile.reactive.messaging.Message;

import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.reactive.messaging.OutgoingMessageMetadata;
import io.smallrye.reactive.messaging.aws.sqs.i18n.AwsSqsLogging;
import io.smallrye.reactive.messaging.health.HealthReport;
import io.smallrye.reactive.messaging.json.JsonMapping;
import io.smallrye.reactive.messaging.providers.helpers.MultiUtils;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResultEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;

public class SqsOutboundChannel {

Expand All @@ -32,17 +40,31 @@ public class SqsOutboundChannel {
private final List<Throwable> failures = new ArrayList<>();
private final boolean healthEnabled;
private final String groupId;
private final boolean batch;
private final Duration batchDelay;
private final int batchSize;

public SqsOutboundChannel(SqsConnectorOutgoingConfiguration conf, SqsManager sqsManager, JsonMapping jsonMapping) {
this.channel = conf.getChannel();
this.healthEnabled = conf.getHealthEnabled();
this.client = sqsManager.getClient(conf);
this.batch = conf.getBatch();
this.batchSize = conf.getBatchSize();
this.batchDelay = Duration.ofMillis(conf.getBatchDelay());
this.queueUrlUni = sqsManager.getQueueUrl(conf).memoize().indefinitely();
this.groupId = conf.getGroupId().orElse(null);
this.jsonMapping = jsonMapping;
this.subscriber = MultiUtils.via(multi -> multi
.onSubscription().call(s -> queueUrlUni)
.call(m -> publishMessage(this.client, m))
.plug(stream -> {
if (batch) {
return stream.group().intoLists().of(batchSize, batchDelay)
.call(l -> publishMessage(this.client, l))
.onItem().transformToMultiAndConcatenate(l -> Multi.createFrom().iterable(l));
} else {
return stream.call(m -> publishMessage(this.client, m));
}
})
.onFailure().invoke(f -> {
AwsSqsLogging.log.unableToDispatch(channel, f);
reportFailure(f);
Expand Down Expand Up @@ -73,6 +95,82 @@ private Uni<Void> publishMessage(SqsAsyncClient client, Message<?> m) {
});
}

private Uni<Void> publishMessage(SqsAsyncClient client, List<Message<?>> messages) {
if (closed.get()) {
return Uni.createFrom().voidItem();
}
if (messages.isEmpty()) {
return Uni.createFrom().nullItem();
}
if (messages.size() == 1) {
return publishMessage(client, messages.get(0));
}
return queueUrlUni.map(queueUrl -> getSendMessageRequest(queueUrl, messages))
.chain(request -> Uni.createFrom().completionStage(() -> client.sendMessageBatch(request)))
.onItem().transformToUni(response -> {
List<Uni<Void>> results = new ArrayList<>();
for (BatchResultErrorEntry entry : response.failed()) {
int index = Integer.parseInt(entry.id());
if (messages.size() > index) {
Message<?> m = messages.get(index);
results.add(Uni.createFrom().completionStage(m.nack(new BatchResultErrorException(entry))));
}
}
for (SendMessageBatchResultEntry entry : response.successful()) {
int index = Integer.parseInt(entry.id());
if (messages.size() > index) {
Message<?> m = messages.get(index);
SendMessageResponse r = SendMessageResponse.builder()
.messageId(entry.messageId())
.sequenceNumber(entry.sequenceNumber())
.md5OfMessageBody(entry.md5OfMessageBody())
.md5OfMessageAttributes(entry.md5OfMessageAttributes())
.md5OfMessageSystemAttributes(entry.md5OfMessageSystemAttributes())
.build();
AwsSqsLogging.log.messageSentToChannel(channel, r.messageId(), r.sequenceNumber());
OutgoingMessageMetadata.setResultOnMessage(m, r);
results.add(Uni.createFrom().completionStage(m.ack()));
}
}
return Uni.combine().all().unis(results).discardItems();
})
.onFailure().recoverWithUni(t -> {
List<Uni<Void>> results = new ArrayList<>();
for (Message<?> m : messages) {
results.add(Uni.createFrom().completionStage(m.nack(t)));
}
return Uni.combine().all().unis(results).discardItems();
});
}

private SendMessageBatchRequest getSendMessageRequest(String channelQueueUrl, List<Message<?>> messages) {
List<SendMessageBatchRequestEntry> entries = getSendMessageBatchEntry(channelQueueUrl, messages);
return SendMessageBatchRequest.builder()
.entries(entries)
.queueUrl(channelQueueUrl)
.build();
}

private List<SendMessageBatchRequestEntry> getSendMessageBatchEntry(String channelQueueUrl, List<Message<?>> messages) {
// Use message index in the list as the id to identify the message in the batch result.
return IntStream.range(0, messages.size())
.mapToObj(i -> sendMessageBatchRequestEntry(channelQueueUrl, String.valueOf(i), messages.get(i)))
.collect(Collectors.toList());
}

private SendMessageBatchRequestEntry sendMessageBatchRequestEntry(String channelQueueUrl, String id, Message<?> message) {
SendMessageRequest request = getSendMessageRequest(channelQueueUrl, message);
return SendMessageBatchRequestEntry.builder()
.id(id)
.delaySeconds(request.delaySeconds())
.messageAttributes(request.messageAttributes())
.messageGroupId(request.messageGroupId())
.messageDeduplicationId(request.messageDeduplicationId())
.messageSystemAttributes(request.messageSystemAttributes())
.messageBody(request.messageBody())
.build();
}

private SendMessageRequest getSendMessageRequest(String channelQueueUrl, Message<?> m) {
Object payload = m.getPayload();
String queueUrl = channelQueueUrl;
Expand Down
Loading
Loading