Skip to content

Commit

Permalink
fix: Moved the middleware to the MessageHandlerRegistry (openwallet-f…
Browse files Browse the repository at this point in the history
…oundation#1896)



Signed-off-by: Tom Lanser <[email protected]>
  • Loading branch information
Tommylans authored Jun 11, 2024
1 parent 558f877 commit a648af5
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 16 deletions.
19 changes: 19 additions & 0 deletions packages/core/src/agent/MessageHandlerRegistry.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { AgentMessage } from './AgentMessage'
import type { MessageHandler } from './MessageHandler'
import type { MessageHandlerMiddleware } from './MessageHandlerMiddleware'
import type { ParsedDidCommProtocolUri } from '../utils/messageType'

import { injectable } from 'tsyringe'
Expand All @@ -9,11 +10,29 @@ import { supportsIncomingDidCommProtocolUri, canHandleMessageType, parseMessageT
@injectable()
export class MessageHandlerRegistry {
private messageHandlers: MessageHandler[] = []
public readonly messageHandlerMiddlewares: MessageHandlerMiddleware[] = []
private _fallbackMessageHandler?: MessageHandler['handle']

public registerMessageHandler(messageHandler: MessageHandler) {
this.messageHandlers.push(messageHandler)
}

public registerMessageHandlerMiddleware(messageHandlerMiddleware: MessageHandlerMiddleware) {
this.messageHandlerMiddlewares.push(messageHandlerMiddleware)
}

public get fallbackMessageHandler() {
return this._fallbackMessageHandler
}

/**
* Sets the fallback message handler, the message handler that will be called if no handler
* is registered for an incoming message type.
*/
public setFallbackMessageHandler(fallbackMessageHandler: MessageHandler['handle']) {
this._fallbackMessageHandler = fallbackMessageHandler
}

public getHandlerForMessageType(messageType: string): MessageHandler | undefined {
const incomingMessageType = parseMessageType(messageType)

Expand Down
38 changes: 30 additions & 8 deletions packages/core/src/agent/__tests__/Dispatcher.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,13 @@ describe('Dispatcher', () => {
it('calls the middleware in the order they are registered', async () => {
const agentContext = getAgentContext()

// Replace the MessageHandlerRegistry instance with a empty one
agentContext.dependencyManager.registerInstance(MessageHandlerRegistry, new MessageHandlerRegistry())

const dispatcher = new Dispatcher(
new MessageSenderMock(),
eventEmitter,
new MessageHandlerRegistry(),
agentContext.dependencyManager.resolve(MessageHandlerRegistry),
agentConfig.logger
)

Expand Down Expand Up @@ -108,10 +111,13 @@ describe('Dispatcher', () => {
it('calls the middleware in the order they are registered', async () => {
const agentContext = getAgentContext()

// Replace the MessageHandlerRegistry instance with a empty one
agentContext.dependencyManager.registerInstance(MessageHandlerRegistry, new MessageHandlerRegistry())

const dispatcher = new Dispatcher(
new MessageSenderMock(),
eventEmitter,
new MessageHandlerRegistry(),
agentContext.dependencyManager.resolve(MessageHandlerRegistry),
agentConfig.logger
)

Expand Down Expand Up @@ -139,10 +145,13 @@ describe('Dispatcher', () => {
it('correctly calls the fallback message handler if no message handler is registered for the message type', async () => {
const agentContext = getAgentContext()

// Replace the MessageHandlerRegistry instance with a empty one
agentContext.dependencyManager.registerInstance(MessageHandlerRegistry, new MessageHandlerRegistry())

const dispatcher = new Dispatcher(
new MessageSenderMock(),
eventEmitter,
new MessageHandlerRegistry(),
agentContext.dependencyManager.resolve(MessageHandlerRegistry),
agentConfig.logger
)

Expand All @@ -160,13 +169,15 @@ describe('Dispatcher', () => {
})

it('will not call the message handler if the middleware does not call next (intercept incoming message handling)', async () => {
const messageHandlerRegistry = new MessageHandlerRegistry()
const agentContext = getAgentContext()

// Replace the MessageHandlerRegistry instance with a empty one
agentContext.dependencyManager.registerInstance(MessageHandlerRegistry, new MessageHandlerRegistry())

const dispatcher = new Dispatcher(
new MessageSenderMock(),
eventEmitter,
messageHandlerRegistry,
agentContext.dependencyManager.resolve(MessageHandlerRegistry),
agentConfig.logger
)

Expand All @@ -176,7 +187,12 @@ describe('Dispatcher', () => {
const inboundMessageContext = new InboundMessageContext(customProtocolMessage, { agentContext })

const mockHandle = jest.fn()
messageHandlerRegistry.registerMessageHandler({ supportedMessages: [CustomProtocolMessage], handle: mockHandle })
agentContext.dependencyManager.registerMessageHandlers([
{
supportedMessages: [CustomProtocolMessage],
handle: mockHandle,
},
])

const middleware = jest.fn()
agentContext.dependencyManager.registerMessageHandlerMiddleware(middleware)
Expand All @@ -192,10 +208,13 @@ describe('Dispatcher', () => {
it('calls the message handler set by the middleware', async () => {
const agentContext = getAgentContext()

// Replace the MessageHandlerRegistry instance with a empty one
agentContext.dependencyManager.registerInstance(MessageHandlerRegistry, new MessageHandlerRegistry())

const dispatcher = new Dispatcher(
new MessageSenderMock(),
eventEmitter,
new MessageHandlerRegistry(),
agentContext.dependencyManager.resolve(MessageHandlerRegistry),
agentConfig.logger
)

Expand Down Expand Up @@ -228,10 +247,13 @@ describe('Dispatcher', () => {
})
const messageSenderMock = new MessageSenderMock()

// Replace the MessageHandlerRegistry instance with a empty one
agentContext.dependencyManager.registerInstance(MessageHandlerRegistry, new MessageHandlerRegistry())

const dispatcher = new Dispatcher(
messageSenderMock,
eventEmitter,
new MessageHandlerRegistry(),
agentContext.dependencyManager.resolve(MessageHandlerRegistry),
agentConfig.logger
)

Expand Down
21 changes: 15 additions & 6 deletions packages/core/src/plugins/DependencyManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ export class DependencyManager {
public readonly container: DependencyContainer
public readonly registeredModules: ModulesMap

public readonly messageHandlerMiddlewares: MessageHandlerMiddleware[] = []
private _fallbackMessageHandler?: MessageHandler['handle']

public constructor(
container: DependencyContainer = rootContainer.createChildContainer(),
registeredModules: ModulesMap = {}
Expand Down Expand Up @@ -54,19 +51,31 @@ export class DependencyManager {
}

public registerMessageHandlerMiddleware(messageHandlerMiddleware: MessageHandlerMiddleware) {
this.messageHandlerMiddlewares.push(messageHandlerMiddleware)
const messageHandlerRegistry = this.resolve(MessageHandlerRegistry)

messageHandlerRegistry.registerMessageHandlerMiddleware(messageHandlerMiddleware)
}

public get fallbackMessageHandler() {
return this._fallbackMessageHandler
const messageHandlerRegistry = this.resolve(MessageHandlerRegistry)

return messageHandlerRegistry.fallbackMessageHandler
}

public get messageHandlerMiddlewares() {
const messageHandlerRegistry = this.resolve(MessageHandlerRegistry)

return messageHandlerRegistry.messageHandlerMiddlewares
}

/**
* Sets the fallback message handler, the message handler that will be called if no handler
* is registered for an incoming message type.
*/
public setFallbackMessageHandler(fallbackMessageHandler: MessageHandler['handle']) {
this._fallbackMessageHandler = fallbackMessageHandler
const messageHandlerRegistry = this.resolve(MessageHandlerRegistry)

messageHandlerRegistry.setFallbackMessageHandler(fallbackMessageHandler)
}

public registerSingleton<T>(from: InjectionToken<T>, to: InjectionToken<T>): void
Expand Down
29 changes: 27 additions & 2 deletions packages/tenants/tests/tenants.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import { SubjectInboundTransport } from '../../../tests/transport/SubjectInbound
import { SubjectOutboundTransport } from '../../../tests/transport/SubjectOutboundTransport'
import { uuid } from '../../core/src/utils/uuid'
import { testLogger } from '../../core/tests'

import { TenantsModule } from '@credo-ts/tenants'
import { TenantsModule } from '../src/TenantsModule'

const agent1Config: InitConfig = {
label: 'Tenant Agent 1',
Expand Down Expand Up @@ -251,4 +250,30 @@ describe('Tenants E2E', () => {

await agent1.modules.tenants.deleteTenantById(tenantRecord.id)
})

test('fallback middleware for the tenant manager propagated to the tenant', async () => {
expect(agent1.dependencyManager.fallbackMessageHandler).toBeUndefined()

const fallbackFunction = async () => {
// empty
}

agent1.dependencyManager.setFallbackMessageHandler(fallbackFunction)

expect(agent1.dependencyManager.fallbackMessageHandler).toBe(fallbackFunction)

const tenantRecord = await agent1.modules.tenants.createTenant({
config: {
label: 'Agent 1 Tenant 1',
},
})

const tenantAgent = await agent1.modules.tenants.getTenantAgent({
tenantId: tenantRecord.id,
})

expect(tenantAgent.dependencyManager.fallbackMessageHandler).toBe(fallbackFunction)

await tenantAgent.endSession()
})
})

0 comments on commit a648af5

Please sign in to comment.