Support passing a user-specific token to MCP server requests

which can subsequently be accessed by downstream calls #20
This commit is contained in:
Dominik Jain 2025-12-11 12:26:42 +01:00
parent 19725f5495
commit df7245cb9d
2 changed files with 89 additions and 35 deletions

View File

@ -1,4 +1,5 @@
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { AsyncLocalStorage } from "async_hooks";
import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js";
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
import { ExecuteCodeTool } from "./tools/ExecuteCodeTool";
@ -13,6 +14,13 @@ import { ImportImageTool } from "./tools/ImportImageTool";
import { ReplServer } from "./ReplServer";
import { ApiDocs } from "./ApiDocs";
/**
* Session context for request-scoped data.
*/
export interface SessionContext {
userToken?: string;
}
export class PenpotMcpServer {
private readonly logger = createLogger("PenpotMcpServer");
private readonly server: McpServer;
@ -23,9 +31,14 @@ export class PenpotMcpServer {
private readonly replServer: ReplServer;
private apiDocs: ApiDocs;
/**
* Manages session-specific context, particularly user tokens for each request.
*/
private readonly sessionContext = new AsyncLocalStorage<SessionContext>();
private readonly transports = {
streamable: {} as Record<string, StreamableHTTPServerTransport>,
sse: {} as Record<string, SSEServerTransport>,
sse: {} as Record<string, { transport: SSEServerTransport; userToken?: string }>,
};
constructor(
@ -59,6 +72,15 @@ export class PenpotMcpServer {
return instructions;
}
/**
* Retrieves the current session context.
*
* @returns The session context for the current request, or undefined if not in a request context
*/
public getSessionContext(): SessionContext | undefined {
return this.sessionContext.getStore();
}
private registerTools(): void {
const toolInstances: Tool<any>[] = [
new ExecuteCodeTool(this),
@ -88,51 +110,70 @@ export class PenpotMcpServer {
}
private setupHttpEndpoints(): void {
/**
* Modern Streamable HTTP connection endpoint
*/
this.app.all("/mcp", async (req: any, res: any) => {
const { randomUUID } = await import("node:crypto");
const userToken = req.query.userToken as string | undefined;
const sessionId = req.headers["mcp-session-id"] as string | undefined;
let transport: StreamableHTTPServerTransport;
await this.sessionContext.run({ userToken }, async () => {
const { randomUUID } = await import("node:crypto");
if (sessionId && this.transports.streamable[sessionId]) {
transport = this.transports.streamable[sessionId];
} else {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (id: string) => {
this.transports.streamable[id] = transport;
},
const sessionId = req.headers["mcp-session-id"] as string | undefined;
let transport: StreamableHTTPServerTransport;
if (sessionId && this.transports.streamable[sessionId]) {
transport = this.transports.streamable[sessionId];
} else {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (id: string) => {
this.transports.streamable[id] = transport;
},
});
transport.onclose = () => {
if (transport.sessionId) {
delete this.transports.streamable[transport.sessionId];
}
};
await this.server.connect(transport);
}
await transport.handleRequest(req, res, req.body);
});
});
/**
* Legacy SSE connection endpoint
*/
this.app.get("/sse", async (req: any, res: any) => {
const userToken = req.query.userToken as string | undefined;
await this.sessionContext.run({ userToken }, async () => {
const transport = new SSEServerTransport("/messages", res);
this.transports.sse[transport.sessionId] = { transport, userToken };
res.on("close", () => {
delete this.transports.sse[transport.sessionId];
});
transport.onclose = () => {
if (transport.sessionId) {
delete this.transports.streamable[transport.sessionId];
}
};
await this.server.connect(transport);
}
await transport.handleRequest(req, res, req.body);
});
this.app.get("/sse", async (_req: any, res: any) => {
const transport = new SSEServerTransport("/messages", res);
this.transports.sse[transport.sessionId] = transport;
res.on("close", () => {
delete this.transports.sse[transport.sessionId];
});
await this.server.connect(transport);
});
/**
* SSE message POST endpoint (using previously established session)
*/
this.app.post("/messages", async (req: any, res: any) => {
const sessionId = req.query.sessionId as string;
const transport = this.transports.sse[sessionId];
const session = this.transports.sse[sessionId];
if (transport) {
await transport.handlePostMessage(req, res, req.body);
if (session) {
await this.sessionContext.run({ userToken: session.userToken }, async () => {
await session.transport.handlePostMessage(req, res, req.body);
});
} else {
res.status(400).send("No transport found for sessionId");
}

View File

@ -1,7 +1,7 @@
import { z } from "zod";
import "reflect-metadata";
import { TextResponse, ToolResponse } from "./ToolResponse";
import type { PenpotMcpServer } from "./PenpotMcpServer";
import type { PenpotMcpServer, SessionContext } from "./PenpotMcpServer";
import { createLogger } from "./logger";
/**
@ -38,6 +38,10 @@ export abstract class Tool<TArgs extends object> {
let argsInstance: TArgs = args as TArgs;
this.logger.info("Executing tool: %s; arguments: %s", this.getToolName(), this.formatArgs(argsInstance));
// TODO: Remove; testing only
const sessionContext = this.mcpServer.getSessionContext();
this.logger.info("Session context: %s", sessionContext ? JSON.stringify(sessionContext) : "none");
// execute the actual tool logic
let result = await this.executeCore(argsInstance);
@ -89,6 +93,15 @@ export abstract class Tool<TArgs extends object> {
return formatted.length > 0 ? "\n" + formatted.join("\n") : "{}";
}
/**
* Retrieves the current session context.
*
* @returns The session context for the current request, or undefined if not in a request context
*/
protected getSessionContext(): SessionContext | undefined {
return this.mcpServer.getSessionContext();
}
public getInputSchema() {
return this.inputSchema;
}