Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/sweet-cameras-act.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": minor
---

add `context.setDomainPolicy({blockedDomains: ["some.domain"]})` which allows users to define a list of domains that will be blocked by stagehand
6 changes: 6 additions & 0 deletions packages/core/lib/v3/types/public/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ export interface ClearCookieOptions {
domain?: string | RegExp;
path?: string | RegExp;
}

/** Context-wide network domain policy. */
export interface DomainPolicy {
Comment thread
seanmcguire12 marked this conversation as resolved.
/** Domain-only block patterns, e.g. "example.com" or "*.example.com". */
blockedDomains?: string[];
}
11 changes: 11 additions & 0 deletions packages/core/lib/v3/types/public/sdkErrors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,17 @@ export class StagehandSetExtraHTTPHeadersError extends StagehandError {
}
}

export class StagehandSetDomainPolicyError extends StagehandError {
public readonly failures: string[];

constructor(failures: string[]) {
super(
`setDomainPolicy failed for ${failures.length} session(s): ${failures.join(", ")}`,
);
this.failures = failures;
}
}

export class StagehandSnapshotError extends StagehandError {
constructor(cause?: unknown) {
const suffix =
Expand Down
237 changes: 237 additions & 0 deletions packages/core/lib/v3/understudy/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
CookieSetError,
PageNotFoundError,
StagehandSetExtraHTTPHeadersError,
StagehandSetDomainPolicyError,
} from "../types/public/sdkErrors.js";
import { getEnvTimeoutMs, withTimeout } from "../timeoutConfig.js";
import {
Expand All @@ -29,7 +30,10 @@ import {
Cookie,
ClearCookieOptions,
CookieParam,
DomainPolicy,
} from "../types/public/context.js";
import { normalizeDomainPolicy, shouldBlockUrl } from "./domainPolicy.js";
import type { NormalizedDomainPolicy } from "./domainPolicy.js";

type TargetId = string;
type SessionId = string;
Expand Down Expand Up @@ -113,6 +117,10 @@ export class V3Context {
// Timestamp for most recent popup/open signal
private _lastPopupSignalAt = 0;
private readonly _targetSessionListeners = new Set<SessionId>();
private readonly _domainPolicySessionListeners = new Map<
SessionId,
(evt: Protocol.Fetch.RequestPausedEvent) => void
>();

private readonly _sessionInit = new Set<SessionId>();
private pagesByTarget = new Map<TargetId, Page>();
Expand All @@ -124,8 +132,10 @@ export class V3Context {
private typeByTarget = new Map<TargetId, TargetType>();
private _pageOrder: TargetId[] = [];
private pendingCreatedTargetUrl = new Map<TargetId, string>();
private pageCreationFailures = new Map<TargetId, Error>();
private readonly initScripts: string[] = [];
private extraHttpHeaders: Record<string, string> | null = null;
private domainPolicy: NormalizedDomainPolicy | null = null;
private _clipboard?: ContextClipboard;

private installTargetSessionListeners(session: CDPSessionLike): void {
Expand Down Expand Up @@ -425,6 +435,143 @@ export class V3Context {
}
}

public getDomainPolicy(): DomainPolicy | null {
if (!this.domainPolicy) return null;
return { blockedDomains: [...this.domainPolicy.blockedDomains] };
}

public async setDomainPolicy(policy: DomainPolicy | null): Promise<void> {
Comment thread
seanmcguire12 marked this conversation as resolved.
const nextPolicy = normalizeDomainPolicy(policy);
this.domainPolicy = nextPolicy;

const sessions: CDPSessionLike[] = [];
for (const sessionId of this._sessionInit) {
const session = this.conn.getSession(sessionId);
if (session) sessions.push(session);
}

if (!sessions.length) return;

const results = await Promise.allSettled(
sessions.map(async (session) => {
if (!nextPolicy) {
try {
await session.send("Fetch.disable");
this.uninstallDomainPolicyHandler(session);
} catch (error) {
throw { action: "disable", error };
}
return;
}

this.installDomainPolicyHandler(session);
try {
await session.send("Fetch.enable", {
patterns: nextPolicy.fetchPatterns,
});
} catch (error) {
throw { action: "enable", error };
}
}),
);

const failures = results
.map((result, index) => ({ result, session: sessions[index] }))
.filter(
(
entry,
): entry is {
result: PromiseRejectedResult;
session: CDPSessionLike;
} => entry.result.status === "rejected",
)
.map((entry) => {
const failure = entry.result.reason as {
action?: "enable" | "disable";
error?: unknown;
};
if (failure?.action === "enable") {
this.uninstallDomainPolicyHandler(entry.session);
}
const reason = failure?.error ?? entry.result.reason;
const sid = entry.session.id ?? "unknown";
const message =
reason instanceof Error ? reason.message : String(reason);
return `session=${sid} error=${message}`;
});

if (failures.length) {
throw new StagehandSetDomainPolicyError(failures);
}
}

private installDomainPolicyHandler(session: CDPSessionLike): void {
const sessionId = session.id;
if (!sessionId) return;
if (this._domainPolicySessionListeners.has(sessionId)) return;

const handler = (evt: Protocol.Fetch.RequestPausedEvent) => {
void this.handleDomainPolicyRequestPaused(session, evt);
};
this._domainPolicySessionListeners.set(sessionId, handler);
session.on<Protocol.Fetch.RequestPausedEvent>(
"Fetch.requestPaused",
handler,
);
}

private uninstallDomainPolicyHandler(session: CDPSessionLike): void {
const sessionId = session.id;
if (!sessionId) return;

const handler = this._domainPolicySessionListeners.get(sessionId);
if (!handler) return;

session.off<Protocol.Fetch.RequestPausedEvent>(
"Fetch.requestPaused",
handler,
);
this._domainPolicySessionListeners.delete(sessionId);
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
}

private async handleDomainPolicyRequestPaused(
session: CDPSessionLike,
evt: Protocol.Fetch.RequestPausedEvent,
): Promise<void> {
const policy = this.domainPolicy;

if (!policy || !shouldBlockUrl(evt.request.url, policy)) {
await session
.send("Fetch.continueRequest", { requestId: evt.requestId })
.catch(() => {});
return;
}

let hostname = "";
try {
hostname = new URL(evt.request.url).hostname.toLowerCase();
} catch {
// ignore malformed URLs for logging
}

v3Logger({
category: "network",
message: "Blocked request by domain policy",
level: 2,
Comment thread
seanmcguire12 marked this conversation as resolved.
auxiliary: {
hostname: { value: hostname, type: "string" },
ruleType: { value: "blockedDomains", type: "string" },
},
});

await session
.send("Fetch.failRequest", {
requestId: evt.requestId,
errorReason: "BlockedByClient",
})
.catch(() => {});
}

public get clipboard(): BrowserClipboard {
return (this._clipboard ??= new ContextClipboard({
context: this,
Expand Down Expand Up @@ -503,6 +650,13 @@ export class V3Context {

const deadline = Date.now() + 5000;
while (Date.now() < deadline) {
const failure = this.pageCreationFailures.get(targetId);
if (failure) {
this.pageCreationFailures.delete(targetId);
this.pendingCreatedTargetUrl.delete(targetId);
throw failure;
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
}

const page = this.pagesByTarget.get(targetId);
if (page) {
// we created at about:blank; navigate only after attach so init scripts run
Expand All @@ -518,6 +672,7 @@ export class V3Context {
}
await new Promise((r) => setTimeout(r, 25));
}
this.pendingCreatedTargetUrl.delete(targetId);
throw new TimeoutError(`newPage: target not attached (${targetId})`, 5000);
}

Expand All @@ -534,6 +689,7 @@ export class V3Context {
this.createdAtByTarget.clear();
this.typeByTarget.clear();
this.pendingCreatedTargetUrl.clear();
this.pageCreationFailures.clear();
}

/**
Expand Down Expand Up @@ -668,6 +824,24 @@ export class V3Context {
.catch(() => false);
return { dispatched, response };
};
const queueFetchEnablePreResume = (params: object) => {
let error: unknown;
const dispatched = this.conn
.waitForSessionDispatch(sessionId, "Fetch.enable")
.then(() => true)
.catch((err) => {
error = err;
return false;
});
const response = session
.send("Fetch.enable", params)
.then(() => true)
.catch((err) => {
error = err;
return false;
});
return { dispatched, response, getError: () => error };
};
const initScriptOps: Array<{
dispatched: Promise<boolean>;
response: Promise<boolean>;
Expand Down Expand Up @@ -697,6 +871,19 @@ export class V3Context {
queuePreResume("Network.setExtraHTTPHeaders", { headers }),
);
}
const fetchPreResumeOps: Array<{
dispatched: Promise<boolean>;
response: Promise<boolean>;
getError: () => unknown;
}> = [];
if (this.domainPolicy) {
this.installDomainPolicyHandler(session);
fetchPreResumeOps.push(
queueFetchEnablePreResume({
patterns: this.domainPolicy.fetchPatterns,
}),
);
}
// Send init scripts only after auto-attach has been queued.
if (this.initScripts.length) {
for (const source of this.initScripts) {
Expand Down Expand Up @@ -728,6 +915,7 @@ export class V3Context {
await Promise.all([
...corePreResumeOps.map((op) => op.dispatched),
...headerPreResumeOps.map((op) => op.dispatched),
...fetchPreResumeOps.map((op) => op.dispatched),
...initScriptOps.map((op) => op.dispatched),
piercerPreloadOp.dispatched,
])
Expand All @@ -741,11 +929,13 @@ export class V3Context {
const [
coreResults,
headerResults,
fetchResults,
initScriptResults,
piercerPreRegistered,
] = await Promise.all([
Promise.all(corePreResumeOps.map((op) => op.response)),
Promise.all(headerPreResumeOps.map((op) => op.response)),
Promise.all(fetchPreResumeOps.map((op) => op.response)),
Promise.all(initScriptOps.map((op) => op.response)),
piercerPreloadOp.response,
]);
Expand Down Expand Up @@ -778,6 +968,46 @@ export class V3Context {
return;
}
resumed = true;

if (fetchPreResumeOps.length > 0 && !fetchResults.every(Boolean)) {
this.uninstallDomainPolicyHandler(session);
const fetchError = fetchPreResumeOps
.map((op) => op.getError())
.find((error) => error !== undefined);
const fetchErrorMessage = fetchError
? fetchError instanceof Error
? fetchError.message
: String(fetchError)
: "Fetch.enable failed during target attach";
const policyFailureMessage =
"Fetch.enable failed during target attach; closing target because " +
"Stagehand cannot guarantee domain policy enforcement";
if (this.pendingCreatedTargetUrl.has(info.targetId)) {
this.pageCreationFailures.set(
info.targetId,
new StagehandSetDomainPolicyError([
`session=${sessionId} target=${info.targetId} error=${policyFailureMessage} cdpError=${fetchErrorMessage}`,
]),
);
}
v3Logger({
category: "ctx",
message: "Closing target because domain policy could not be guaranteed",
level: 0,
auxiliary: {
targetId: { value: String(info.targetId), type: "string" },
targetType: { value: String(info.type), type: "string" },
targetUrl: { value: String(info.url ?? ""), type: "string" },
sessionId: { value: sessionId, type: "string" },
cdpError: { value: fetchErrorMessage, type: "string" },
},
});
await this.conn
.send("Target.closeTarget", { targetId: info.targetId })
.catch(() => {});
return;
}

const scriptsInstalled =
coreResults.every(Boolean) && initScriptResults.every(Boolean);

Expand Down Expand Up @@ -933,6 +1163,12 @@ export class V3Context {
}

this._targetSessionListeners.delete(sessionId);
const session = this.conn.getSession(sessionId);
if (session) {
this.uninstallDomainPolicyHandler(session);
} else {
this._domainPolicySessionListeners.delete(sessionId);
}
this._sessionInit.delete(sessionId);
this._piercerInstalled.delete(sessionId);
}
Expand All @@ -944,6 +1180,7 @@ export class V3Context {
const page = this.pagesByTarget.get(targetId);
if (!page) return;

this.pageCreationFailures.delete(targetId);
const mainId = page.mainFrameId();
this.mainFrameToTarget.delete(mainId);
this.frameOwnerPage.delete(mainId);
Expand Down
Loading
Loading