diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/native.ts b/rivetkit-typescript/packages/rivetkit/src/registry/native.ts index 633bddd9df..60b74faaba 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/native.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/native.ts @@ -711,7 +711,7 @@ async function callNative(invoke: () => Promise): Promise { } } -function callNativeSync(invoke: () => T): T { +export function callNativeSync(invoke: () => T): T { try { return invoke(); } catch (error) { @@ -1206,7 +1206,7 @@ function toActorKey( ); } -class NativeConnAdapter { +export class NativeConnAdapter { #runtime: CoreRuntime; #conn: ConnHandle; #schemas: NativeValidationConfig; @@ -2368,6 +2368,88 @@ class TrackedWebSocketHandleAdapter implements UniversalWebSocket { } } +class NativeConnectionMap implements ReadonlyMap { + #runtime: CoreRuntime; + #ctx: ActorContextHandle; + #schemas: NativeValidationConfig; + + constructor( + runtime: CoreRuntime, + ctx: ActorContextHandle, + schemas: NativeValidationConfig, + ) { + this.#runtime = runtime; + this.#ctx = ctx; + this.#schemas = schemas; + } + + #connToAdapter(conn: ConnHandle): NativeConnAdapter { + return new NativeConnAdapter( + this.#runtime, + conn, + this.#schemas, + this.#ctx, + (connId) => + callNativeSync(() => + this.#runtime.actorQueueHibernationRemoval( + this.#ctx, + connId, + ), + ), + ); + } + + get size(): number { + return callNativeSync(() => this.#runtime.actorConns(this.#ctx)).length; + } + + get(key: string): NativeConnAdapter | undefined { + const conns = callNativeSync(() => this.#runtime.actorConns(this.#ctx)); + const conn = conns.find( + (c) => this.#runtime.connId(c) === key, + ); + if (!conn) return undefined; + return this.#connToAdapter(conn); + } + + has(key: string): boolean { + const conns = callNativeSync(() => this.#runtime.actorConns(this.#ctx)); + return conns.some((c) => this.#runtime.connId(c) === key); + } + + keys(): MapIterator { + const conns = callNativeSync(() => this.#runtime.actorConns(this.#ctx)); + return conns.map((c) => this.#runtime.connId(c))[Symbol.iterator]() as MapIterator; + } + + values(): MapIterator { + const conns = callNativeSync(() => this.#runtime.actorConns(this.#ctx)); + return conns.map((c) => this.#connToAdapter(c))[Symbol.iterator]() as MapIterator; + } + + entries(): MapIterator<[string, NativeConnAdapter]> { + const conns = callNativeSync(() => this.#runtime.actorConns(this.#ctx)); + return conns.map( + (c) => [this.#runtime.connId(c), this.#connToAdapter(c)] as [string, NativeConnAdapter], + )[Symbol.iterator]() as MapIterator<[string, NativeConnAdapter]>; + } + + forEach( + callback: (value: NativeConnAdapter, key: string, map: ReadonlyMap) => void, + thisArg?: unknown, + ): void { + const conns = callNativeSync(() => this.#runtime.actorConns(this.#ctx)); + for (const conn of conns) { + const id = this.#runtime.connId(conn); + callback.call(thisArg, this.#connToAdapter(conn), id, this); + } + } + + [Symbol.iterator](): MapIterator<[string, NativeConnAdapter]> { + return this.entries(); + } +} + export class ActorContextHandleAdapter { #runtime: CoreRuntime; #ctx: ActorContextHandle; @@ -2556,27 +2638,8 @@ export class ActorContextHandleAdapter { return callNativeSync(() => this.#runtime.actorRegion(this.#ctx)); } - get conns(): Map { - return new Map( - callNativeSync(() => this.#runtime.actorConns(this.#ctx)).map( - (conn) => [ - this.#runtime.connId(conn), - new NativeConnAdapter( - this.#runtime, - conn, - this.#schemas, - this.#ctx, - (connId) => - callNativeSync(() => - this.#runtime.actorQueueHibernationRemoval( - this.#ctx, - connId, - ), - ), - ), - ], - ), - ); + get conns(): NativeConnectionMap { + return new NativeConnectionMap(this.#runtime, this.#ctx, this.#schemas); } get log() {