Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ import {
type QueueSendResult,
type QueueSendWaitOptions,
} from "./queue";
import { resolveGatewayTarget } from "./resolve-gateway-target";
import {
type WebSocketMessage as ConnMessage,
messageLength,
Expand Down Expand Up @@ -578,9 +577,7 @@ export class ActorConnRaw {

async #connectWebSocket() {
const params = await this.#resolveConnectionParams();
const target = this.#gatewayOptions.skipReadyWait
? await this.#resolveGatewayTargetForSkipReadyWait()
: getGatewayTarget(this.#actorResolutionState);
const target = getGatewayTarget(this.#actorResolutionState);
const ws = await this.#driver.openWebSocket(
PATH_CONNECT,
target,
Expand Down Expand Up @@ -634,25 +631,6 @@ export class ActorConnRaw {
});
}

async #resolveGatewayTargetForSkipReadyWait() {
if ("getForId" in this.#actorResolutionState) {
return {
directId: this.#actorResolutionState.getForId.actorId,
} as const;
}

if (this.#actorId) {
return { directId: this.#actorId } as const;
}

return {
directId: await resolveGatewayTarget(
this.#driver,
this.#actorResolutionState,
),
} as const;
}

/** Called by the onopen event from drivers. */
#handleOnOpen() {
// Connection was disposed before Init message arrived - close the websocket to avoid leak
Expand Down
40 changes: 31 additions & 9 deletions rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,13 @@ export class ActorHandleRaw {
for (let attempt = 0; attempt < maxAttempts; attempt++) {
let actorId: string | undefined;
try {
const target = await this.#resolveActionTarget(useQueryTarget);
const gatewayOptions = resolveActorGatewayOptions(
this.#gatewayOptions,
);
const target = await this.#resolveGatewayRequestTarget(
useQueryTarget,
gatewayOptions,
);
actorId = "directId" in target ? target.directId : undefined;

return await createQueueSender({
Expand All @@ -150,9 +156,7 @@ export class ActorHandleRaw {
return await this.#driver.sendRequest(
target,
request,
resolveActorGatewayOptions(
this.#gatewayOptions,
),
gatewayOptions,
);
},
}).send(name, body, options as any);
Expand Down Expand Up @@ -269,7 +273,10 @@ export class ActorHandleRaw {
for (let attempt = 0; attempt < maxAttempts; attempt++) {
let actorId: string | undefined;
try {
const target = await this.#resolveActionTarget(useQueryTarget);
const target = await this.#resolveGatewayRequestTarget(
useQueryTarget,
gatewayOptions,
);
actorId = "directId" in target ? target.directId : undefined;

logger().debug(
Expand Down Expand Up @@ -558,6 +565,17 @@ export class ActorHandleRaw {
}
}

async #resolveGatewayRequestTarget(
useQueryTarget: boolean,
gatewayOptions: ActorGatewayOptions,
) {
if (gatewayOptions.skipReadyWait) {
return getGatewayTarget(this.#actorResolutionState);
}

return await this.#resolveActionTarget(useQueryTarget);
}

/**
* Establishes a persistent connection to the actor.
*
Expand Down Expand Up @@ -616,7 +634,10 @@ export class ActorHandleRaw {
for (let attempt = 0; attempt < maxAttempts; attempt++) {
let actorId: string | undefined;
try {
const target = await this.#resolveActionTarget(useQueryTarget);
const target = await this.#resolveGatewayRequestTarget(
useQueryTarget,
gatewayOptions,
);
actorId = "directId" in target ? target.directId : undefined;
const response = await rawHttpFetch(
this.#driver,
Expand Down Expand Up @@ -836,9 +857,10 @@ export class ActorHandleRaw {
this.#gatewayOptions,
options,
);
const target = gatewayOptions.skipReadyWait
? await this.#resolveActionTarget(false)
: getGatewayTarget(this.#actorResolutionState);
const target = await this.#resolveGatewayRequestTarget(
false,
gatewayOptions,
);
return await rawWebSocket(
this.#driver,
target,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { ClientConfigSchema } from "@/client/config";
import { createClient } from "@/client/mod";
import {
HEADER_RIVET_ACTOR,
HEADER_RIVET_SKIP_READY_WAIT,
Expand All @@ -10,7 +11,6 @@ import {
WS_PROTOCOL_TARGET,
WS_PROTOCOL_TOKEN,
} from "@/common/actor-router-consts";
import { createClient } from "@/client/mod";
import { RemoteEngineControlClient } from "@/engine-client/mod";

describe.sequential("RemoteEngineControlClient public token usage", () => {
Expand Down Expand Up @@ -162,6 +162,48 @@ describe.sequential("RemoteEngineControlClient public token usage", () => {
);
});

test("query handle fetch keeps skip ready wait on gateway URL", async () => {
const fetchCalls: Request[] = [];
const fetchMock = vi.fn(async (input: Request | URL | string) => {
const request = normalizeRequest(input);
fetchCalls.push(request);
return new Response("ok");
});
vi.stubGlobal("fetch", fetchMock);

const client = createClient({
endpoint: "https://api.rivet.dev",
disableMetadataLookup: true,
gateway: { skipReadyWait: true },
});
const handle = client.getOrCreate("mockAgenticLoop", [
"query-http-skip-ready-wait",
]);

const response = await handle.fetch("/skip-ready-wait");

expect(response.status).toBe(200);
expect(fetchCalls).toHaveLength(1);

const actorRequest = fetchCalls[0];
expect(actorRequest).toBeDefined();
if (!actorRequest) throw new Error("missing actor request");
const url = new URL(actorRequest.url);
expect(url.pathname).toBe(
"/gateway/mockAgenticLoop/request/skip-ready-wait",
);
expect(url.searchParams.get("rvt-method")).toBe("getOrCreate");
expect(url.searchParams.get("rvt-key")).toBe(
"query-http-skip-ready-wait",
);
expect(url.searchParams.get("rvt-skip-ready-wait")).toBe("true");
expect(actorRequest?.headers.get(HEADER_RIVET_TARGET)).toBeNull();
expect(actorRequest?.headers.get(HEADER_RIVET_ACTOR)).toBeNull();
expect(actorRequest?.headers.get(HEADER_RIVET_SKIP_READY_WAIT)).toBe(
"1",
);
});

test("uses metadata clientToken for actor websocket gateway requests", async () => {
const fetchMock = vi.fn(async (input: Request | URL | string) => {
const request = normalizeRequest(input);
Expand Down Expand Up @@ -258,6 +300,36 @@ describe.sequential("RemoteEngineControlClient public token usage", () => {
WS_PROTOCOL_SKIP_READY_WAIT,
]),
);

const client = createClient({
endpoint: "https://api.rivet.dev",
disableMetadataLookup: true,
gateway: { skipReadyWait: true },
});
const handle = client.getOrCreate("mockAgenticLoop", [
"query-ws-skip-ready-wait",
]);

await handle.webSocket("/skip-ready-wait");

expect(fetchMock).toHaveBeenCalledTimes(1);
expect(sockets).toHaveLength(4);
const querySocket = sockets[3];
expect(querySocket).toBeDefined();
if (!querySocket) throw new Error("missing query websocket");
const url = new URL(querySocket.url);
expect(url.pathname).toBe(
"/gateway/mockAgenticLoop/websocket/skip-ready-wait",
);
expect(url.searchParams.get("rvt-method")).toBe("getOrCreate");
expect(url.searchParams.get("rvt-key")).toBe(
"query-ws-skip-ready-wait",
);
expect(url.searchParams.get("rvt-skip-ready-wait")).toBe("true");
expect(querySocket.protocols).toContain(WS_PROTOCOL_SKIP_READY_WAIT);
expect(querySocket.protocols).not.toContain(
`${WS_PROTOCOL_TARGET}actor`,
);
});
});

Expand Down
Loading