diff --git a/integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts b/integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts new file mode 100644 index 000000000..5c568c28f --- /dev/null +++ b/integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts @@ -0,0 +1,82 @@ +import { + AmqpConnection, + RabbitMQModule, + RabbitSubscribe, +} from '@golevelup/nestjs-rabbitmq'; +import { INestApplication, Injectable } from '@nestjs/common'; +import { Test } from '@nestjs/testing'; + +const testHandler = jest.fn(); + +const routingKey1 = 'longConsumerRoutingKey'; +let subscriberResolve: (value: unknown) => void; + +async function delay(milliseconds = 0, returnValue) { + return new Promise((done) => + setTimeout(() => done(returnValue), milliseconds), + ); +} + +async function isFinished(promise) { + return await Promise.race([ + delay(0, false), + promise.then( + () => true, + () => true, + ), + ]); +} + +@Injectable() +class SubscribeService { + @RabbitSubscribe({ + queue: routingKey1, + }) + async handleSubscribe(message: object) { + await new Promise((resolve) => { + subscriberResolve = resolve; + }); + testHandler(message); + } +} + +describe('Rabbit Graceful Shutdown', () => { + let app: INestApplication; + let amqpConnection: AmqpConnection; + + const rabbitHost = + process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_HOST : 'localhost'; + const rabbitPort = + process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_PORT : '5672'; + const uri = `amqp://rabbitmq:rabbitmq@${rabbitHost}:${rabbitPort}`; + + beforeAll(async () => { + const moduleFixture = await Test.createTestingModule({ + providers: [SubscribeService], + imports: [ + RabbitMQModule.forRoot(RabbitMQModule, { + uri, + connectionInitOptions: { wait: true, reject: true, timeout: 3000 }, + }), + ], + }).compile(); + + app = moduleFixture.createNestApplication(); + amqpConnection = app.get<AmqpConnection>(AmqpConnection); + await app.init(); + }); + + it('should wait for consumers to finish', async () => { + await amqpConnection.publish('', routingKey1, 'testMessage'); + + await new Promise((resolve) => setTimeout(resolve, 100)); + const closePromise = app.close(); + await new Promise((resolve) => setTimeout(resolve, 100)); + expect(testHandler).not.toHaveBeenCalled(); + expect(isFinished(closePromise)).toBeFalsy(); + subscriberResolve(true); + await new Promise((resolve) => setTimeout(resolve, 100)); + expect(testHandler).toHaveBeenCalled(); + expect(isFinished(closePromise)).toBeTruthy(); + }); +}); diff --git a/packages/rabbitmq/src/amqp/connection.ts b/packages/rabbitmq/src/amqp/connection.ts index 39d414abc..3ac0ee5ca 100644 --- a/packages/rabbitmq/src/amqp/connection.ts +++ b/packages/rabbitmq/src/amqp/connection.ts @@ -81,6 +81,8 @@ export type ConsumerHandler<T, U> = ) => Promise<RpcResponse<U>>; }); +type Consumer = (msg: ConsumeMessage | null) => void | Promise<void>; + const defaultConfig = { name: 'default', prefetchCount: 10, @@ -128,6 +130,8 @@ export class AmqpConnection { private readonly config: Required<RabbitMQConfig>; + private readonly outstandingMessageProcessing = new Set<Promise<void>>(); + constructor(config: RabbitMQConfig) { this.config = { deserializer: (message) => JSON.parse(message.toString()), @@ -341,7 +345,7 @@ export class AmqpConnection { // Set up a consumer on the Direct Reply-To queue to facilitate RPC functionality await channel.consume( DIRECT_REPLY_QUEUE, - async (msg) => { + (msg) => { if (msg == null) { return; } @@ -427,6 +431,20 @@ export class AmqpConnection { }); } + /** + * Wrap a consumer with logic that tracks the outstanding message processing to + * be able to wait for them on shutdown. + */ + private wrapConsumer(consumer: Consumer): Consumer { + return (msg: ConsumeMessage | null) => { + const messageProcessingPromise = Promise.resolve(consumer(msg)); + this.outstandingMessageProcessing.add(messageProcessingPromise); + messageProcessingPromise.finally(() => + this.outstandingMessageProcessing.delete(messageProcessingPromise) + ); + }; + } + private async setupSubscriberChannel<T>( handler: SubscriberHandler<T>, msgOptions: MessageHandlerOptions, @@ -438,7 +456,7 @@ export class AmqpConnection { const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume( queue, - async (msg) => { + this.wrapConsumer(async (msg) => { try { if (isNull(msg)) { throw new Error('Received null message'); @@ -480,7 +498,7 @@ export class AmqpConnection { await errorHandler(channel, msg, e); } } - }, + }), consumeOptions ); @@ -534,7 +552,7 @@ export class AmqpConnection { const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume( queue, - async (msg) => { + this.wrapConsumer(async (msg) => { try { if (msg == null) { throw new Error('Received null message'); @@ -582,7 +600,7 @@ export class AmqpConnection { await errorHandler(channel, msg, e); } } - }, + }), rpcOptions?.queueOptions?.consumerOptions ); @@ -804,4 +822,25 @@ export class AmqpConnection { this.unregisterConsumerForQueue(consumerTag); return newConsumerTag; } + + public async close(): Promise<void> { + const managedChannels = Object.values(this._managedChannels); + + // First cancel all consumers so they stop getting new messages + await Promise.all(managedChannels.map((channel) => channel.cancelAll())); + + // Wait for all the outstanding messages to be processed + if (this.outstandingMessageProcessing.size) { + this.logger.log( + { outstandingMessageCount: this.outstandingMessageProcessing.size }, + 'Waiting for outstanding consumers' + ); + } + await Promise.all(this.outstandingMessageProcessing); + + // Close all channels + await Promise.all(managedChannels.map((channel) => channel.close())); + + await this.managedConnection.close(); + } } diff --git a/packages/rabbitmq/src/amqp/connectionManager.ts b/packages/rabbitmq/src/amqp/connectionManager.ts index 7375daf1a..79ea8899d 100644 --- a/packages/rabbitmq/src/amqp/connectionManager.ts +++ b/packages/rabbitmq/src/amqp/connectionManager.ts @@ -20,4 +20,8 @@ export class AmqpConnectionManager { clearConnections() { this.connections = []; } + + async close() { + await Promise.all(this.connections.map((connection) => connection.close())); + } } diff --git a/packages/rabbitmq/src/rabbitmq.module.ts b/packages/rabbitmq/src/rabbitmq.module.ts index e0d912132..03043ff85 100644 --- a/packages/rabbitmq/src/rabbitmq.module.ts +++ b/packages/rabbitmq/src/rabbitmq.module.ts @@ -138,12 +138,7 @@ export class RabbitMQModule async onApplicationShutdown() { this.logger.verbose?.('Closing AMQP Connections'); - - await Promise.all( - this.connectionManager - .getConnections() - .map((connection) => connection.managedConnection.close()) - ); + await this.connectionManager.close(); this.connectionManager.clearConnections(); RabbitMQModule.bootstrapped = false;