Skip to content

Commit

Permalink
feat(rabbitmq): add graceful shutdown logic
Browse files Browse the repository at this point in the history
Add message processing tracking so it can wait on them at shutdown.

fix #688
  • Loading branch information
ttshivers committed Feb 16, 2024
1 parent 933f97a commit 8bddcf4
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 11 deletions.
82 changes: 82 additions & 0 deletions integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import {
AmqpConnection,
RabbitMQModule,
RabbitSubscribe,
} from '@golevelup/nestjs-rabbitmq';
import { INestApplication, Injectable } from '@nestjs/common';
import { Test } from '@nestjs/testing';

const testHandler = jest.fn();

const routingKey1 = 'longConsumerRoutingKey';
let subscriberResolve: (value: unknown) => void;

async function delay(milliseconds = 0, returnValue) {
return new Promise((done) =>
setTimeout(() => done(returnValue), milliseconds),
);
}

async function isFinished(promise) {
return await Promise.race([
delay(0, false),
promise.then(
() => true,
() => true,
),
]);
}

@Injectable()
class SubscribeService {
@RabbitSubscribe({
queue: routingKey1,
})
async handleSubscribe(message: object) {
await new Promise((resolve) => {
subscriberResolve = resolve;
});
testHandler(message);
}
}

describe('Rabbit Graceful Shutdown', () => {
let app: INestApplication;
let amqpConnection: AmqpConnection;

const rabbitHost =
process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_HOST : 'localhost';
const rabbitPort =
process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_PORT : '5672';
const uri = `amqp://rabbitmq:rabbitmq@${rabbitHost}:${rabbitPort}`;

beforeAll(async () => {
const moduleFixture = await Test.createTestingModule({
providers: [SubscribeService],
imports: [
RabbitMQModule.forRoot(RabbitMQModule, {
uri,
connectionInitOptions: { wait: true, reject: true, timeout: 3000 },
}),
],
}).compile();

app = moduleFixture.createNestApplication();
amqpConnection = app.get<AmqpConnection>(AmqpConnection);
await app.init();
});

it('should wait for consumers to finish', async () => {
await amqpConnection.publish('', routingKey1, 'testMessage');

await new Promise((resolve) => setTimeout(resolve, 100));
const closePromise = app.close();
await new Promise((resolve) => setTimeout(resolve, 100));
expect(testHandler).not.toHaveBeenCalled();
expect(isFinished(closePromise)).toBeFalsy();
subscriberResolve(true);
await new Promise((resolve) => setTimeout(resolve, 100));
expect(testHandler).toHaveBeenCalled();
expect(isFinished(closePromise)).toBeTruthy();
});
});
49 changes: 44 additions & 5 deletions packages/rabbitmq/src/amqp/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ export type ConsumerHandler<T, U> =
) => Promise<RpcResponse<U>>;
});

type Consumer = (msg: ConsumeMessage | null) => void | Promise<void>;

const defaultConfig = {
name: 'default',
prefetchCount: 10,
Expand Down Expand Up @@ -128,6 +130,8 @@ export class AmqpConnection {

private readonly config: Required<RabbitMQConfig>;

private readonly outstandingMessageProcessing = new Set<Promise<void>>();

constructor(config: RabbitMQConfig) {
this.config = {
deserializer: (message) => JSON.parse(message.toString()),
Expand Down Expand Up @@ -341,7 +345,7 @@ export class AmqpConnection {
// Set up a consumer on the Direct Reply-To queue to facilitate RPC functionality
await channel.consume(
DIRECT_REPLY_QUEUE,
async (msg) => {
(msg) => {
if (msg == null) {
return;
}
Expand Down Expand Up @@ -427,6 +431,20 @@ export class AmqpConnection {
});
}

/**
* Wrap a consumer with logic that tracks the outstanding message processing to
* be able to wait for them on shutdown.
*/
private wrapConsumer(consumer: Consumer): Consumer {
return (msg: ConsumeMessage | null) => {
const messageProcessingPromise = Promise.resolve(consumer(msg));
this.outstandingMessageProcessing.add(messageProcessingPromise);
messageProcessingPromise.finally(() =>
this.outstandingMessageProcessing.delete(messageProcessingPromise)
);
};
}

private async setupSubscriberChannel<T>(
handler: SubscriberHandler<T>,
msgOptions: MessageHandlerOptions,
Expand All @@ -438,7 +456,7 @@ export class AmqpConnection {

const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume(
queue,
async (msg) => {
this.wrapConsumer(async (msg) => {
try {
if (isNull(msg)) {
throw new Error('Received null message');
Expand Down Expand Up @@ -480,7 +498,7 @@ export class AmqpConnection {
await errorHandler(channel, msg, e);
}
}
},
}),
consumeOptions
);

Expand Down Expand Up @@ -534,7 +552,7 @@ export class AmqpConnection {

const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume(
queue,
async (msg) => {
this.wrapConsumer(async (msg) => {
try {
if (msg == null) {
throw new Error('Received null message');
Expand Down Expand Up @@ -582,7 +600,7 @@ export class AmqpConnection {
await errorHandler(channel, msg, e);
}
}
},
}),
rpcOptions?.queueOptions?.consumerOptions
);

Expand Down Expand Up @@ -804,4 +822,25 @@ export class AmqpConnection {
this.unregisterConsumerForQueue(consumerTag);
return newConsumerTag;
}

public async close(): Promise<void> {
const managedChannels = Object.values(this._managedChannels);

// First cancel all consumers so they stop getting new messages
await Promise.all(managedChannels.map((channel) => channel.cancelAll()));

// Wait for all the outstanding messages to be processed
if (this.outstandingMessageProcessing.size) {
this.logger.log(
{ outstandingMessageCount: this.outstandingMessageProcessing.size },
'Waiting for outstanding consumers'
);
}
await Promise.all(this.outstandingMessageProcessing);

// Close all channels
await Promise.all(managedChannels.map((channel) => channel.close()));

await this.managedConnection.close();
}
}
4 changes: 4 additions & 0 deletions packages/rabbitmq/src/amqp/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ export class AmqpConnectionManager {
clearConnections() {
this.connections = [];
}

async close() {
await Promise.all(this.connections.map((connection) => connection.close()));
}
}
7 changes: 1 addition & 6 deletions packages/rabbitmq/src/rabbitmq.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,7 @@ export class RabbitMQModule

async onApplicationShutdown() {
this.logger.verbose?.('Closing AMQP Connections');

await Promise.all(
this.connectionManager
.getConnections()
.map((connection) => connection.managedConnection.close())
);
await this.connectionManager.close();

this.connectionManager.clearConnections();
RabbitMQModule.bootstrapped = false;
Expand Down

0 comments on commit 8bddcf4

Please sign in to comment.