diff --git a/packages/spanner-migrate/src/__tests__/apply.spec.ts b/packages/spanner-migrate/src/__tests__/apply.spec.ts index 00ccb034..b10d9c56 100644 --- a/packages/spanner-migrate/src/__tests__/apply.spec.ts +++ b/packages/spanner-migrate/src/__tests__/apply.spec.ts @@ -1,21 +1,23 @@ import { type Database, Spanner } from '@google-cloud/spanner' -import type { - ExecuteSqlRequest, - Transaction, -} from '@google-cloud/spanner/build/src/transaction' +import type { ExecuteSqlRequest } from '@google-cloud/spanner/build/src/transaction' import { applyDown, applyUp } from '../apply' import type { Migration } from '../types' describe('apply', () => { + let spanner: jest.Mocked let db: jest.Mocked beforeEach(() => { // Assume `Database` is mocked globally - db = new Spanner() + spanner = new Spanner() as jest.Mocked + db = spanner .instance('my-instance') .database('my-database') as jest.Mocked jest.clearAllMocks() }) + afterEach(() => { + spanner.close() + }) describe('applyUp', () => { it('should apply the up script and record the migration', async () => { diff --git a/packages/spanner-migrate/src/__tests__/db.spec.ts b/packages/spanner-migrate/src/__tests__/db.spec.ts index 975e9c83..a6d910cc 100644 --- a/packages/spanner-migrate/src/__tests__/db.spec.ts +++ b/packages/spanner-migrate/src/__tests__/db.spec.ts @@ -15,6 +15,9 @@ describe('prepare', () => { instance = spanner.instance('my-instance') as jest.Mocked database = instance.database('my-database') as jest.Mocked }) + afterEach(() => { + spanner.close() + }) describe('ensure ensureMigrationTable', () => { it('checks if table exists', async () => { await ensureMigrationTable(database) diff --git a/packages/spanner-mock/src/spanner.ts b/packages/spanner-mock/src/spanner.ts index 8340d4e0..2d31dd8e 100644 --- a/packages/spanner-mock/src/spanner.ts +++ b/packages/spanner-mock/src/spanner.ts @@ -1,7 +1,12 @@ import { jest } from '@jest/globals' import { createMockInstance } from './instance' -export const createSpanner = () => { +type SpannerArgs = { + projectId?: string +} + +const spannerClients = new Map() +const createNewSpannerClient = (projectId: string) => { const instances = new Map>() return { instance: jest.fn((instanceName: string) => { @@ -10,5 +15,16 @@ export const createSpanner = () => { } return instances.get(instanceName) }), + close: jest.fn(() => { + spannerClients.delete(projectId) + }), + } +} + +export const createSpanner = (args: SpannerArgs) => { + const projectId = args?.projectId || '' + if (!spannerClients.has(projectId)) { + spannerClients.set(projectId, createNewSpannerClient(projectId)) } + return spannerClients.get(projectId) }