diff --git a/.agents/skills/deepgram-js-maintaining-sdk/SKILL.md b/.agents/skills/deepgram-js-maintaining-sdk/SKILL.md index 654db3dc..2477e572 100644 --- a/.agents/skills/deepgram-js-maintaining-sdk/SKILL.md +++ b/.agents/skills/deepgram-js-maintaining-sdk/SKILL.md @@ -27,6 +27,7 @@ Current permanently frozen entries from `.fernignore` / `AGENTS.md`: - `src/CustomClient.ts` — custom wrapper for auth prefixing, session ID propagation, browser/Node WebSocket handling, reconnect behavior, binary message fixes, and `createConnection()` aliases - `src/index.ts` — curated re-exports with backwards-compatible namespace behavior +- `src/transport.ts` — hand-maintained pluggable transport interface for SageMaker and other custom streaming transports - `scripts/fix-wire-test-imports.js`, `scripts/revert-wire-test-imports.js` — post-generation import fixups - `scripts/proxy-server.js` — development proxy - `scripts/validate-esm-build.mjs` — ESM build validation @@ -105,7 +106,7 @@ Relevant underlying commands today: ## JS-specific maintainer notes 1. **Do not edit most generated files under `src/api/` directly.** Prefer Fern input changes or post-regen patch review. -2. **`src/CustomClient.ts` is the highest-risk permanent wrapper.** It carries auth prefixing, browser subprotocol auth, custom websocket startup, binary handling, and the user-facing `createConnection()` aliases used by examples. +2. **`src/CustomClient.ts` is the highest-risk permanent wrapper.** It carries auth prefixing, browser subprotocol auth, custom websocket startup, binary handling, the pluggable transport adapter used for SageMaker/custom transports, and the user-facing `createConnection()` aliases used by examples. 3. **The repo ships both CJS and ESM.** Validate both outputs after generator changes. 4. **Browser behavior matters.** The wrapper intentionally diverges for browser WebSocket auth because browsers cannot send arbitrary socket headers. 5. **`.agents/` is permanently frozen in `.fernignore`.** Treat these skills as hand-written documentation during regeneration; Fern will not touch the folder. Keep this note aligned with `AGENTS.md` whenever the frozen-file list changes. diff --git a/.fernignore b/.fernignore index ffda964c..1fc0f9d0 100644 --- a/.fernignore +++ b/.fernignore @@ -4,6 +4,7 @@ examples # Custom wrapper src/CustomClient.ts src/index.ts +src/transport.ts # Custom Scripts scripts/fix-wire-test-imports.js @@ -23,6 +24,7 @@ tests/unit/custom-client.test.ts tests/unit/error-handling.test.ts tests/unit/multiple-keyterms.test.ts tests/unit/nodejs-version-compatibility.test.ts +tests/unit/transport-factory.test.ts tests/unit/websocket-reconnection.test.ts tests/unit/websocket-wrappers.test.ts tests/wire/websocket diff --git a/AGENTS.md b/AGENTS.md index 010c0b35..c9546659 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -24,6 +24,7 @@ How to identify: Current permanently frozen files: - `src/CustomClient.ts` — entirely custom wrapper with WebSocket management, auth providers, and session ID handling; no Fern equivalent - `src/index.ts` — curated re-export file with custom namespace handling +- `src/transport.ts` — hand-maintained pluggable transport interface for SageMaker and other custom streaming transports - `scripts/fix-wire-test-imports.js`, `scripts/revert-wire-test-imports.js` — post-generation import fixup scripts - `scripts/proxy-server.js` — development proxy server - `scripts/validate-esm-build.mjs` — ESM build validation diff --git a/src/CustomClient.ts b/src/CustomClient.ts index 94e0b829..8eb3082b 100644 --- a/src/CustomClient.ts +++ b/src/CustomClient.ts @@ -17,8 +17,15 @@ import { V1Socket as SpeakV1Socket } from "./api/resources/speak/resources/v1/cl import { mergeHeaders } from "./core/headers.js"; import { fromJson } from "./core/json.js"; import * as core from "./core/index.js"; +import * as websocketEvents from "./core/websocket/events.js"; import * as environments from "./environments.js"; import { RUNTIME } from "./core/runtime/index.js"; +import type { + DeepgramTransport, + DeepgramTransportFactory, + DeepgramTransportMessage, + DeepgramTransportRequest, +} from "./transport.js"; // Default WebSocket connection timeout in milliseconds const DEFAULT_CONNECTION_TIMEOUT_MS = 10000; @@ -28,6 +35,7 @@ const DEFAULT_CONNECTION_TIMEOUT_MS = 10000; const WEBSOCKET_OPTION_KEYS = new Set([ "Authorization", "headers", + "protocols", "debug", "reconnectAttempts", "connectionTimeoutInSeconds", @@ -157,18 +165,23 @@ class AccessTokenAuthProviderWrapper implements core.AuthProvider { * Custom wrapper around DeepgramClient that ensures the custom websocket implementation * from ws.ts is always used, even if the auto-generated code changes. */ +export interface CustomDeepgramClientOptions extends DeepgramClient.Options { + accessToken?: core.Supplier; + transportFactory?: DeepgramTransportFactory; +} + export class CustomDeepgramClient extends DeepgramClient { private _customAgent: AgentClient | undefined; private _customListen: ListenClient | undefined; private _customSpeak: SpeakClient | undefined; private readonly _sessionId: string; - constructor(options: DeepgramClient.Options & { accessToken?: core.Supplier } = {}) { + constructor(options: CustomDeepgramClientOptions = {}) { // Generate a UUID for the session ID const sessionId = generateUUID(); // Add the session ID to headers so it's included in all REST requests - const optionsWithSessionId: DeepgramClient.Options = { + const optionsWithSessionId: CustomDeepgramClientOptions = { ...options, headers: { ...options.headers, @@ -320,12 +333,445 @@ function buildQueryParams(args: Record): Record): Record { + const result: Record = {}; + + for (const [key, value] of Object.entries(headers)) { + result[key] = String(value); + } + + return result; +} + +function buildWebSocketUrl(url: string, queryParams: Record): string { + const queryString = core.url.toQueryString(queryParams, { arrayFormat: "repeat" }); + return queryString ? `${url}?${queryString}` : url; +} + +function getTransportFactory(options: DeepgramClient.Options): DeepgramTransportFactory | undefined { + return (options as CustomDeepgramClientOptions).transportFactory; +} + +class TransportWebSocketAdapter { + private _listeners: ReconnectingWebSocket.ListenersMap = { + error: [], + message: [], + open: [], + close: [], + }; + private _retryCount = -1; + private _shouldReconnect = true; + private _connectLock = false; + private _binaryType: BinaryType = "blob"; + private _closeCalled = false; + private _messageQueue: DeepgramTransportMessage[] = []; + private _connectTimeout: ReturnType | undefined; + private _transport: DeepgramTransport | undefined; + private _readyState: ReconnectingWebSocket.ReadyState; + private _ws: + | { + OPEN: typeof ReconnectingWebSocket.OPEN; + readyState: ReconnectingWebSocket.ReadyState; + ping?: (data?: string | ArrayBuffer | Blob | ArrayBufferView) => void; + } + | undefined; + + private readonly _factory: DeepgramTransportFactory; + private readonly _request: DeepgramTransportRequest; + + constructor(args: { factory: DeepgramTransportFactory; request: DeepgramTransportRequest; startClosed?: boolean }) { + this._factory = args.factory; + this._request = args.request; + this._readyState = args.startClosed ? ReconnectingWebSocket.ReadyState.CLOSED : ReconnectingWebSocket.ReadyState.CONNECTING; + + if (this._request.abortSignal) { + this._request.abortSignal.addEventListener("abort", this._handleAbort, { once: true }); + } + + if (!args.startClosed) { + void this._connect(); + } + } + + public static readonly CONNECTING = ReconnectingWebSocket.CONNECTING; + public static readonly OPEN = ReconnectingWebSocket.OPEN; + public static readonly CLOSING = ReconnectingWebSocket.CLOSING; + public static readonly CLOSED = ReconnectingWebSocket.CLOSED; + + public readonly CONNECTING: typeof ReconnectingWebSocket.CONNECTING = ReconnectingWebSocket.CONNECTING; + public readonly OPEN: typeof ReconnectingWebSocket.OPEN = ReconnectingWebSocket.OPEN; + public readonly CLOSING: typeof ReconnectingWebSocket.CLOSING = ReconnectingWebSocket.CLOSING; + public readonly CLOSED: typeof ReconnectingWebSocket.CLOSED = ReconnectingWebSocket.CLOSED; + + public onclose: ((event: websocketEvents.CloseEvent) => void) | null = null; + public onerror: ((event: websocketEvents.ErrorEvent) => void) | null = null; + public onmessage: ((event: MessageEvent) => void) | null = null; + public onopen: ((event: websocketEvents.Event) => void) | null = null; + + get binaryType(): BinaryType { + return this._binaryType; + } + + set binaryType(value: BinaryType) { + this._binaryType = value; + } + + get retryCount(): number { + return Math.max(this._retryCount, 0); + } + + get bufferedAmount(): number { + return this._messageQueue.reduce((acc, message) => { + if (typeof message === "string") { + return acc + message.length; + } + if (message instanceof Blob) { + return acc + message.size; + } + return acc + message.byteLength; + }, 0); + } + + get extensions(): string { + return ""; + } + + get protocol(): string { + return this._request.protocols[0] ?? ""; + } + + get readyState(): ReconnectingWebSocket.ReadyState { + return this._readyState; + } + + get url(): string { + return this._request.url; + } + + public close(code = 1000, reason?: string): void { + this._closeCalled = true; + this._shouldReconnect = false; + this._clearConnectTimeout(); + this._readyState = ReconnectingWebSocket.ReadyState.CLOSING; + + const transport = this._transport; + this._transport = undefined; + this._setTransportHandle(undefined); + + if (!transport) { + this._readyState = ReconnectingWebSocket.ReadyState.CLOSED; + return; + } + + void transport.close(code, reason); + this._readyState = ReconnectingWebSocket.ReadyState.CLOSED; + } + + public reconnect(code?: number, reason?: string): void { + this._shouldReconnect = true; + this._closeCalled = false; + this._retryCount = -1; + this._readyState = ReconnectingWebSocket.ReadyState.CONNECTING; + + const transport = this._transport; + this._transport = undefined; + this._setTransportHandle(undefined); + + if (transport) { + void transport.close(code, reason); + } + + void this._connect(); + } + + public send(data: DeepgramTransportMessage): void { + if (this._transport?.isOpen()) { + void this._transport.send(data); + return; + } + + this._messageQueue.push(data); + } + + public addEventListener( + type: T, + listener: websocketEvents.WebSocketEventListenerMap[T], + ): void { + if (this._listeners[type]) { + (this._listeners[type] as Array).push(listener); + } + } + + public dispatchEvent(event: websocketEvents.Event): boolean { + const listeners = this._listeners[event.type as keyof websocketEvents.WebSocketEventListenerMap]; + + if (listeners) { + for (const listener of listeners) { + this._callEventListener(event as never, listener as never); + } + } + + return true; + } + + public removeEventListener( + type: T, + listener: websocketEvents.WebSocketEventListenerMap[T], + ): void { + if (this._listeners[type]) { + this._listeners[type] = this._listeners[type].filter((registered) => registered !== listener) as never; + } + } + + private _debug(...args: unknown[]): void { + if (this._request.debug) { + // biome-ignore lint/suspicious/noConsole: transport debug logging mirrors websocket debug logging + console.log.apply(console, ["DG-TRANSPORT>", ...args]); + } + } + + private _handleAbort = () => { + if (this._closeCalled) { + return; + } + + this._debug("abort signal fired"); + this._closeCalled = true; + this._shouldReconnect = false; + this._clearConnectTimeout(); + + const transport = this._transport; + this._transport = undefined; + this._setTransportHandle(undefined); + + if (transport) { + void transport.close(1000, "aborted"); + } + + this._readyState = ReconnectingWebSocket.ReadyState.CLOSED; + this._emitClose(1000, "aborted"); + }; + + private async _connect(): Promise { + if (this._connectLock || !this._shouldReconnect || this._request.abortSignal?.aborted) { + return; + } + + if (this._retryCount >= this._request.reconnectAttempts) { + this._debug("max retries reached", this._retryCount, ">=", this._request.reconnectAttempts); + return; + } + + this._connectLock = true; + this._retryCount++; + this._readyState = ReconnectingWebSocket.ReadyState.CONNECTING; + this._clearConnectTimeout(); + + try { + const transport = await this._factory(this._request.url, this._request.headers, this._request); + + if (this._closeCalled || this._request.abortSignal?.aborted) { + this._connectLock = false; + await transport.close(1000, "aborted"); + return; + } + + this._transport = transport; + this._setTransportHandle(transport); + this._bindTransport(transport); + this._armConnectTimeout(); + this._connectLock = false; + + if (transport.isOpen()) { + this._handleOpen(transport); + } + } catch (error) { + this._connectLock = false; + this._handleError(error instanceof Error ? error : new Error(String(error))); + } + } + + private _bindTransport(transport: DeepgramTransport): void { + transport.onOpen(() => { + if (this._transport !== transport) { + return; + } + + this._handleOpen(transport); + }); + + transport.onMessage((message) => { + if (this._transport !== transport) { + return; + } + + this._handleMessage(message); + }); + + transport.onError((error) => { + if (this._transport !== transport) { + return; + } + + this._handleError(error); + }); + + transport.onClose((event) => { + if (this._transport !== transport) { + return; + } + + this._handleClose(event.code ?? 1000, event.reason ?? ""); + }); + } + + private _armConnectTimeout(): void { + const timeoutMs = + this._request.connectionTimeoutInSeconds != null + ? this._request.connectionTimeoutInSeconds * 1000 + : DEFAULT_CONNECTION_TIMEOUT_MS; + + this._connectTimeout = setTimeout(() => { + this._handleError(new Error("TIMEOUT")); + }, timeoutMs); + } + + private _clearConnectTimeout(): void { + if (this._connectTimeout != null) { + clearTimeout(this._connectTimeout); + this._connectTimeout = undefined; + } + } + + private _handleOpen(transport: DeepgramTransport): void { + if (this._transport !== transport || this._readyState === ReconnectingWebSocket.ReadyState.OPEN) { + return; + } + + this._debug("open event"); + this._clearConnectTimeout(); + this._readyState = ReconnectingWebSocket.ReadyState.OPEN; + + const queued = [...this._messageQueue]; + this._messageQueue = []; + for (const message of queued) { + void transport.send(message); + } + + const event = new websocketEvents.Event("open", this); + if (this.onopen) { + this.onopen(event); + } + this._listeners.open.forEach((listener) => this._callEventListener(event, listener)); + } + + private _handleMessage(message: DeepgramTransportMessage): void { + const event = { type: "message", data: message, target: this } as unknown as MessageEvent; + + if (this.onmessage) { + this.onmessage(event); + } + this._listeners.message.forEach((listener) => this._callEventListener(event, listener)); + } + + private _handleError(error: Error): void { + this._debug("error event", error.message); + this._clearConnectTimeout(); + this._readyState = ReconnectingWebSocket.ReadyState.CLOSED; + + const event = new websocketEvents.ErrorEvent(error, this); + if (this.onerror) { + this.onerror(event); + } + this._listeners.error.forEach((listener) => this._callEventListener(event, listener)); + + const transport = this._transport; + this._transport = undefined; + this._setTransportHandle(undefined); + + if (transport) { + void transport.close(1011, error.message); + } + + if (this._shouldReconnect && !this._closeCalled) { + void this._connect(); + } + } + + private _handleClose(code: number, reason: string): void { + this._debug("close event", code, reason); + this._clearConnectTimeout(); + this._transport = undefined; + this._readyState = ReconnectingWebSocket.ReadyState.CLOSED; + this._setTransportHandle(undefined); + + if (code === 1000) { + this._shouldReconnect = false; + } + + this._emitClose(code, reason); + + if (this._shouldReconnect && !this._closeCalled) { + void this._connect(); + } + } + + private _emitClose(code: number, reason: string): void { + const event = new websocketEvents.CloseEvent(code, reason, this); + + if (this.onclose) { + this.onclose(event); + } + this._listeners.close.forEach((listener) => this._callEventListener(event, listener)); + } + + private _setTransportHandle(transport: DeepgramTransport | undefined): void { + if (!transport) { + this._ws = undefined; + return; + } + + this._ws = { + OPEN: this.OPEN, + get readyState() { + return transport.isOpen() + ? ReconnectingWebSocket.ReadyState.OPEN + : ReconnectingWebSocket.ReadyState.CLOSED; + }, + ping: transport.ping + ? (data?: string | ArrayBuffer | Blob | ArrayBufferView) => { + void transport.ping?.(data); + } + : undefined, + }; + } + + private _callEventListener( + event: websocketEvents.WebSocketEventMap[T], + listener: websocketEvents.WebSocketEventListenerMap[T], + ): void { + if (typeof listener === "object" && listener && "handleEvent" in listener) { + (listener as { handleEvent: (event: websocketEvents.WebSocketEventMap[T]) => void }).handleEvent(event); + } else { + (listener as (event: websocketEvents.WebSocketEventMap[T]) => void)(event); + } + } +} + /** * Helper function to get WebSocket class and handle headers/protocols based on runtime. * In Node.js, use the 'ws' library which supports headers. * In browser, use Sec-WebSocket-Protocol for authentication since headers aren't supported. */ -function getWebSocketOptions(headers: Record): { +function getWebSocketOptions(headers: Record, requestedProtocols: string[]): { WebSocket?: any; headers?: Record; protocols?: string[]; @@ -342,6 +788,9 @@ function getWebSocketOptions(headers: Record): { if (RUNTIME.type === "node" && NodeWebSocket) { options.WebSocket = NodeWebSocket; options.headers = headers; + if (requestedProtocols.length > 0) { + options.protocols = requestedProtocols; + } } else if (isBrowser) { // In browser, native WebSocket doesn't support custom headers // Extract Authorization header and use Sec-WebSocket-Protocol instead @@ -357,7 +806,7 @@ function getWebSocketOptions(headers: Record): { options.headers = browserHeaders; // Build protocols array for browser WebSocket - const protocols: string[] = []; + const protocols = [...requestedProtocols]; // If we have an Authorization header, extract the token and format as protocols // Deepgram expects: @@ -390,6 +839,9 @@ function getWebSocketOptions(headers: Record): { } else { // Fallback for other environments options.headers = headers; + if (requestedProtocols.length > 0) { + options.protocols = requestedProtocols; + } } return options; @@ -479,6 +931,8 @@ async function createWebSocketConnection({ urlPath, environmentKey, queryParams, + protocols, + service, headers, debug, reconnectAttempts, @@ -487,8 +941,10 @@ async function createWebSocketConnection({ }: { options: DeepgramClient.Options; urlPath: string; - environmentKey: 'agent' | 'production'; + environmentKey: "agent" | "production"; queryParams: Record; + protocols?: string | string[]; + service: DeepgramTransportRequest["service"]; headers?: Record; debug?: boolean; reconnectAttempts?: number; @@ -510,9 +966,7 @@ async function createWebSocketConnection({ // Resolve any Suppliers in headers to actual values const _headers = await resolveHeaders(mergedHeaders); - - // Get WebSocket options with proper header handling - const wsOptions = getWebSocketOptions(_headers); + const normalizedProtocols = normalizeProtocols(protocols); // Get the appropriate base URL for the environment const baseUrl = (await core.Supplier.get(options.baseUrl)) ?? @@ -521,9 +975,37 @@ async function createWebSocketConnection({ environments.DeepgramEnvironment.Production )[environmentKey]; + const url = core.url.join(baseUrl, urlPath); + const fullUrl = buildWebSocketUrl(url, queryParams); + const transportFactory = getTransportFactory(options); + + if (transportFactory) { + const request: DeepgramTransportRequest = { + url: fullUrl, + headers: stringifyHeaders(_headers), + protocols: normalizedProtocols, + path: urlPath, + service, + queryParams, + debug: debug ?? false, + reconnectAttempts: reconnectAttempts ?? 30, + connectionTimeoutInSeconds, + abortSignal, + }; + + return new TransportWebSocketAdapter({ + factory: transportFactory, + request, + startClosed: true, + }) as unknown as ReconnectingWebSocket; + } + + // Get WebSocket options with proper header handling + const wsOptions = getWebSocketOptions(_headers, normalizedProtocols); + // Create and return the ReconnectingWebSocket return new ReconnectingWebSocket({ - url: core.url.join(baseUrl, urlPath), + url, protocols: wsOptions.protocols ?? [], queryParameters: queryParams, headers: wsOptions.headers, @@ -548,13 +1030,15 @@ async function createWebSocketConnection({ */ class WrappedAgentV1Client extends AgentV1Client { public async connect(args: Omit & { Authorization?: string } = {}): Promise { - const { headers, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; + const { headers, protocols, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; const socket = await createWebSocketConnection({ options: this._options, urlPath: "/v1/agent/converse", environmentKey: 'agent', queryParams: buildQueryParams(args as Record), + protocols, + service: "agent.v1", headers, debug, reconnectAttempts, @@ -631,13 +1115,15 @@ class WrappedAgentV1Socket extends AgentV1Socket { */ class WrappedListenV1Client extends ListenV1Client { public async connect(args: Omit & { Authorization?: string }): Promise { - const { headers, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; + const { headers, protocols, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; const socket = await createWebSocketConnection({ options: this._options, urlPath: "/v1/listen", environmentKey: 'production', queryParams: buildQueryParams(args as Record), + protocols, + service: "listen.v1", headers, debug, reconnectAttempts, @@ -718,13 +1204,15 @@ class WrappedListenV2Client extends ListenV2Client { keyterm?: string | string[]; } ): Promise { - const { headers, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; + const { headers, protocols, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; const socket = await createWebSocketConnection({ options: this._options, urlPath: "/v2/listen", environmentKey: 'production', queryParams: buildQueryParams(args as Record), + protocols, + service: "listen.v2", headers, debug, reconnectAttempts, @@ -836,13 +1324,15 @@ class WrappedListenV2Socket extends ListenV2Socket { */ class WrappedSpeakV1Client extends SpeakV1Client { public async connect(args: Omit & { Authorization?: string }): Promise { - const { headers, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; + const { headers, protocols, debug, reconnectAttempts, connectionTimeoutInSeconds, abortSignal } = args; const socket = await createWebSocketConnection({ options: this._options, urlPath: "/v1/speak", environmentKey: 'production', queryParams: buildQueryParams(args as Record), + protocols, + service: "speak.v1", headers, debug, reconnectAttempts, diff --git a/src/index.ts b/src/index.ts index aba2bad2..a7b806e4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -11,6 +11,8 @@ export * from "./api/resources/index.js"; export type { BaseClientOptions, BaseRequestOptions } from "./BaseClient.js"; export { DeepgramClient as DefaultDeepgramClient } from "./Client.js"; export { CustomDeepgramClient as DeepgramClient } from "./CustomClient.js"; +export type { CustomDeepgramClientOptions } from "./CustomClient.js"; export { DeepgramEnvironment, type DeepgramEnvironmentUrls } from "./environments.js"; export { DeepgramError, DeepgramTimeoutError } from "./errors/index.js"; +export * from "./transport.js"; export * from "./exports.js"; diff --git a/src/transport.ts b/src/transport.ts new file mode 100644 index 00000000..e8e5da3c --- /dev/null +++ b/src/transport.ts @@ -0,0 +1,81 @@ +/** + * Message payloads exchanged over Deepgram streaming transports. + * + * A transport can carry JSON control messages as strings and audio or synthesized + * audio as binary payloads. + */ +export type DeepgramTransportMessage = string | ArrayBuffer | Blob | ArrayBufferView; + +/** Close metadata reported by a custom transport. */ +export interface DeepgramTransportCloseEvent { + code?: number; + reason?: string; +} + +/** + * Metadata passed to a transport factory when a streaming connection is created. + * + * The first two factory arguments intentionally match the Python and Java SDKs: + * `factory(url, headers)`. JavaScript also passes this third metadata object so + * custom transports can inspect the target streaming API and connection settings. + */ +export interface DeepgramTransportRequest { + /** Full Deepgram websocket URL including query parameters. */ + url: string; + /** Resolved request headers including auth and session headers. */ + headers: Record; + /** Requested websocket subprotocols, if any. */ + protocols: string[]; + /** Deepgram websocket path (for example `/v1/listen`). */ + path: string; + /** Streaming API being targeted. */ + service: "agent.v1" | "listen.v1" | "listen.v2" | "speak.v1"; + /** Query parameters before they are encoded into the URL. */ + queryParams: Record; + /** Whether debug logging was requested for the connection. */ + debug: boolean; + /** Requested reconnect attempts for this connection. */ + reconnectAttempts: number; + /** Optional connection timeout in seconds. */ + connectionTimeoutInSeconds?: number; + /** Optional abort signal for the connection attempt. */ + abortSignal?: AbortSignal; +} + +/** + * Transport interface for replacing the SDK's default websocket transport. + * + * This is the seam used by SageMaker support and other non-websocket streaming + * implementations. The SDK adapts this transport to its existing socket APIs, so + * callers still use `client.listen.v1.createConnection()` and related methods. + */ +export interface DeepgramTransport { + /** Send either a JSON string or binary payload to the transport. */ + send(data: DeepgramTransportMessage): void | Promise; + /** Register a listener fired once the transport is ready to exchange messages. */ + onOpen(listener: () => void): void; + /** Register a listener for inbound text or binary messages. */ + onMessage(listener: (message: DeepgramTransportMessage) => void): void; + /** Register a listener for transport-level errors. */ + onError(listener: (error: Error) => void): void; + /** Register a listener for transport close events. */ + onClose(listener: (event: DeepgramTransportCloseEvent) => void): void; + /** Returns true while the transport is open and able to send data. */ + isOpen(): boolean; + /** Close the transport gracefully. */ + close(code?: number, reason?: string): void | Promise; + /** Optional ping hook for transports that expose an explicit keepalive primitive. */ + ping?(data?: string | ArrayBuffer | Blob | ArrayBufferView): void | Promise; +} + +/** + * Factory for creating custom streaming transports. + * + * The first two arguments mirror the Python and Java SDKs. JavaScript also passes + * a third metadata argument for transports that need more connection context. + */ +export type DeepgramTransportFactory = ( + url: string, + headers: Record, + request: DeepgramTransportRequest, +) => DeepgramTransport | Promise; diff --git a/tests/unit/transport-factory.test.ts b/tests/unit/transport-factory.test.ts new file mode 100644 index 00000000..55be560b --- /dev/null +++ b/tests/unit/transport-factory.test.ts @@ -0,0 +1,128 @@ +import { describe, expect, it } from "vitest"; + +import { DeepgramClient, type DeepgramTransport, type DeepgramTransportFactory } from "../../src"; + +type ListenerMap = { + open?: () => void; + message?: (message: string | ArrayBuffer | Blob | ArrayBufferView) => void; + error?: (error: Error) => void; + close?: (event: { code?: number; reason?: string }) => void; +}; + +class FakeTransport implements DeepgramTransport { + public readonly listeners: ListenerMap = {}; + public readonly sent: Array = []; + public closed = false; + public pingPayloads: Array = []; + private open = false; + + public send(data: string | ArrayBuffer | Blob | ArrayBufferView): void { + this.sent.push(data); + } + + public onOpen(listener: () => void): void { + this.listeners.open = listener; + } + + public onMessage(listener: (message: string | ArrayBuffer | Blob | ArrayBufferView) => void): void { + this.listeners.message = listener; + } + + public onError(listener: (error: Error) => void): void { + this.listeners.error = listener; + } + + public onClose(listener: (event: { code?: number; reason?: string }) => void): void { + this.listeners.close = listener; + } + + public isOpen(): boolean { + return this.open; + } + + public close(code?: number, reason?: string): void { + this.closed = true; + this.open = false; + this.listeners.close?.({ code, reason }); + } + + public ping(data?: string | ArrayBuffer | Blob | ArrayBufferView): void { + this.pingPayloads.push(data); + } + + public emitOpen(): void { + this.open = true; + this.listeners.open?.(); + } + + public emitMessage(message: string | ArrayBuffer | Blob | ArrayBufferView): void { + this.listeners.message?.(message); + } +} + +describe("transportFactory", () => { + it("routes listen websocket connections through the custom transport", async () => { + const created: Array<{ url: string; headers: Record; request: { service: string } }> = []; + const transport = new FakeTransport(); + + const transportFactory: DeepgramTransportFactory = (url, headers, request) => { + created.push({ url, headers, request: { service: request.service } }); + return transport; + }; + + const client = new DeepgramClient({ + apiKey: "test-api-key", + transportFactory, + }); + + const socket = await client.listen.v1.createConnection({ model: "nova-3" }); + let opened = false; + let receivedType: string | undefined; + + socket.on("open", () => { + opened = true; + }); + socket.on("message", (message: any) => { + receivedType = message.type; + }); + + socket.connect(); + await Promise.resolve(); + + expect(created).toHaveLength(1); + expect(created[0]?.url).toContain("wss://api.deepgram.com/v1/listen"); + expect(created[0]?.url).toContain("model=nova-3"); + expect(created[0]?.headers.Authorization ?? created[0]?.headers.authorization).toBe("Token test-api-key"); + expect(created[0]?.headers["x-deepgram-session-id"]).toBeTruthy(); + expect(created[0]?.request.service).toBe("listen.v1"); + + transport.emitOpen(); + expect(opened).toBe(true); + + socket.sendMedia(new Uint8Array([1, 2, 3])); + expect(transport.sent).toHaveLength(1); + expect(transport.sent[0]).toBeInstanceOf(Uint8Array); + + transport.emitMessage('{"type":"Results"}'); + expect(receivedType).toBe("Results"); + + socket.close(); + expect(transport.closed).toBe(true); + }); + + it("exposes transport ping through listen.v2 sockets when available", async () => { + const transport = new FakeTransport(); + const client = new DeepgramClient({ + apiKey: "test-api-key", + transportFactory: () => transport, + }); + + const socket = await client.listen.v2.createConnection({ model: "flux-general-en" }); + socket.connect(); + await Promise.resolve(); + transport.emitOpen(); + + socket.ping("keepalive"); + expect(transport.pingPayloads).toEqual(["keepalive"]); + }); +});