diff --git a/.changeset/sweet-cameras-act.md b/.changeset/sweet-cameras-act.md new file mode 100644 index 000000000..fa945249f --- /dev/null +++ b/.changeset/sweet-cameras-act.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": minor +--- + +add `context.setDomainPolicy({blockedDomains: ["some.domain"]})` which allows users to define a list of domains that will be blocked by stagehand diff --git a/packages/core/lib/v3/types/public/context.ts b/packages/core/lib/v3/types/public/context.ts index f5a173436..030572c85 100644 --- a/packages/core/lib/v3/types/public/context.ts +++ b/packages/core/lib/v3/types/public/context.ts @@ -32,3 +32,9 @@ export interface ClearCookieOptions { domain?: string | RegExp; path?: string | RegExp; } + +/** Context-wide network domain policy. */ +export interface DomainPolicy { + /** Domain-only block patterns, e.g. "example.com" or "*.example.com". */ + blockedDomains?: string[]; +} diff --git a/packages/core/lib/v3/types/public/sdkErrors.ts b/packages/core/lib/v3/types/public/sdkErrors.ts index ae5d580e7..d31528cdf 100644 --- a/packages/core/lib/v3/types/public/sdkErrors.ts +++ b/packages/core/lib/v3/types/public/sdkErrors.ts @@ -423,6 +423,17 @@ export class StagehandSetExtraHTTPHeadersError extends StagehandError { } } +export class StagehandSetDomainPolicyError extends StagehandError { + public readonly failures: string[]; + + constructor(failures: string[]) { + super( + `setDomainPolicy failed for ${failures.length} session(s): ${failures.join(", ")}`, + ); + this.failures = failures; + } +} + export class StagehandSnapshotError extends StagehandError { constructor(cause?: unknown) { const suffix = diff --git a/packages/core/lib/v3/understudy/context.ts b/packages/core/lib/v3/understudy/context.ts index 1ac471502..c0439b5eb 100644 --- a/packages/core/lib/v3/understudy/context.ts +++ b/packages/core/lib/v3/understudy/context.ts @@ -17,6 +17,7 @@ import { CookieSetError, PageNotFoundError, StagehandSetExtraHTTPHeadersError, + StagehandSetDomainPolicyError, } from "../types/public/sdkErrors.js"; import { getEnvTimeoutMs, withTimeout } from "../timeoutConfig.js"; import { @@ -29,7 +30,10 @@ import { Cookie, ClearCookieOptions, CookieParam, + DomainPolicy, } from "../types/public/context.js"; +import { normalizeDomainPolicy, shouldBlockUrl } from "./domainPolicy.js"; +import type { NormalizedDomainPolicy } from "./domainPolicy.js"; type TargetId = string; type SessionId = string; @@ -113,6 +117,10 @@ export class V3Context { // Timestamp for most recent popup/open signal private _lastPopupSignalAt = 0; private readonly _targetSessionListeners = new Set(); + private readonly _domainPolicySessionListeners = new Map< + SessionId, + (evt: Protocol.Fetch.RequestPausedEvent) => void + >(); private readonly _sessionInit = new Set(); private pagesByTarget = new Map(); @@ -124,8 +132,10 @@ export class V3Context { private typeByTarget = new Map(); private _pageOrder: TargetId[] = []; private pendingCreatedTargetUrl = new Map(); + private pageCreationFailures = new Map(); private readonly initScripts: string[] = []; private extraHttpHeaders: Record | null = null; + private domainPolicy: NormalizedDomainPolicy | null = null; private _clipboard?: ContextClipboard; private installTargetSessionListeners(session: CDPSessionLike): void { @@ -425,6 +435,143 @@ export class V3Context { } } + public getDomainPolicy(): DomainPolicy | null { + if (!this.domainPolicy) return null; + return { blockedDomains: [...this.domainPolicy.blockedDomains] }; + } + + public async setDomainPolicy(policy: DomainPolicy | null): Promise { + const nextPolicy = normalizeDomainPolicy(policy); + this.domainPolicy = nextPolicy; + + const sessions: CDPSessionLike[] = []; + for (const sessionId of this._sessionInit) { + const session = this.conn.getSession(sessionId); + if (session) sessions.push(session); + } + + if (!sessions.length) return; + + const results = await Promise.allSettled( + sessions.map(async (session) => { + if (!nextPolicy) { + try { + await session.send("Fetch.disable"); + this.uninstallDomainPolicyHandler(session); + } catch (error) { + throw { action: "disable", error }; + } + return; + } + + this.installDomainPolicyHandler(session); + try { + await session.send("Fetch.enable", { + patterns: nextPolicy.fetchPatterns, + }); + } catch (error) { + throw { action: "enable", error }; + } + }), + ); + + const failures = results + .map((result, index) => ({ result, session: sessions[index] })) + .filter( + ( + entry, + ): entry is { + result: PromiseRejectedResult; + session: CDPSessionLike; + } => entry.result.status === "rejected", + ) + .map((entry) => { + const failure = entry.result.reason as { + action?: "enable" | "disable"; + error?: unknown; + }; + if (failure?.action === "enable") { + this.uninstallDomainPolicyHandler(entry.session); + } + const reason = failure?.error ?? entry.result.reason; + const sid = entry.session.id ?? "unknown"; + const message = + reason instanceof Error ? reason.message : String(reason); + return `session=${sid} error=${message}`; + }); + + if (failures.length) { + throw new StagehandSetDomainPolicyError(failures); + } + } + + private installDomainPolicyHandler(session: CDPSessionLike): void { + const sessionId = session.id; + if (!sessionId) return; + if (this._domainPolicySessionListeners.has(sessionId)) return; + + const handler = (evt: Protocol.Fetch.RequestPausedEvent) => { + void this.handleDomainPolicyRequestPaused(session, evt); + }; + this._domainPolicySessionListeners.set(sessionId, handler); + session.on( + "Fetch.requestPaused", + handler, + ); + } + + private uninstallDomainPolicyHandler(session: CDPSessionLike): void { + const sessionId = session.id; + if (!sessionId) return; + + const handler = this._domainPolicySessionListeners.get(sessionId); + if (!handler) return; + + session.off( + "Fetch.requestPaused", + handler, + ); + this._domainPolicySessionListeners.delete(sessionId); + } + + private async handleDomainPolicyRequestPaused( + session: CDPSessionLike, + evt: Protocol.Fetch.RequestPausedEvent, + ): Promise { + const policy = this.domainPolicy; + + if (!policy || !shouldBlockUrl(evt.request.url, policy)) { + await session + .send("Fetch.continueRequest", { requestId: evt.requestId }) + .catch(() => {}); + return; + } + + let hostname = ""; + try { + hostname = new URL(evt.request.url).hostname.toLowerCase(); + } catch { + // ignore malformed URLs for logging + } + + v3Logger({ + category: "network", + message: "Blocked request by domain policy", + level: 2, + auxiliary: { + hostname: { value: hostname, type: "string" }, + ruleType: { value: "blockedDomains", type: "string" }, + }, + }); + + await session + .send("Fetch.failRequest", { + requestId: evt.requestId, + errorReason: "BlockedByClient", + }) + .catch(() => {}); + } + public get clipboard(): BrowserClipboard { return (this._clipboard ??= new ContextClipboard({ context: this, @@ -503,6 +650,13 @@ export class V3Context { const deadline = Date.now() + 5000; while (Date.now() < deadline) { + const failure = this.pageCreationFailures.get(targetId); + if (failure) { + this.pageCreationFailures.delete(targetId); + this.pendingCreatedTargetUrl.delete(targetId); + throw failure; + } + const page = this.pagesByTarget.get(targetId); if (page) { // we created at about:blank; navigate only after attach so init scripts run @@ -518,6 +672,7 @@ export class V3Context { } await new Promise((r) => setTimeout(r, 25)); } + this.pendingCreatedTargetUrl.delete(targetId); throw new TimeoutError(`newPage: target not attached (${targetId})`, 5000); } @@ -534,6 +689,7 @@ export class V3Context { this.createdAtByTarget.clear(); this.typeByTarget.clear(); this.pendingCreatedTargetUrl.clear(); + this.pageCreationFailures.clear(); } /** @@ -668,6 +824,24 @@ export class V3Context { .catch(() => false); return { dispatched, response }; }; + const queueFetchEnablePreResume = (params: object) => { + let error: unknown; + const dispatched = this.conn + .waitForSessionDispatch(sessionId, "Fetch.enable") + .then(() => true) + .catch((err) => { + error = err; + return false; + }); + const response = session + .send("Fetch.enable", params) + .then(() => true) + .catch((err) => { + error = err; + return false; + }); + return { dispatched, response, getError: () => error }; + }; const initScriptOps: Array<{ dispatched: Promise; response: Promise; @@ -697,6 +871,19 @@ export class V3Context { queuePreResume("Network.setExtraHTTPHeaders", { headers }), ); } + const fetchPreResumeOps: Array<{ + dispatched: Promise; + response: Promise; + getError: () => unknown; + }> = []; + if (this.domainPolicy) { + this.installDomainPolicyHandler(session); + fetchPreResumeOps.push( + queueFetchEnablePreResume({ + patterns: this.domainPolicy.fetchPatterns, + }), + ); + } // Send init scripts only after auto-attach has been queued. if (this.initScripts.length) { for (const source of this.initScripts) { @@ -728,6 +915,7 @@ export class V3Context { await Promise.all([ ...corePreResumeOps.map((op) => op.dispatched), ...headerPreResumeOps.map((op) => op.dispatched), + ...fetchPreResumeOps.map((op) => op.dispatched), ...initScriptOps.map((op) => op.dispatched), piercerPreloadOp.dispatched, ]) @@ -741,11 +929,13 @@ export class V3Context { const [ coreResults, headerResults, + fetchResults, initScriptResults, piercerPreRegistered, ] = await Promise.all([ Promise.all(corePreResumeOps.map((op) => op.response)), Promise.all(headerPreResumeOps.map((op) => op.response)), + Promise.all(fetchPreResumeOps.map((op) => op.response)), Promise.all(initScriptOps.map((op) => op.response)), piercerPreloadOp.response, ]); @@ -778,6 +968,46 @@ export class V3Context { return; } resumed = true; + + if (fetchPreResumeOps.length > 0 && !fetchResults.every(Boolean)) { + this.uninstallDomainPolicyHandler(session); + const fetchError = fetchPreResumeOps + .map((op) => op.getError()) + .find((error) => error !== undefined); + const fetchErrorMessage = fetchError + ? fetchError instanceof Error + ? fetchError.message + : String(fetchError) + : "Fetch.enable failed during target attach"; + const policyFailureMessage = + "Fetch.enable failed during target attach; closing target because " + + "Stagehand cannot guarantee domain policy enforcement"; + if (this.pendingCreatedTargetUrl.has(info.targetId)) { + this.pageCreationFailures.set( + info.targetId, + new StagehandSetDomainPolicyError([ + `session=${sessionId} target=${info.targetId} error=${policyFailureMessage} cdpError=${fetchErrorMessage}`, + ]), + ); + } + v3Logger({ + category: "ctx", + message: "Closing target because domain policy could not be guaranteed", + level: 0, + auxiliary: { + targetId: { value: String(info.targetId), type: "string" }, + targetType: { value: String(info.type), type: "string" }, + targetUrl: { value: String(info.url ?? ""), type: "string" }, + sessionId: { value: sessionId, type: "string" }, + cdpError: { value: fetchErrorMessage, type: "string" }, + }, + }); + await this.conn + .send("Target.closeTarget", { targetId: info.targetId }) + .catch(() => {}); + return; + } + const scriptsInstalled = coreResults.every(Boolean) && initScriptResults.every(Boolean); @@ -933,6 +1163,12 @@ export class V3Context { } this._targetSessionListeners.delete(sessionId); + const session = this.conn.getSession(sessionId); + if (session) { + this.uninstallDomainPolicyHandler(session); + } else { + this._domainPolicySessionListeners.delete(sessionId); + } this._sessionInit.delete(sessionId); this._piercerInstalled.delete(sessionId); } @@ -944,6 +1180,7 @@ export class V3Context { const page = this.pagesByTarget.get(targetId); if (!page) return; + this.pageCreationFailures.delete(targetId); const mainId = page.mainFrameId(); this.mainFrameToTarget.delete(mainId); this.frameOwnerPage.delete(mainId); diff --git a/packages/core/lib/v3/understudy/domainPolicy.ts b/packages/core/lib/v3/understudy/domainPolicy.ts new file mode 100644 index 000000000..b7268b383 --- /dev/null +++ b/packages/core/lib/v3/understudy/domainPolicy.ts @@ -0,0 +1,163 @@ +import type { Protocol } from "devtools-protocol"; +import { StagehandInvalidArgumentError } from "../types/public/sdkErrors.js"; +import type { DomainPolicy } from "../types/public/context.js"; + +type BlockedDomainRule = + | { type: "exact"; hostname: string } + | { type: "wildcard"; hostname: string }; + +export type NormalizedDomainPolicy = { + blockedDomains: string[]; + blockedDomainRules: BlockedDomainRule[]; + fetchPatterns: Protocol.Fetch.RequestPattern[]; +}; + +const DOMAIN_LABEL_RE = /^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$/; + +function canonicalizeHostname(hostname: string): string { + return hostname.toLowerCase().replace(/\.+$/, ""); +} + +function validateHostname(hostname: string, original: string): void { + if (!hostname || hostname.length > 253) { + throw new StagehandInvalidArgumentError( + `Invalid blocked domain pattern: "${original}"`, + ); + } + + const labels = hostname.split("."); + if ( + labels.length < 2 || + labels.some((label) => !DOMAIN_LABEL_RE.test(label)) + ) { + throw new StagehandInvalidArgumentError( + `Invalid blocked domain pattern: "${original}"`, + ); + } +} + +function normalizeBlockedDomainPattern(pattern: unknown): BlockedDomainRule { + if (typeof pattern !== "string") { + throw new StagehandInvalidArgumentError( + `Blocked domain patterns must be strings`, + ); + } + + const original = pattern; + const normalized = canonicalizeHostname(pattern.trim()); + + if (!normalized) { + throw new StagehandInvalidArgumentError( + `Invalid blocked domain pattern: "${original}"`, + ); + } + + if ( + normalized.includes("://") || + normalized.includes("/") || + normalized.includes(":") || + normalized.includes("?") || + normalized.includes("#") + ) { + throw new StagehandInvalidArgumentError( + `Blocked domain patterns must be domain-only values: "${original}"`, + ); + } + + if (normalized.startsWith("*.")) { + const hostname = normalized.slice(2); + validateHostname(hostname, original); + return { type: "wildcard", hostname }; + } + + if (normalized.includes("*")) { + throw new StagehandInvalidArgumentError( + `Wildcards are only supported as a leading "*.": "${original}"`, + ); + } + + validateHostname(normalized, original); + return { type: "exact", hostname: normalized }; +} + +function patternHost(rule: BlockedDomainRule): string { + return rule.type === "wildcard" ? `*.${rule.hostname}` : rule.hostname; +} + +function fetchPatternHosts(rule: BlockedDomainRule): string[] { + const host = patternHost(rule); + return [host, `${host}.`]; +} + +function toFetchPatterns( + rules: BlockedDomainRule[], +): Protocol.Fetch.RequestPattern[] { + const patterns: Protocol.Fetch.RequestPattern[] = []; + + for (const rule of rules) { + for (const host of fetchPatternHosts(rule)) { + for (const scheme of ["http", "https"]) { + patterns.push({ + urlPattern: `${scheme}://${host}/*`, + requestStage: "Request", + }); + patterns.push({ + urlPattern: `${scheme}://${host}:*/*`, + requestStage: "Request", + }); + } + } + } + + return patterns; +} + +export function normalizeDomainPolicy( + policy: DomainPolicy | null, +): NormalizedDomainPolicy | null { + if (!policy?.blockedDomains?.length) return null; + + const rulesByKey = new Map(); + for (const domain of policy.blockedDomains) { + const rule = normalizeBlockedDomainPattern(domain); + rulesByKey.set(`${rule.type}:${rule.hostname}`, rule); + } + + const blockedDomainRules = Array.from(rulesByKey.values()); + if (!blockedDomainRules.length) return null; + + return { + blockedDomains: blockedDomainRules.map(patternHost), + blockedDomainRules, + fetchPatterns: toFetchPatterns(blockedDomainRules), + }; +} + +function hostnameFromHttpUrl(url: string): string | null { + try { + const parsed = new URL(url); + if (parsed.protocol !== "http:" && parsed.protocol !== "https:") { + return null; + } + return canonicalizeHostname(parsed.hostname); + } catch { + return null; + } +} + +export function shouldBlockUrl( + url: string, + policy: NormalizedDomainPolicy | null, +): boolean { + if (!policy) return false; + + const hostname = hostnameFromHttpUrl(url); + if (!hostname) return false; + + return policy.blockedDomainRules.some((rule) => { + if (rule.type === "exact") { + return hostname === rule.hostname; + } + return hostname.endsWith(`.${rule.hostname}`); + }); +} diff --git a/packages/core/tests/integration/context-domain-policy.spec.ts b/packages/core/tests/integration/context-domain-policy.spec.ts new file mode 100644 index 000000000..e57299ae1 --- /dev/null +++ b/packages/core/tests/integration/context-domain-policy.spec.ts @@ -0,0 +1,116 @@ +import { test, expect } from "@playwright/test"; +import type { Protocol } from "devtools-protocol"; +import { V3 } from "../../lib/v3/v3.js"; +import { v3TestConfig } from "./v3.config.js"; +import { closeV3 } from "./testUtils.js"; + +const BLOCKED_HOST = "example.com"; +const BLOCKED_URL = `https://${BLOCKED_HOST}/stagehand-domain-policy.png`; + +type InternalPage = { + mainSession: { + send: (method: string, params?: unknown) => Promise; + on: (event: string, handler: (params: unknown) => void) => void; + off: (event: string, handler: (params: unknown) => void) => void; + }; + goto: ( + url: string, + options?: { waitUntil?: "load" | "domcontentloaded"; timeoutMs?: number }, + ) => Promise; +}; + +function pageWithBlockedImage(): string { + return `data:text/html,${encodeURIComponent( + ``, + )}`; +} + +async function waitForBlockedRequest(page: InternalPage): Promise { + await page.mainSession.send("Network.enable"); + + await new Promise((resolve, reject) => { + const requestUrls = new Map(); + let settled = false; + const timeout = setTimeout(() => { + finish(() => reject(new Error("Timed out waiting for blocked request"))); + }, 5000); + + const cleanup = () => { + clearTimeout(timeout); + page.mainSession.off("Network.requestWillBeSent", onRequest); + page.mainSession.off("Network.loadingFailed", onLoadingFailed); + }; + + const finish = (settle: () => void) => { + if (settled) return; + settled = true; + cleanup(); + settle(); + }; + + const onRequest = (params: unknown) => { + const evt = params as Protocol.Network.RequestWillBeSentEvent; + requestUrls.set(evt.requestId, String(evt.request?.url ?? "")); + }; + + const onLoadingFailed = (params: unknown) => { + const evt = params as Protocol.Network.LoadingFailedEvent; + const url = requestUrls.get(evt.requestId); + if (url !== BLOCKED_URL) return; + try { + expect(evt.errorText).toContain("ERR_BLOCKED_BY_CLIENT"); + finish(resolve); + } catch (error) { + finish(() => reject(error)); + } + }; + + page.mainSession.on("Network.requestWillBeSent", onRequest); + page.mainSession.on("Network.loadingFailed", onLoadingFailed); + + void page + .goto(pageWithBlockedImage(), { + waitUntil: "load", + timeoutMs: 5000, + }) + .catch((error) => { + finish(() => reject(error)); + }); + }); +} + +test.describe("context.setDomainPolicy", () => { + let v3: V3; + + test.beforeEach(async () => { + v3 = new V3(v3TestConfig); + await v3.init(); + }); + + test.afterEach(async () => { + await closeV3(v3); + }); + + test("blocks matching requests on existing pages", async () => { + const ctx = v3.context; + const page = (await ctx.awaitActivePage()) as unknown as InternalPage; + + await ctx.setDomainPolicy({ + blockedDomains: [BLOCKED_HOST], + }); + + await waitForBlockedRequest(page); + }); + + test("applies to pages created after setting the policy", async () => { + const ctx = v3.context; + + await ctx.setDomainPolicy({ + blockedDomains: [BLOCKED_HOST], + }); + + const page = (await ctx.newPage()) as unknown as InternalPage; + + await waitForBlockedRequest(page); + }); +}); diff --git a/packages/core/tests/unit/context-domain-policy.test.ts b/packages/core/tests/unit/context-domain-policy.test.ts new file mode 100644 index 000000000..78ff79010 --- /dev/null +++ b/packages/core/tests/unit/context-domain-policy.test.ts @@ -0,0 +1,341 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { V3Context } from "../../lib/v3/understudy/context.js"; +import { MockCDPSession } from "./helpers/mockCDPSession.js"; +import { StagehandSetDomainPolicyError } from "../../lib/v3/types/public/sdkErrors.js"; +import type { DomainPolicy } from "../../lib/v3/types/public/context.js"; +import { normalizeDomainPolicy } from "../../lib/v3/understudy/domainPolicy.js"; + +type ContextStub = { + _sessionInit: Set; + _domainPolicySessionListeners: Map; + conn: { + getSession: (id: string) => MockCDPSession | undefined; + }; + domainPolicy: unknown; +}; + +const makeContext = (sessions: MockCDPSession[]): ContextStub => { + const sessionsById = new Map( + sessions.map((session) => [session.id, session]), + ); + return Object.assign(Object.create(V3Context.prototype), { + _sessionInit: new Set(sessions.map((session) => session.id)), + _domainPolicySessionListeners: new Map(), + conn: { + getSession: (id: string) => sessionsById.get(id), + }, + domainPolicy: null, + }) as ContextStub; +}; + +const flushAsyncHandlers = async () => { + await new Promise((resolve) => setTimeout(resolve, 0)); +}; + +describe("V3Context.setDomainPolicy", () => { + afterEach(() => { + vi.useRealTimers(); + }); + + const setDomainPolicy = V3Context.prototype.setDomainPolicy as ( + this: ContextStub, + policy: DomainPolicy | null, + ) => Promise; + const getDomainPolicy = V3Context.prototype.getDomainPolicy as ( + this: ContextStub, + ) => DomainPolicy | null; + + it("sends Fetch.enable with generated patterns to all sessions", async () => { + const sessionA = new MockCDPSession({}, "session-a"); + const sessionB = new MockCDPSession({}, "session-b"); + const ctx = makeContext([sessionA, sessionB]); + + await setDomainPolicy.call(ctx, { + blockedDomains: ["ads.example.com"], + }); + + for (const session of [sessionA, sessionB]) { + expect(session.listenerCount("Fetch.requestPaused")).toBe(1); + expect(session.callsFor("Fetch.enable")[0]?.params).toEqual({ + patterns: [ + { urlPattern: "http://ads.example.com/*", requestStage: "Request" }, + { + urlPattern: "http://ads.example.com:*/*", + requestStage: "Request", + }, + { urlPattern: "https://ads.example.com/*", requestStage: "Request" }, + { + urlPattern: "https://ads.example.com:*/*", + requestStage: "Request", + }, + { urlPattern: "http://ads.example.com./*", requestStage: "Request" }, + { + urlPattern: "http://ads.example.com.:*/*", + requestStage: "Request", + }, + { + urlPattern: "https://ads.example.com./*", + requestStage: "Request", + }, + { + urlPattern: "https://ads.example.com.:*/*", + requestStage: "Request", + }, + ], + }); + } + + expect(getDomainPolicy.call(ctx)).toEqual({ + blockedDomains: ["ads.example.com"], + }); + }); + + it("sends Fetch.disable when policy is null or empty", async () => { + const sessionA = new MockCDPSession({}, "session-a"); + const sessionB = new MockCDPSession({}, "session-b"); + const ctx = makeContext([sessionA, sessionB]); + + await setDomainPolicy.call(ctx, { + blockedDomains: ["ads.example.com"], + }); + await setDomainPolicy.call(ctx, null); + await setDomainPolicy.call(ctx, { blockedDomains: [] }); + + for (const session of [sessionA, sessionB]) { + expect(session.callsFor("Fetch.disable").length).toBe(2); + expect(session.listenerCount("Fetch.requestPaused")).toBe(0); + } + + expect(getDomainPolicy.call(ctx)).toBeNull(); + }); + + it("keeps its requestPaused listener when Fetch.disable fails", async () => { + const session = new MockCDPSession( + { + "Fetch.disable": () => { + throw new Error("disable failed"); + }, + }, + "session-a", + ); + const ctx = makeContext([session]); + + await setDomainPolicy.call(ctx, { + blockedDomains: ["ads.example.com"], + }); + + await expect(setDomainPolicy.call(ctx, null)).rejects.toBeInstanceOf( + StagehandSetDomainPolicyError, + ); + + expect(session.listenerCount("Fetch.requestPaused")).toBe(1); + }); + + it("removes only its own requestPaused listener when disabled", async () => { + const session = new MockCDPSession({}, "session-a"); + const ctx = makeContext([session]); + const userHandler = () => {}; + session.on("Fetch.requestPaused", userHandler); + + await setDomainPolicy.call(ctx, { + blockedDomains: ["ads.example.com"], + }); + + expect(session.listenerCount("Fetch.requestPaused")).toBe(2); + + await setDomainPolicy.call(ctx, null); + + expect(session.listenerCount("Fetch.requestPaused")).toBe(1); + + session.emit("Fetch.requestPaused", { + requestId: "request-1", + request: { url: "https://ads.example.com/script.js" }, + }); + await flushAsyncHandlers(); + + expect(session.callsFor("Fetch.continueRequest").length).toBe(0); + expect(session.callsFor("Fetch.failRequest").length).toBe(0); + }); + + it("throws a custom error with session failure details", async () => { + const sessionA = new MockCDPSession( + { + "Fetch.enable": () => { + throw new Error("boom"); + }, + }, + "session-a", + ); + const sessionB = new MockCDPSession({}, "session-b"); + const ctx = makeContext([sessionA, sessionB]); + + const promise = setDomainPolicy.call(ctx, { + blockedDomains: ["ads.example.com"], + }); + + await expect(promise).rejects.toBeInstanceOf(StagehandSetDomainPolicyError); + + try { + await promise; + } catch (error) { + const err = error as StagehandSetDomainPolicyError; + expect(err.failures).toHaveLength(1); + expect(err.failures[0]).toContain("session=session-a"); + expect(err.failures[0]).toContain("boom"); + } + + expect(sessionA.callsFor("Fetch.enable").length).toBe(1); + expect(sessionB.callsFor("Fetch.enable").length).toBe(1); + expect(sessionA.listenerCount("Fetch.requestPaused")).toBe(0); + expect(sessionB.listenerCount("Fetch.requestPaused")).toBe(1); + }); + + it("fails blocked paused requests", async () => { + const session = new MockCDPSession({}, "session-a"); + const ctx = makeContext([session]); + + await setDomainPolicy.call(ctx, { + blockedDomains: ["ads.example.com"], + }); + + session.emit("Fetch.requestPaused", { + requestId: "request-1", + request: { url: "https://ads.example.com/script.js" }, + }); + await flushAsyncHandlers(); + + expect(session.callsFor("Fetch.failRequest")[0]?.params).toEqual({ + requestId: "request-1", + errorReason: "BlockedByClient", + }); + }); + + it("continues unexpected non-blocked paused requests", async () => { + const session = new MockCDPSession({}, "session-a"); + const ctx = makeContext([session]); + + await setDomainPolicy.call(ctx, { + blockedDomains: ["ads.example.com"], + }); + + session.emit("Fetch.requestPaused", { + requestId: "request-1", + request: { url: "https://example.com/" }, + }); + await flushAsyncHandlers(); + + expect(session.callsFor("Fetch.continueRequest")[0]?.params).toEqual({ + requestId: "request-1", + }); + }); + + it("closes new targets when Fetch.enable fails with an active policy", async () => { + const session = new MockCDPSession( + { + "Fetch.enable": () => { + throw new Error("fetch unavailable"); + }, + }, + "session-a", + ); + const closeTargetCalls: unknown[] = []; + const ctx = Object.assign(Object.create(V3Context.prototype), { + _sessionInit: new Set(), + _targetSessionListeners: new Set(), + _domainPolicySessionListeners: new Map(), + _piercerInstalled: new Set(), + domainPolicy: normalizeDomainPolicy({ + blockedDomains: ["ads.example.com"], + }), + conn: { + getSession: (id: string) => (id === session.id ? session : undefined), + waitForSessionDispatch: () => Promise.resolve(), + send: async (method: string, params?: unknown) => { + if (method === "Target.closeTarget") closeTargetCalls.push(params); + return {}; + }, + }, + pagesByTarget: new Map(), + mainFrameToTarget: new Map(), + sessionOwnerPage: new Map(), + frameOwnerPage: new Map(), + pendingOopifByMainFrame: new Map(), + createdAtByTarget: new Map(), + typeByTarget: new Map(), + pendingCreatedTargetUrl: new Map([["target-a", "about:blank"]]), + pageCreationFailures: new Map(), + initScripts: [], + extraHttpHeaders: null, + localBrowserLaunchOptions: null, + apiClient: null, + env: "LOCAL", + }); + + const onAttachedToTarget = V3Context.prototype[ + "onAttachedToTarget" as keyof V3Context + ] as unknown as ( + this: typeof ctx, + info: { + targetId: string; + type: string; + title: string; + url: string; + attached: boolean; + canAccessOpener: boolean; + }, + sessionId: string, + ) => Promise; + + await onAttachedToTarget.call( + ctx, + { + targetId: "target-a", + type: "page", + title: "", + url: "about:blank", + attached: true, + canAccessOpener: false, + }, + session.id, + ); + + expect(closeTargetCalls).toEqual([{ targetId: "target-a" }]); + expect(session.listenerCount("Fetch.requestPaused")).toBe(0); + expect(session.callsFor("Page.getFrameTree").length).toBe(0); + const failure = ctx.pageCreationFailures.get("target-a"); + expect(failure).toBeInstanceOf(StagehandSetDomainPolicyError); + const failureMessage = (failure as StagehandSetDomainPolicyError) + .failures[0]; + expect(failureMessage).toContain( + "Stagehand cannot guarantee domain policy enforcement", + ); + expect(failureMessage).toContain("cdpError=fetch unavailable"); + }); + + it("newPage throws stored attach failures without waiting for timeout", async () => { + const ctx = Object.assign(Object.create(V3Context.prototype), { + conn: { + send: vi.fn(async (method: string) => { + if (method === "Target.createTarget") { + return { targetId: "target-a" }; + } + return {}; + }), + }, + pendingCreatedTargetUrl: new Map([["target-a", "about:blank"]]), + pageCreationFailures: new Map([ + ["target-a", new StagehandSetDomainPolicyError(["session=session-a"])], + ]), + pagesByTarget: new Map(), + }); + const newPage = V3Context.prototype.newPage as ( + this: typeof ctx, + url?: string, + ) => Promise; + + await expect(newPage.call(ctx)).rejects.toBeInstanceOf( + StagehandSetDomainPolicyError, + ); + expect(ctx.pendingCreatedTargetUrl.has("target-a")).toBe(false); + }); +}); diff --git a/packages/core/tests/unit/domain-policy.test.ts b/packages/core/tests/unit/domain-policy.test.ts new file mode 100644 index 000000000..9f6b623d5 --- /dev/null +++ b/packages/core/tests/unit/domain-policy.test.ts @@ -0,0 +1,135 @@ +import { describe, expect, it } from "vitest"; +import { + normalizeDomainPolicy, + shouldBlockUrl, +} from "../../lib/v3/understudy/domainPolicy.js"; +import { StagehandInvalidArgumentError } from "../../lib/v3/types/public/sdkErrors.js"; + +describe("domain policy helpers", () => { + it("generates HTTP and HTTPS Fetch patterns for exact blocked domains", () => { + const policy = normalizeDomainPolicy({ + blockedDomains: ["ads.example.com"], + }); + + expect(policy?.fetchPatterns).toEqual([ + { urlPattern: "http://ads.example.com/*", requestStage: "Request" }, + { urlPattern: "http://ads.example.com:*/*", requestStage: "Request" }, + { urlPattern: "https://ads.example.com/*", requestStage: "Request" }, + { urlPattern: "https://ads.example.com:*/*", requestStage: "Request" }, + { urlPattern: "http://ads.example.com./*", requestStage: "Request" }, + { urlPattern: "http://ads.example.com.:*/*", requestStage: "Request" }, + { urlPattern: "https://ads.example.com./*", requestStage: "Request" }, + { urlPattern: "https://ads.example.com.:*/*", requestStage: "Request" }, + ]); + }); + + it("generates HTTP and HTTPS Fetch patterns for wildcard blocked domains", () => { + const policy = normalizeDomainPolicy({ + blockedDomains: ["*.tracking.example.com"], + }); + + expect(policy?.fetchPatterns).toEqual([ + { + urlPattern: "http://*.tracking.example.com/*", + requestStage: "Request", + }, + { + urlPattern: "http://*.tracking.example.com:*/*", + requestStage: "Request", + }, + { + urlPattern: "https://*.tracking.example.com/*", + requestStage: "Request", + }, + { + urlPattern: "https://*.tracking.example.com:*/*", + requestStage: "Request", + }, + { + urlPattern: "http://*.tracking.example.com./*", + requestStage: "Request", + }, + { + urlPattern: "http://*.tracking.example.com.:*/*", + requestStage: "Request", + }, + { + urlPattern: "https://*.tracking.example.com./*", + requestStage: "Request", + }, + { + urlPattern: "https://*.tracking.example.com.:*/*", + requestStage: "Request", + }, + ]); + }); + + it("matches exact and wildcard domains without matching unrelated suffixes", () => { + const policy = normalizeDomainPolicy({ + blockedDomains: ["ads.example.com", "*.tracking.example.com"], + }); + + expect(shouldBlockUrl("https://ads.example.com/script.js", policy)).toBe( + true, + ); + expect( + shouldBlockUrl("https://a.tracking.example.com/pixel.gif", policy), + ).toBe(true); + expect( + shouldBlockUrl("https://deep.a.tracking.example.com/pixel.gif", policy), + ).toBe(true); + expect(shouldBlockUrl("https://ads.example.com./script.js", policy)).toBe( + true, + ); + expect( + shouldBlockUrl("https://a.tracking.example.com./pixel.gif", policy), + ).toBe(true); + expect(shouldBlockUrl("https://tracking.example.com/", policy)).toBe(false); + expect(shouldBlockUrl("https://badtracking.example.com/", policy)).toBe( + false, + ); + expect(shouldBlockUrl("https://ads.example.com.evil.test/", policy)).toBe( + false, + ); + }); + + it("matches domains case-insensitively", () => { + const policy = normalizeDomainPolicy({ + blockedDomains: ["ADS.EXAMPLE.COM"], + }); + + expect(shouldBlockUrl("https://ads.example.com/script.js", policy)).toBe( + true, + ); + expect(shouldBlockUrl("https://ADS.EXAMPLE.COM/script.js", policy)).toBe( + true, + ); + }); + + it("continues malformed and non-HTTP URLs", () => { + const policy = normalizeDomainPolicy({ + blockedDomains: ["ads.example.com"], + }); + + expect(shouldBlockUrl("not a url", policy)).toBe(false); + expect(shouldBlockUrl("data:text/plain,hello", policy)).toBe(false); + expect(shouldBlockUrl("file:///tmp/example.html", policy)).toBe(false); + }); + + it("rejects invalid blocked domain patterns", () => { + expect(() => + normalizeDomainPolicy({ blockedDomains: ["https://example.com"] }), + ).toThrow(StagehandInvalidArgumentError); + expect(() => + normalizeDomainPolicy({ blockedDomains: ["*example.com"] }), + ).toThrow(StagehandInvalidArgumentError); + expect(() => + normalizeDomainPolicy({ blockedDomains: ["example"] }), + ).toThrow(StagehandInvalidArgumentError); + expect(() => + normalizeDomainPolicy({ + blockedDomains: [123] as unknown as string[], + }), + ).toThrow(StagehandInvalidArgumentError); + }); +}); diff --git a/packages/core/tests/unit/public-api/public-error-types.test.ts b/packages/core/tests/unit/public-api/public-error-types.test.ts index 9caa4ff8d..e0d7414e3 100644 --- a/packages/core/tests/unit/public-api/public-error-types.test.ts +++ b/packages/core/tests/unit/public-api/public-error-types.test.ts @@ -68,6 +68,7 @@ export const publicErrorTypes = { UnderstudyCommandException: Stagehand.UnderstudyCommandException, StagehandSetExtraHTTPHeadersError: Stagehand.StagehandSetExtraHTTPHeadersError, + StagehandSetDomainPolicyError: Stagehand.StagehandSetDomainPolicyError, } as const; const errorTypes = Object.keys(publicErrorTypes) as Array<