diff --git a/.changeset/fix-pglite-socket-ssl-request.md b/.changeset/fix-pglite-socket-ssl-request.md new file mode 100644 index 000000000..958380139 --- /dev/null +++ b/.changeset/fix-pglite-socket-ssl-request.md @@ -0,0 +1,5 @@ +--- +'@electric-sql/pglite-socket': patch +--- + +Handle the `SSLRequest` startup packet per the PostgreSQL wire protocol: when SSL is not available, respond with `N` so the client may continue with a cleartext `StartupMessage`. Improves interoperability with JDBC clients such as DBeaver that probe TLS first without requiring manual SSL mode tweaks. See https://www.postgresql.org/docs/current/protocol-message-formats.html . diff --git a/packages/pglite-socket/src/index.ts b/packages/pglite-socket/src/index.ts index 8dfd4ce01..aeae7e26b 100644 --- a/packages/pglite-socket/src/index.ts +++ b/packages/pglite-socket/src/index.ts @@ -343,6 +343,23 @@ export class PGLiteSocketHandler extends EventTarget { let totalProcessed = 0 while (this.messageBuffer.length > 0) { + // SSLRequest: first Int32 is length (8); second Int32 is fixed 80877103. + // This and other frontend/backend message layouts are specified in PostgreSQL docs: + // https://www.postgresql.org/docs/current/protocol-message-formats.html + // Rules: server must reply 'S' or 'N' before the client sends StartupMessage. + // pglite-socket has no TLS/SSL, so always 'N' (decline SSL). + if (this.messageBuffer.length >= 8) { + const len = this.messageBuffer.readInt32BE(0) + const code = this.messageBuffer.readInt32BE(4) + if (len === 8 && code === 80877103) { + if (this.socket?.writable) { + this.socket.write(Buffer.from('N')) + } + this.messageBuffer = this.messageBuffer.slice(8) + continue + } + } + // Determine message length let messageLength = 0 let isComplete = false diff --git a/packages/pglite-socket/tests/ssl-request.test.ts b/packages/pglite-socket/tests/ssl-request.test.ts new file mode 100644 index 000000000..3bb6e6960 --- /dev/null +++ b/packages/pglite-socket/tests/ssl-request.test.ts @@ -0,0 +1,76 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { PGLiteSocketHandler } from '../src' + +/** Second Int32 of SSLRequest — https://www.postgresql.org/docs/current/protocol-message-formats.html */ +const PG_PROTOCOL_SSL_REQUEST_CODE = 80877103 + +function createNetSocketStub() { + const eventHandlers: Record void>> = {} + const socket = { + writable: true, + remoteAddress: '127.0.0.1', + remotePort: 12345, + setNoDelay: vi.fn(), + write: vi.fn(), + removeAllListeners: vi.fn(), + end: vi.fn(), + destroy: vi.fn(), + on: vi.fn((event: string, callback: (data?: unknown) => void) => { + if (!eventHandlers[event]) eventHandlers[event] = [] + eventHandlers[event].push(callback) + return socket + }), + emit(event: string, data?: unknown) { + eventHandlers[event]?.forEach((h) => h(data)) + }, + } + return socket as any +} + +function createQueryQueueStub() { + return { + enqueue: vi.fn().mockResolvedValue(0), + clearQueueForHandler: vi.fn(), + clearTransactionIfNeeded: vi.fn().mockResolvedValue(undefined), + getQueueLength: vi.fn().mockReturnValue(0), + } +} + +async function flushEventLoop(): Promise { + await new Promise((r) => setImmediate(r)) + await new Promise((r) => setImmediate(r)) +} + +describe('PGLiteSocketHandler PostgreSQL SSLRequest (protocol-message-formats)', () => { + let handler: PGLiteSocketHandler + let socketStub: ReturnType + let queryQueueStub: ReturnType + + beforeEach(() => { + queryQueueStub = createQueryQueueStub() + handler = new PGLiteSocketHandler({ + queryQueue: queryQueueStub as any, + }) + socketStub = createNetSocketStub() + }) + + afterEach(async () => { + if (handler?.isAttached) { + await handler.detach(true) + } + }) + + it("consumes SSLRequest (8 bytes) and writes 'N' without queueing PGlite protocol", async () => { + await handler.attach(socketStub) + + const sslRequest = Buffer.alloc(8) + sslRequest.writeInt32BE(8, 0) + sslRequest.writeInt32BE(PG_PROTOCOL_SSL_REQUEST_CODE, 4) + socketStub.emit('data', sslRequest) + + await flushEventLoop() + + expect(socketStub.write).toHaveBeenCalledWith(Buffer.from('N')) + expect(queryQueueStub.enqueue).not.toHaveBeenCalled() + }) +})