diff --git a/integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts b/integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts
new file mode 100644
index 000000000..5c568c28f
--- /dev/null
+++ b/integration/rabbitmq/e2e/graceful-shutdown.e2e-spec.ts
@@ -0,0 +1,82 @@
+import {
+  AmqpConnection,
+  RabbitMQModule,
+  RabbitSubscribe,
+} from '@golevelup/nestjs-rabbitmq';
+import { INestApplication, Injectable } from '@nestjs/common';
+import { Test } from '@nestjs/testing';
+
+const testHandler = jest.fn();
+
+const routingKey1 = 'longConsumerRoutingKey';
+let subscriberResolve: (value: unknown) => void;
+
+async function delay(milliseconds = 0, returnValue) {
+  return new Promise((done) =>
+    setTimeout(() => done(returnValue), milliseconds),
+  );
+}
+
+async function isFinished(promise) {
+  return await Promise.race([
+    delay(0, false),
+    promise.then(
+      () => true,
+      () => true,
+    ),
+  ]);
+}
+
+@Injectable()
+class SubscribeService {
+  @RabbitSubscribe({
+    queue: routingKey1,
+  })
+  async handleSubscribe(message: object) {
+    await new Promise((resolve) => {
+      subscriberResolve = resolve;
+    });
+    testHandler(message);
+  }
+}
+
+describe('Rabbit Graceful Shutdown', () => {
+  let app: INestApplication;
+  let amqpConnection: AmqpConnection;
+
+  const rabbitHost =
+    process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_HOST : 'localhost';
+  const rabbitPort =
+    process.env.NODE_ENV === 'ci' ? process.env.RABBITMQ_PORT : '5672';
+  const uri = `amqp://rabbitmq:rabbitmq@${rabbitHost}:${rabbitPort}`;
+
+  beforeAll(async () => {
+    const moduleFixture = await Test.createTestingModule({
+      providers: [SubscribeService],
+      imports: [
+        RabbitMQModule.forRoot(RabbitMQModule, {
+          uri,
+          connectionInitOptions: { wait: true, reject: true, timeout: 3000 },
+        }),
+      ],
+    }).compile();
+
+    app = moduleFixture.createNestApplication();
+    amqpConnection = app.get<AmqpConnection>(AmqpConnection);
+    await app.init();
+  });
+
+  it('should wait for consumers to finish', async () => {
+    await amqpConnection.publish('', routingKey1, 'testMessage');
+
+    await new Promise((resolve) => setTimeout(resolve, 100));
+    const closePromise = app.close();
+    await new Promise((resolve) => setTimeout(resolve, 100));
+    expect(testHandler).not.toHaveBeenCalled();
+    expect(isFinished(closePromise)).toBeFalsy();
+    subscriberResolve(true);
+    await new Promise((resolve) => setTimeout(resolve, 100));
+    expect(testHandler).toHaveBeenCalled();
+    expect(isFinished(closePromise)).toBeTruthy();
+  });
+});
diff --git a/packages/rabbitmq/src/amqp/connection.ts b/packages/rabbitmq/src/amqp/connection.ts
index 39d414abc..3ac0ee5ca 100644
--- a/packages/rabbitmq/src/amqp/connection.ts
+++ b/packages/rabbitmq/src/amqp/connection.ts
@@ -81,6 +81,8 @@ export type ConsumerHandler<T, U> =
       ) => Promise<RpcResponse<U>>;
     });
 
+type Consumer = (msg: ConsumeMessage | null) => void | Promise<void>;
+
 const defaultConfig = {
   name: 'default',
   prefetchCount: 10,
@@ -128,6 +130,8 @@ export class AmqpConnection {
 
   private readonly config: Required<RabbitMQConfig>;
 
+  private readonly outstandingMessageProcessing = new Set<Promise<void>>();
+
   constructor(config: RabbitMQConfig) {
     this.config = {
       deserializer: (message) => JSON.parse(message.toString()),
@@ -341,7 +345,7 @@ export class AmqpConnection {
     // Set up a consumer on the Direct Reply-To queue to facilitate RPC functionality
     await channel.consume(
       DIRECT_REPLY_QUEUE,
-      async (msg) => {
+      (msg) => {
         if (msg == null) {
           return;
         }
@@ -427,6 +431,20 @@ export class AmqpConnection {
     });
   }
 
+  /**
+   * Wrap a consumer with logic that tracks the outstanding message processing to
+   * be able to wait for them on shutdown.
+   */
+  private wrapConsumer(consumer: Consumer): Consumer {
+    return (msg: ConsumeMessage | null) => {
+      const messageProcessingPromise = Promise.resolve(consumer(msg));
+      this.outstandingMessageProcessing.add(messageProcessingPromise);
+      messageProcessingPromise.finally(() =>
+        this.outstandingMessageProcessing.delete(messageProcessingPromise)
+      );
+    };
+  }
+
   private async setupSubscriberChannel<T>(
     handler: SubscriberHandler<T>,
     msgOptions: MessageHandlerOptions,
@@ -438,7 +456,7 @@ export class AmqpConnection {
 
     const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume(
       queue,
-      async (msg) => {
+      this.wrapConsumer(async (msg) => {
         try {
           if (isNull(msg)) {
             throw new Error('Received null message');
@@ -480,7 +498,7 @@ export class AmqpConnection {
             await errorHandler(channel, msg, e);
           }
         }
-      },
+      }),
       consumeOptions
     );
 
@@ -534,7 +552,7 @@ export class AmqpConnection {
 
     const { consumerTag }: { consumerTag: ConsumerTag } = await channel.consume(
       queue,
-      async (msg) => {
+      this.wrapConsumer(async (msg) => {
         try {
           if (msg == null) {
             throw new Error('Received null message');
@@ -582,7 +600,7 @@ export class AmqpConnection {
             await errorHandler(channel, msg, e);
           }
         }
-      },
+      }),
       rpcOptions?.queueOptions?.consumerOptions
     );
 
@@ -804,4 +822,25 @@ export class AmqpConnection {
     this.unregisterConsumerForQueue(consumerTag);
     return newConsumerTag;
   }
+
+  public async close(): Promise<void> {
+    const managedChannels = Object.values(this._managedChannels);
+
+    // First cancel all consumers so they stop getting new messages
+    await Promise.all(managedChannels.map((channel) => channel.cancelAll()));
+
+    // Wait for all the outstanding messages to be processed
+    if (this.outstandingMessageProcessing.size) {
+      this.logger.log(
+        { outstandingMessageCount: this.outstandingMessageProcessing.size },
+        'Waiting for outstanding consumers'
+      );
+    }
+    await Promise.all(this.outstandingMessageProcessing);
+
+    // Close all channels
+    await Promise.all(managedChannels.map((channel) => channel.close()));
+
+    await this.managedConnection.close();
+  }
 }
diff --git a/packages/rabbitmq/src/amqp/connectionManager.ts b/packages/rabbitmq/src/amqp/connectionManager.ts
index 7375daf1a..79ea8899d 100644
--- a/packages/rabbitmq/src/amqp/connectionManager.ts
+++ b/packages/rabbitmq/src/amqp/connectionManager.ts
@@ -20,4 +20,8 @@ export class AmqpConnectionManager {
   clearConnections() {
     this.connections = [];
   }
+
+  async close() {
+    await Promise.all(this.connections.map((connection) => connection.close()));
+  }
 }
diff --git a/packages/rabbitmq/src/rabbitmq.module.ts b/packages/rabbitmq/src/rabbitmq.module.ts
index e0d912132..03043ff85 100644
--- a/packages/rabbitmq/src/rabbitmq.module.ts
+++ b/packages/rabbitmq/src/rabbitmq.module.ts
@@ -138,12 +138,7 @@ export class RabbitMQModule
 
   async onApplicationShutdown() {
     this.logger.verbose?.('Closing AMQP Connections');
-
-    await Promise.all(
-      this.connectionManager
-        .getConnections()
-        .map((connection) => connection.managedConnection.close())
-    );
+    await this.connectionManager.close();
 
     this.connectionManager.clearConnections();
     RabbitMQModule.bootstrapped = false;