Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rabbitmq): add graceful shutdown logic #697

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import {
AmqpConnection,
RabbitMQModule,
RabbitSubscribe,
} from '@golevelup/nestjs-rabbitmq';
import { INestApplication, Injectable } from '@nestjs/common';
import { Test } from '@nestjs/testing';

const testHandler = jest.fn();

const routingKey1 = 'longConsumerRoutingKey';
let subscriberResolve: (value: unknown) => void;

async function delay(milliseconds = 0, returnValue) {
return new Promise((done) =>
setTimeout(() => done(returnValue), milliseconds),
);
}

async function isFinished(promise) {
return await Promise.race([
delay(0, false),
promise.then(
() => true,
() => true,
),
]);
}

@Injectable()
class SubscribeService {
@RabbitSubscribe({
queue: routingKey1,
})
async handleSubscribe(message: object) {
await new Promise((resolve) => {
subscriberResolve = resolve;
});
testHandler(message);
}
}

describe('Rabbit Graceful Shutdown', () => {
let app: INestApplication;
let amqpConnection: AmqpConnection;

const rabbitHost =
process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_HOST : 'localhost';
const rabbitPort =
process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_PORT : '5672';
const uri = `amqp://rabbitmq:rabbitmq@${rabbitHost}:${rabbitPort}`;

beforeAll(async () => {
const moduleFixture = await Test.createTestingModule({
providers: [SubscribeService],
imports: [
RabbitMQModule.forRoot(RabbitMQModule, {
uri,
connectionInitOptions: { wait: true, reject: true, timeout: 3000 },
}),
],
}).compile();

app = moduleFixture.createNestApplication();
app.enableShutdownHooks();
amqpConnection = app.get<AmqpConnection>(AmqpConnection);
await app.init();
});

it('should wait for consumers to finish', async () => {
await amqpConnection.publish('', routingKey1, 'testMessage');

await new Promise((resolve) => setTimeout(resolve, 100));
const closePromise = app.close();
await new Promise((resolve) => setTimeout(resolve, 100));
expect(testHandler).not.toHaveBeenCalled();
expect(await isFinished(closePromise)).toBeFalsy();
subscriberResolve(true);
await closePromise;
expect(testHandler).toHaveBeenCalled();
});
});
49 changes: 44 additions & 5 deletions packages/rabbitmq/src/amqp/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ export type ConsumerHandler<T, U> =
) => Promise<RpcResponse<U>>;
});

type Consumer = (msg: ConsumeMessage | null) => void | Promise<void>;

const defaultConfig = {
name: 'default',
prefetchCount: 10,
Expand Down Expand Up @@ -128,6 +130,8 @@ export class AmqpConnection {

private readonly config: Required<RabbitMQConfig>;

private readonly outstandingMessageProcessing = new Set<Promise<void>>();

constructor(config: RabbitMQConfig) {
this.config = {
deserializer: (message) => JSON.parse(message.toString()),
Expand Down Expand Up @@ -341,7 +345,7 @@ export class AmqpConnection {
// Set up a consumer on the Direct Reply-To queue to facilitate RPC functionality
await channel.consume(
DIRECT_REPLY_QUEUE,
async (msg) => {
(msg) => {
if (msg == null) {
return;
}
Expand Down Expand Up @@ -427,6 +431,20 @@ export class AmqpConnection {
});
}

/**
* Wrap a consumer with logic that tracks the outstanding message processing to
* be able to wait for them on shutdown.
*/
private wrapConsumer(consumer: Consumer): Consumer {
return (msg: ConsumeMessage | null) => {
const messageProcessingPromise = Promise.resolve(consumer(msg));
this.outstandingMessageProcessing.add(messageProcessingPromise);
messageProcessingPromise.finally(() =>
this.outstandingMessageProcessing.delete(messageProcessingPromise)
);
};
}

private async setupSubscriberChannel<T>(
handler: SubscriberHandler<T>,
msgOptions: MessageHandlerOptions,
Expand All @@ -438,7 +456,7 @@ export class AmqpConnection {

const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume(
queue,
async (msg) => {
this.wrapConsumer(async (msg) => {
try {
if (isNull(msg)) {
throw new Error('Received null message');
Expand Down Expand Up @@ -480,7 +498,7 @@ export class AmqpConnection {
await errorHandler(channel, msg, e);
}
}
},
}),
consumeOptions
);

Expand Down Expand Up @@ -534,7 +552,7 @@ export class AmqpConnection {

const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume(
queue,
async (msg) => {
this.wrapConsumer(async (msg) => {
try {
if (msg == null) {
throw new Error('Received null message');
Expand Down Expand Up @@ -582,7 +600,7 @@ export class AmqpConnection {
await errorHandler(channel, msg, e);
}
}
},
}),
rpcOptions?.queueOptions?.consumerOptions
);

Expand Down Expand Up @@ -804,4 +822,25 @@ export class AmqpConnection {
this.unregisterConsumerForQueue(consumerTag);
return newConsumerTag;
}

public async close(): Promise<void> {
const managedChannels = Object.values(this._managedChannels);

// First cancel all consumers so they stop getting new messages
await Promise.all(managedChannels.map((channel) => channel.cancelAll()));

// Wait for all the outstanding messages to be processed
if (this.outstandingMessageProcessing.size) {
this.logger.log(
{ outstandingMessageCount: this.outstandingMessageProcessing.size },
'Waiting for outstanding consumers'
);
}
await Promise.all(this.outstandingMessageProcessing);

// Close all channels
await Promise.all(managedChannels.map((channel) => channel.close()));

await this.managedConnection.close();
}
}
4 changes: 4 additions & 0 deletions packages/rabbitmq/src/amqp/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ export class AmqpConnectionManager {
clearConnections() {
this.connections = [];
}

async close() {
await Promise.all(this.connections.map((connection) => connection.close()));
}
}
7 changes: 1 addition & 6 deletions packages/rabbitmq/src/rabbitmq.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,7 @@ export class RabbitMQModule

async onApplicationShutdown() {
this.logger.verbose?.('Closing AMQP Connections');

await Promise.all(
this.connectionManager
.getConnections()
.map((connection) => connection.managedConnection.close())
);
await this.connectionManager.close();

this.connectionManager.clearConnections();
RabbitMQModule.bootstrapped = false;
Expand Down
Loading