Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { Config } from "../config.d.ts";
import { normalizeVerifiedConfig } from "./config.js";
import { TOOLS } from "./tools/index.js";
import { RESOURCE_TEMPLATES } from "./mcp/resources.js";
import { registerServerCleanup } from "./server.js";

import {
ListResourcesRequestSchema,
Expand Down Expand Up @@ -145,6 +146,9 @@ export default function ({ config }: { config: z.infer<typeof configSchema> }) {
// Create the context, passing server instance and config
const contextId = randomUUID();
const context = new Context(server.server, internalConfig, contextId);
registerServerCleanup(server.server, () =>
context.getSessionManager().closeAllSessions(),
);

server.server.registerCapabilities({
resources: {
Expand Down
49 changes: 49 additions & 0 deletions src/server.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { describe, expect, it, vi } from "vitest";
import type { Server } from "@modelcontextprotocol/sdk/server/index.js";

import { registerServerCleanup, ServerList } from "./server.js";

function fakeServer(closeImpl?: () => Promise<void> | void): Server {
return {
close: vi.fn(closeImpl ?? (async () => undefined)),
} as unknown as Server;
}

describe("ServerList lifecycle cleanup", () => {
it("runs registered session cleanup before closing a server", async () => {
const cleanup = vi.fn(async () => undefined);
const server = fakeServer();
registerServerCleanup(server, cleanup);

const serverList = new ServerList(async () => server);
const created = await serverList.create();

await serverList.close(created);

expect(cleanup).toHaveBeenCalledOnce();
expect(server.close).toHaveBeenCalledOnce();
expect(cleanup.mock.invocationCallOrder[0]).toBeLessThan(
vi.mocked(server.close).mock.invocationCallOrder[0],
);
});

it("keeps closeAll idempotent when a transport close re-enters ServerList.close", async () => {
let created!: Server;
const serverList = new ServerList(async () => {
created = fakeServer(async () => {
await serverList.close(created);
});
return created;
});
const cleanup = vi.fn(async () => undefined);

const server = await serverList.create();
registerServerCleanup(server, cleanup);

await serverList.closeAll();
await serverList.closeAll();

expect(cleanup).toHaveBeenCalledOnce();
expect(server.close).toHaveBeenCalledOnce();
});
});
32 changes: 29 additions & 3 deletions src/server.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import { Server } from "@modelcontextprotocol/sdk/server/index.js";

type ServerCleanup = () => Promise<void> | void;

const serverCleanups = new WeakMap<Server, ServerCleanup>();

export function registerServerCleanup(
server: Server,
cleanup: ServerCleanup,
) {
serverCleanups.set(server, cleanup);
}

export class ServerList {
private _servers: Server[] = [];
private _closing = new WeakSet<Server>();
private _serverFactory: () => Promise<Server>;

constructor(serverFactory: () => Promise<Server>) {
Expand All @@ -15,12 +27,26 @@ export class ServerList {
}

async close(server: Server) {
await server.close();
if (this._closing.has(server)) return;

const index = this._servers.indexOf(server);
if (index !== -1) this._servers.splice(index, 1);
if (index === -1) return;

this._closing.add(server);
this._servers.splice(index, 1);

try {
await serverCleanups.get(server)?.();
} finally {
try {
await server.close();
} finally {
this._closing.delete(server);
}
}
}

async closeAll() {
await Promise.all(this._servers.map((server) => server.close()));
await Promise.all([...this._servers].map((server) => this.close(server)));
}
}
5 changes: 4 additions & 1 deletion src/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ async function handleStreamable(
sessionIdGenerator: () => sessionId,
});
sessions.set(sessionId, transport);
const server = await serverList.create();
transport.onclose = () => {
if (transport.sessionId) sessions.delete(transport.sessionId);
void serverList.close(server).catch((error: unknown) => {
console.error("Error closing HTTP MCP session:", error);
});
};
const server = await serverList.create();
await server.connect(transport);
return await transport.handleRequest(req, res);
}
Expand Down