diff --git a/src/index.ts b/src/index.ts index 38c0086..ba0a92a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,6 +12,7 @@ import type { Config } from "../config.d.ts"; import { normalizeVerifiedConfig } from "./config.js"; import { TOOLS } from "./tools/index.js"; import { RESOURCE_TEMPLATES } from "./mcp/resources.js"; +import { registerServerCleanup } from "./server.js"; import { ListResourcesRequestSchema, @@ -145,6 +146,9 @@ export default function ({ config }: { config: z.infer }) { // Create the context, passing server instance and config const contextId = randomUUID(); const context = new Context(server.server, internalConfig, contextId); + registerServerCleanup(server.server, () => + context.getSessionManager().closeAllSessions(), + ); server.server.registerCapabilities({ resources: { diff --git a/src/server.test.ts b/src/server.test.ts new file mode 100644 index 0000000..a9e8a17 --- /dev/null +++ b/src/server.test.ts @@ -0,0 +1,49 @@ +import { describe, expect, it, vi } from "vitest"; +import type { Server } from "@modelcontextprotocol/sdk/server/index.js"; + +import { registerServerCleanup, ServerList } from "./server.js"; + +function fakeServer(closeImpl?: () => Promise | void): Server { + return { + close: vi.fn(closeImpl ?? (async () => undefined)), + } as unknown as Server; +} + +describe("ServerList lifecycle cleanup", () => { + it("runs registered session cleanup before closing a server", async () => { + const cleanup = vi.fn(async () => undefined); + const server = fakeServer(); + registerServerCleanup(server, cleanup); + + const serverList = new ServerList(async () => server); + const created = await serverList.create(); + + await serverList.close(created); + + expect(cleanup).toHaveBeenCalledOnce(); + expect(server.close).toHaveBeenCalledOnce(); + expect(cleanup.mock.invocationCallOrder[0]).toBeLessThan( + vi.mocked(server.close).mock.invocationCallOrder[0], + ); + }); + + it("keeps closeAll idempotent when a transport close re-enters ServerList.close", async () => { + let created!: Server; + const serverList = new ServerList(async () => { + created = fakeServer(async () => { + await serverList.close(created); + }); + return created; + }); + const cleanup = vi.fn(async () => undefined); + + const server = await serverList.create(); + registerServerCleanup(server, cleanup); + + await serverList.closeAll(); + await serverList.closeAll(); + + expect(cleanup).toHaveBeenCalledOnce(); + expect(server.close).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/server.ts b/src/server.ts index 1c5b833..7f0313f 100644 --- a/src/server.ts +++ b/src/server.ts @@ -1,7 +1,19 @@ import { Server } from "@modelcontextprotocol/sdk/server/index.js"; +type ServerCleanup = () => Promise | void; + +const serverCleanups = new WeakMap(); + +export function registerServerCleanup( + server: Server, + cleanup: ServerCleanup, +) { + serverCleanups.set(server, cleanup); +} + export class ServerList { private _servers: Server[] = []; + private _closing = new WeakSet(); private _serverFactory: () => Promise; constructor(serverFactory: () => Promise) { @@ -15,12 +27,26 @@ export class ServerList { } async close(server: Server) { - await server.close(); + if (this._closing.has(server)) return; + const index = this._servers.indexOf(server); - if (index !== -1) this._servers.splice(index, 1); + if (index === -1) return; + + this._closing.add(server); + this._servers.splice(index, 1); + + try { + await serverCleanups.get(server)?.(); + } finally { + try { + await server.close(); + } finally { + this._closing.delete(server); + } + } } async closeAll() { - await Promise.all(this._servers.map((server) => server.close())); + await Promise.all([...this._servers].map((server) => this.close(server))); } } diff --git a/src/transport.ts b/src/transport.ts index 4885bf9..b6adfb1 100644 --- a/src/transport.ts +++ b/src/transport.ts @@ -53,10 +53,13 @@ async function handleStreamable( sessionIdGenerator: () => sessionId, }); sessions.set(sessionId, transport); + const server = await serverList.create(); transport.onclose = () => { if (transport.sessionId) sessions.delete(transport.sessionId); + void serverList.close(server).catch((error: unknown) => { + console.error("Error closing HTTP MCP session:", error); + }); }; - const server = await serverList.create(); await server.connect(transport); return await transport.handleRequest(req, res); }