Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/client/actor-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ export type ActorActionFunction<
...args: Args extends [unknown, ...infer Rest] ? Rest : Args
) => Promise<Response>;

export interface ActorGatewayOptions {
bypassConnectable?: boolean;
}

export interface ActorFetchInit extends RequestInit {
gateway?: ActorGatewayOptions;
}

export interface ActorWebSocketOptions {
gateway?: ActorGatewayOptions;
}

/**
* Maps action methods from actor definition to typed function signatures.
*/
Expand Down
22 changes: 17 additions & 5 deletions rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import { decodeCborCompat, deserializeWithEncoding, encodeCborCompat } from "@/s
import { bufferToArrayBuffer } from "@/utils";
import type {
ActorDefinitionActions,
ActorFetchInit,
ActorDefinitionQueueSend,
ActorWebSocketOptions,
} from "./actor-common";
import { type ActorConn, ActorConnRaw } from "./actor-conn";
import {
Expand Down Expand Up @@ -575,16 +577,17 @@ export class ActorHandleRaw {
* Fetches a resource from this actor via the /request endpoint. This is a
* convenience wrapper around the raw HTTP API.
*/
fetch(input: string | URL | Request, init?: RequestInit) {
fetch(input: string | URL | Request, init?: ActorFetchInit) {
return this.#fetchWithResolvedActor(input, init);
}

async #fetchWithResolvedActor(
input: string | URL | Request,
init?: RequestInit,
init?: ActorFetchInit,
) {
const maxAttempts = this.#getDynamicQueryMaxAttempts();
let useQueryTarget = false;
const { gateway, ...requestInit } = init ?? {};

for (let attempt = 0; attempt < maxAttempts; attempt++) {
let actorId: string | undefined;
Expand All @@ -596,7 +599,8 @@ export class ActorHandleRaw {
target,
this.#params,
input,
init,
requestInit,
gateway,
);
const retry = await this.#shouldRetryRawFetchResponse(
response,
Expand Down Expand Up @@ -783,14 +787,22 @@ export class ActorHandleRaw {
/**
* Opens a raw WebSocket connection to this actor.
*/
async webSocket(path?: string, protocols?: string | string[]) {
async webSocket(
path?: string,
protocols?: string | string[],
options: ActorWebSocketOptions = {},
) {
const params = await this.#resolveConnectionParams();
const target = options.gateway?.bypassConnectable
? await this.#resolveActionTarget(false)
: getGatewayTarget(this.#actorResolutionState);
return await rawWebSocket(
this.#driver,
getGatewayTarget(this.#actorResolutionState),
target,
params,
path,
protocols,
options.gateway,
);
}

Expand Down
13 changes: 11 additions & 2 deletions rivetkit-typescript/packages/rivetkit/src/client/raw-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { deconstructError } from "@/common/utils";
import {
type GatewayTarget,
type EngineControlClient,
type GatewayRequestOptions,
} from "@/engine-client/driver";
import { HEADER_CONN_PARAMS } from "@/common/actor-router-consts";
import { ActorError } from "./errors";
Expand All @@ -17,6 +18,7 @@ export async function rawHttpFetch(
params: unknown,
input: string | URL | Request,
init?: RequestInit,
options: GatewayRequestOptions = {},
): Promise<Response> {
// Extract path and merge init options
let path: string;
Expand Down Expand Up @@ -91,7 +93,7 @@ export async function rawHttpFetch(
headers: proxyRequestHeaders,
});

return driver.sendRequest(target, proxyRequest);
return driver.sendRequest(target, proxyRequest, options);
} catch (err) {
// Standardize to ClientActorError instead of the native backend error
const { group, code, message, metadata } = deconstructError(
Expand All @@ -114,6 +116,7 @@ export async function rawWebSocket(
path?: string,
// TODO: Supportp rotocols
_protocols?: string | string[],
options: GatewayRequestOptions = {},
): Promise<any> {
// TODO: Do we need encoding in rawWebSocket?
const encoding = "bare";
Expand Down Expand Up @@ -145,7 +148,13 @@ export async function rawWebSocket(
});

// Open WebSocket
const ws = await driver.openWebSocket(fullPath, target, encoding, params);
const ws = await driver.openWebSocket(
fullPath,
target,
encoding,
params,
options,
);

// Node & browser WebSocket types are incompatible
return ws as any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ export const HEADER_RIVET_TOKEN = "x-rivet-token";
export const HEADER_RIVET_TARGET = "x-rivet-target";
export const HEADER_RIVET_ACTOR = "x-rivet-actor";
export const HEADER_RIVET_NAMESPACE = "x-rivet-namespace";
export const HEADER_RIVET_BYPASS_CONNECTABLE =
"x-rivet-bypass-connectable";

// MARK: WebSocket Protocol Prefixes
/** Some servers (such as node-ws & Cloudflare) require explicitly match a certain WebSocket protocol. This gives us a static protocol to match against. */
Expand All @@ -30,6 +32,7 @@ export const WS_PROTOCOL_ACTOR = "rivet_actor.";
export const WS_PROTOCOL_ENCODING = "rivet_encoding.";
export const WS_PROTOCOL_CONN_PARAMS = "rivet_conn_params.";
export const WS_PROTOCOL_TOKEN = "rivet_token.";
export const WS_PROTOCOL_BYPASS_CONNECTABLE = "rivet_bypass_connectable";
export const WS_PROTOCOL_TEST_ACK_HOOK = "rivet_test_ack_hook.";

// MARK: WebSocket Inline Test Protocol Prefixes
Expand All @@ -51,4 +54,5 @@ export const ALLOWED_PUBLIC_HEADERS = [
HEADER_RIVET_ACTOR,
HEADER_RIVET_NAMESPACE,
HEADER_RIVET_TOKEN,
HEADER_RIVET_BYPASS_CONNECTABLE,
];
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
import type { ClientConfig } from "@/client/config";
import { HEADER_RIVET_TOKEN } from "@/common/actor-router-consts";
import {
HEADER_RIVET_ACTOR,
HEADER_RIVET_BYPASS_CONNECTABLE,
HEADER_RIVET_TARGET,
HEADER_RIVET_TOKEN,
} from "@/common/actor-router-consts";
import type { GatewayRequestOptions } from "./driver";

export interface HttpGatewayRequestOptions extends GatewayRequestOptions {
directActorId?: string;
}

export async function sendHttpRequestToGateway(
runConfig: ClientConfig,
gatewayUrl: string,
actorRequest: Request,
options: HttpGatewayRequestOptions = {},
): Promise<Response> {
// Handle body properly based on method and presence
let bodyToSend: ArrayBuffer | null = null;
const guardHeaders = buildGuardHeaders(runConfig, actorRequest);
const guardHeaders = buildGuardHeaders(runConfig, actorRequest, options);

if (actorRequest.method !== "GET" && actorRequest.method !== "HEAD") {
if (actorRequest.bodyUsed) {
Expand Down Expand Up @@ -49,6 +60,7 @@ function mutableResponse(fetchRes: Response): Response {
function buildGuardHeaders(
runConfig: ClientConfig,
actorRequest: Request,
options: HttpGatewayRequestOptions,
): Headers {
const headers = new Headers();
// Copy all headers from the original request
Expand All @@ -63,5 +75,12 @@ function buildGuardHeaders(
if (runConfig.token) {
headers.set(HEADER_RIVET_TOKEN, runConfig.token);
}
if (options.directActorId !== undefined) {
headers.set(HEADER_RIVET_TARGET, "actor");
headers.set(HEADER_RIVET_ACTOR, options.directActorId);
}
if (options.bypassConnectable) {
headers.set(HEADER_RIVET_BYPASS_CONNECTABLE, "1");
}
return headers;
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
WS_PROTOCOL_STANDARD as WS_PROTOCOL_RIVETKIT,
WS_PROTOCOL_TARGET,
WS_PROTOCOL_ACTOR,
WS_PROTOCOL_BYPASS_CONNECTABLE,
WS_PROTOCOL_TEST_ACK_HOOK,
WS_PROTOCOL_TOKEN,
} from "@/common/actor-router-consts";
Expand All @@ -17,6 +18,7 @@ import type { ActorGatewayQuery, CrashPolicy } from "@/client/query";
import type { Encoding, UniversalWebSocket } from "@/mod";
import { encodeCborCompat, uint8ArrayToBase64 } from "@/serde";
import { combineUrlPath } from "@/utils";
import type { GatewayRequestOptions } from "./driver";
import { logger } from "./log";

class BufferedRemoteWebSocket implements UniversalWebSocket {
Expand Down Expand Up @@ -211,6 +213,7 @@ export function buildActorQueryGatewayUrl(
maxInputSize = DEFAULT_MAX_QUERY_INPUT_SIZE,
crashPolicy: CrashPolicy | undefined = undefined,
runnerName?: string,
options: GatewayRequestOptions = {},
): string {
if (namespace.length === 0) {
throw new Error("actor query namespace must not be empty");
Expand Down Expand Up @@ -266,6 +269,9 @@ export function buildActorQueryGatewayUrl(
if (token !== undefined) {
params.append("rvt-token", token);
}
if (options.bypassConnectable) {
params.append("rvt-bypass_connectable", "true");
}

const queryString = params.toString();
let separator: string;
Expand Down Expand Up @@ -318,6 +324,7 @@ export async function openWebSocketToGateway(
gatewayUrl: string,
encoding: Encoding,
params: unknown,
options: GatewayRequestOptions & { directActorId?: string } = {},
): Promise<UniversalWebSocket> {
const WebSocket = await importWebSocket();

Expand All @@ -334,7 +341,19 @@ export async function openWebSocketToGateway(
// Create WebSocket connection
const ws = new WebSocket(
gatewayUrl,
buildWebSocketProtocols(runConfig, encoding, params, ackHookToken),
buildWebSocketProtocols(
runConfig,
encoding,
params,
ackHookToken,
options.directActorId
? {
target: "actor",
actorId: options.directActorId,
}
: undefined,
options,
),
);

// The WebSocket is returned before the connection is open. This follows
Expand Down Expand Up @@ -364,6 +383,7 @@ export function buildWebSocketProtocols(
target: "actor";
actorId: string;
},
options: GatewayRequestOptions = {},
): string[] {
const protocols: string[] = [];
protocols.push(WS_PROTOCOL_RIVETKIT);
Expand All @@ -372,6 +392,9 @@ export function buildWebSocketProtocols(
protocols.push(`${WS_PROTOCOL_TARGET}${target.target}`);
protocols.push(`${WS_PROTOCOL_ACTOR}${target.actorId}`);
}
if (options.bypassConnectable) {
protocols.push(WS_PROTOCOL_BYPASS_CONNECTABLE);
}
if (params) {
protocols.push(
`${WS_PROTOCOL_CONN_PARAMS}${encodeURIComponent(JSON.stringify(params))}`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import type { ActorQuery, CrashPolicy } from "@/client/query";

export type GatewayTarget = { directId: string } | ActorQuery;

export interface GatewayRequestOptions {
bypassConnectable?: boolean;
}

export interface EngineControlClient {
getForId(input: GetForIdInput): Promise<ActorOutput | undefined>;
getWithKey(input: GetWithKeyInput): Promise<ActorOutput | undefined>;
Expand All @@ -16,12 +20,14 @@ export interface EngineControlClient {
sendRequest(
target: GatewayTarget,
actorRequest: Request,
options?: GatewayRequestOptions,
): Promise<Response>;
openWebSocket(
path: string,
target: GatewayTarget,
encoding: Encoding,
params: unknown,
options?: GatewayRequestOptions,
): Promise<UniversalWebSocket>;
proxyRequest(
c: HonoContext,
Expand All @@ -35,7 +41,10 @@ export interface EngineControlClient {
encoding: Encoding,
params: unknown,
): Promise<Response>;
buildGatewayUrl(target: GatewayTarget): Promise<string>;
buildGatewayUrl(
target: GatewayTarget,
options?: GatewayRequestOptions,
): Promise<string>;
displayInformation(): RuntimeDisplayInformation;
extraStartupLog?: () => Record<string, unknown>;
modifyRuntimeRouter?: (config: RegistryConfig, router: Hono) => void;
Expand Down
Loading
Loading