Skip to content

Commit

Permalink
Merge pull request #134 from modelcontextprotocol/justin/sse-auth
Browse files Browse the repository at this point in the history
Use `eventsource` package, to permit custom headers for SSE
  • Loading branch information
jspahrsummers authored Jan 23, 2025
2 parents 99fe193 + 87bfb61 commit edfdea5
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 17 deletions.
28 changes: 20 additions & 8 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
},
"dependencies": {
"content-type": "^1.0.5",
"eventsource": "^3.0.2",
"raw-body": "^3.0.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.24.1"
Expand All @@ -61,7 +62,6 @@
"@types/node": "^22.0.2",
"@types/ws": "^8.5.12",
"eslint": "^9.8.0",
"eventsource": "^2.0.2",
"express": "^4.19.2",
"jest": "^29.7.0",
"ts-jest": "^29.2.4",
Expand Down
3 changes: 0 additions & 3 deletions src/cli.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import EventSource from "eventsource";
import WebSocket from "ws";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
(global as any).EventSource = EventSource;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(global as any).WebSocket = WebSocket;

Expand Down
287 changes: 287 additions & 0 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import { createServer, type IncomingMessage, type Server } from "http";
import { AddressInfo } from "net";
import { JSONRPCMessage } from "../types.js";
import { SSEClientTransport } from "./sse.js";

describe("SSEClientTransport", () => {
let server: Server;
let transport: SSEClientTransport;
let baseUrl: URL;
let lastServerRequest: IncomingMessage;
let sendServerMessage: ((message: string) => void) | null = null;

beforeEach((done) => {
// Reset state
lastServerRequest = null as unknown as IncomingMessage;
sendServerMessage = null;

// Create a test server that will receive the EventSource connection
server = createServer((req, res) => {
lastServerRequest = req;

// Send SSE headers
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});

// Send the endpoint event
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);

// Store reference to send function for tests
sendServerMessage = (message: string) => {
res.write(`data: ${message}\n\n`);
};

// Handle request body for POST endpoints
if (req.method === "POST") {
let body = "";
req.on("data", (chunk) => {
body += chunk;
});
req.on("end", () => {
(req as IncomingMessage & { body: string }).body = body;
res.end();
});
}
});

// Start server on random port
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
done();
});
});

afterEach(async () => {
await transport.close();
await server.close();
});

describe("connection handling", () => {
it("establishes SSE connection and receives endpoint", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();

expect(lastServerRequest.headers.accept).toBe("text/event-stream");
expect(lastServerRequest.method).toBe("GET");
});

it("rejects if server returns non-200 status", async () => {
// Create a server that returns 403
server.close();
await new Promise((resolve) => server.on("close", resolve));

server = createServer((req, res) => {
res.writeHead(403);
res.end();
});

await new Promise<void>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});

transport = new SSEClientTransport(baseUrl);
await expect(transport.start()).rejects.toThrow();
});

it("closes EventSource connection on close()", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();

const closePromise = new Promise((resolve) => {
lastServerRequest.on("close", resolve);
});

await transport.close();
await closePromise;
});
});

describe("message handling", () => {
it("receives and parses JSON-RPC messages", async () => {
const receivedMessages: JSONRPCMessage[] = [];
transport = new SSEClientTransport(baseUrl);
transport.onmessage = (msg) => receivedMessages.push(msg);

await transport.start();

const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: { foo: "bar" },
};

sendServerMessage!(JSON.stringify(testMessage));

// Wait for message processing
await new Promise((resolve) => setTimeout(resolve, 50));

expect(receivedMessages).toHaveLength(1);
expect(receivedMessages[0]).toEqual(testMessage);
});

it("handles malformed JSON messages", async () => {
const errors: Error[] = [];
transport = new SSEClientTransport(baseUrl);
transport.onerror = (err) => errors.push(err);

await transport.start();

sendServerMessage!("invalid json");

// Wait for message processing
await new Promise((resolve) => setTimeout(resolve, 50));

expect(errors).toHaveLength(1);
expect(errors[0].message).toMatch(/JSON/);
});

it("handles messages via POST requests", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();

const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: { foo: "bar" },
};

await transport.send(testMessage);

// Wait for request processing
await new Promise((resolve) => setTimeout(resolve, 50));

expect(lastServerRequest.method).toBe("POST");
expect(lastServerRequest.headers["content-type"]).toBe(
"application/json",
);
expect(
JSON.parse(
(lastServerRequest as IncomingMessage & { body: string }).body,
),
).toEqual(testMessage);
});

it("handles POST request failures", async () => {
// Create a server that returns 500 for POST
server.close();
await new Promise((resolve) => server.on("close", resolve));

server = createServer((req, res) => {
if (req.method === "GET") {
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
} else {
res.writeHead(500);
res.end("Internal error");
}
});

await new Promise<void>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});

transport = new SSEClientTransport(baseUrl);
await transport.start();

const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: {},
};

await expect(transport.send(testMessage)).rejects.toThrow(/500/);
});
});

describe("header handling", () => {
it("uses custom fetch implementation from EventSourceInit to add auth headers", async () => {
const authToken = "Bearer test-token";

// Create a fetch wrapper that adds auth header
const fetchWithAuth = (url: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
headers.set("Authorization", authToken);
return fetch(url.toString(), { ...init, headers });
};

transport = new SSEClientTransport(baseUrl, {
eventSourceInit: {
fetch: fetchWithAuth,
},
});

await transport.start();

// Verify the auth header was received by the server
expect(lastServerRequest.headers.authorization).toBe(authToken);
});

it("passes custom headers to fetch requests", async () => {
const customHeaders = {
Authorization: "Bearer test-token",
"X-Custom-Header": "custom-value",
};

transport = new SSEClientTransport(baseUrl, {
requestInit: {
headers: customHeaders,
},
});

await transport.start();

// Mock fetch for the message sending test
global.fetch = jest.fn().mockResolvedValue({
ok: true,
});

const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};

await transport.send(message);

// Verify fetch was called with correct headers
expect(global.fetch).toHaveBeenCalledWith(
expect.any(URL),
expect.objectContaining({
headers: expect.any(Headers),
}),
);

const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1]
.headers;
expect(calledHeaders.get("Authorization")).toBe(
customHeaders.Authorization,
);
expect(calledHeaders.get("X-Custom-Header")).toBe(
customHeaders["X-Custom-Header"],
);
expect(calledHeaders.get("content-type")).toBe("application/json");
});
});
});
Loading

0 comments on commit edfdea5

Please sign in to comment.