Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/sweet-cameras-act.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": minor
---

add context.setDomainPolicy() which allows users to define a list of domains that will be blocked by stagehand
6 changes: 6 additions & 0 deletions packages/core/lib/v3/types/public/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ export interface ClearCookieOptions {
domain?: string | RegExp;
path?: string | RegExp;
}

/** Context-wide network domain policy. */
export interface DomainPolicy {
Comment thread
seanmcguire12 marked this conversation as resolved.
/** Domain-only block patterns, e.g. "example.com" or "*.example.com". */
blockedDomains?: string[];
}
11 changes: 11 additions & 0 deletions packages/core/lib/v3/types/public/sdkErrors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,17 @@ export class StagehandSetExtraHTTPHeadersError extends StagehandError {
}
}

export class StagehandSetDomainPolicyError extends StagehandError {
public readonly failures: string[];

constructor(failures: string[]) {
super(
`setDomainPolicy failed for ${failures.length} session(s): ${failures.join(", ")}`,
);
this.failures = failures;
}
}

export class StagehandSnapshotError extends StagehandError {
constructor(cause?: unknown) {
const suffix =
Expand Down
186 changes: 186 additions & 0 deletions packages/core/lib/v3/understudy/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
CookieSetError,
PageNotFoundError,
StagehandSetExtraHTTPHeadersError,
StagehandSetDomainPolicyError,
} from "../types/public/sdkErrors.js";
import { getEnvTimeoutMs, withTimeout } from "../timeoutConfig.js";
import {
Expand All @@ -29,7 +30,10 @@ import {
Cookie,
ClearCookieOptions,
CookieParam,
DomainPolicy,
} from "../types/public/context.js";
import { normalizeDomainPolicy, shouldBlockUrl } from "./domainPolicy.js";
import type { NormalizedDomainPolicy } from "./domainPolicy.js";

type TargetId = string;
type SessionId = string;
Expand Down Expand Up @@ -113,6 +117,10 @@ export class V3Context {
// Timestamp for most recent popup/open signal
private _lastPopupSignalAt = 0;
private readonly _targetSessionListeners = new Set<SessionId>();
private readonly _domainPolicySessionListeners = new Map<
SessionId,
(evt: Protocol.Fetch.RequestPausedEvent) => void
>();

private readonly _sessionInit = new Set<SessionId>();
private pagesByTarget = new Map<TargetId, Page>();
Expand All @@ -124,8 +132,10 @@ export class V3Context {
private typeByTarget = new Map<TargetId, TargetType>();
private _pageOrder: TargetId[] = [];
private pendingCreatedTargetUrl = new Map<TargetId, string>();
private pageCreationFailures = new Map<TargetId, Error>();
private readonly initScripts: string[] = [];
private extraHttpHeaders: Record<string, string> | null = null;
private domainPolicy: NormalizedDomainPolicy | null = null;
private _clipboard?: ContextClipboard;

private installTargetSessionListeners(session: CDPSessionLike): void {
Expand Down Expand Up @@ -425,6 +435,127 @@ export class V3Context {
}
}

public getDomainPolicy(): DomainPolicy | null {
if (!this.domainPolicy) return null;
return { blockedDomains: [...this.domainPolicy.blockedDomains] };
}

public async setDomainPolicy(policy: DomainPolicy | null): Promise<void> {
Comment thread
seanmcguire12 marked this conversation as resolved.
const nextPolicy = normalizeDomainPolicy(policy);
this.domainPolicy = nextPolicy;

const sessions: CDPSessionLike[] = [];
for (const sessionId of this._sessionInit) {
const session = this.conn.getSession(sessionId);
if (session) sessions.push(session);
}

if (!sessions.length) return;

const results = await Promise.allSettled(
sessions.map(async (session) => {
if (!nextPolicy) {
await session.send("Fetch.disable");
this.uninstallDomainPolicyHandler(session);
return;
}

this.installDomainPolicyHandler(session);
await session.send("Fetch.enable", {
patterns: nextPolicy.fetchPatterns,
});
}),
);

const failures = results
.map((result, index) => ({ result, session: sessions[index] }))
.filter(
(
entry,
): entry is {
result: PromiseRejectedResult;
session: CDPSessionLike;
} => entry.result.status === "rejected",
)
.map((entry) => {
const reason = entry.result.reason as Error;
const sid = entry.session.id ?? "unknown";
const message = reason?.message ?? String(reason);
return `session=${sid} error=${message}`;
});

if (failures.length) {
throw new StagehandSetDomainPolicyError(failures);
}
}

private installDomainPolicyHandler(session: CDPSessionLike): void {
const sessionId = session.id;
if (!sessionId) return;
if (this._domainPolicySessionListeners.has(sessionId)) return;

const handler = (evt: Protocol.Fetch.RequestPausedEvent) => {
void this.handleDomainPolicyRequestPaused(session, evt);
};
this._domainPolicySessionListeners.set(sessionId, handler);
session.on<Protocol.Fetch.RequestPausedEvent>(
"Fetch.requestPaused",
handler,
);
}

private uninstallDomainPolicyHandler(session: CDPSessionLike): void {
const sessionId = session.id;
if (!sessionId) return;

const handler = this._domainPolicySessionListeners.get(sessionId);
if (!handler) return;

session.off<Protocol.Fetch.RequestPausedEvent>(
"Fetch.requestPaused",
handler,
);
this._domainPolicySessionListeners.delete(sessionId);
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
}

private async handleDomainPolicyRequestPaused(
session: CDPSessionLike,
evt: Protocol.Fetch.RequestPausedEvent,
): Promise<void> {
const policy = this.domainPolicy;

if (!policy || !shouldBlockUrl(evt.request.url, policy)) {
await session
.send("Fetch.continueRequest", { requestId: evt.requestId })
.catch(() => {});
return;
}

let hostname = "";
try {
hostname = new URL(evt.request.url).hostname.toLowerCase();
} catch {
// ignore malformed URLs for logging
}

v3Logger({
category: "network",
message: "Blocked request by domain policy",
level: 2,
Comment thread
seanmcguire12 marked this conversation as resolved.
auxiliary: {
hostname: { value: hostname, type: "string" },
ruleType: { value: "blockedDomains", type: "string" },
},
});

await session
.send("Fetch.failRequest", {
requestId: evt.requestId,
errorReason: "BlockedByClient",
})
.catch(() => {});
}

public get clipboard(): BrowserClipboard {
return (this._clipboard ??= new ContextClipboard({
context: this,
Expand Down Expand Up @@ -503,6 +634,12 @@ export class V3Context {

const deadline = Date.now() + 5000;
while (Date.now() < deadline) {
const failure = this.pageCreationFailures.get(targetId);
if (failure) {
this.pageCreationFailures.delete(targetId);
throw failure;
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
}

const page = this.pagesByTarget.get(targetId);
if (page) {
// we created at about:blank; navigate only after attach so init scripts run
Expand Down Expand Up @@ -534,6 +671,7 @@ export class V3Context {
this.createdAtByTarget.clear();
this.typeByTarget.clear();
this.pendingCreatedTargetUrl.clear();
this.pageCreationFailures.clear();
}

/**
Expand Down Expand Up @@ -697,6 +835,18 @@ export class V3Context {
queuePreResume("Network.setExtraHTTPHeaders", { headers }),
);
}
const fetchPreResumeOps: Array<{
dispatched: Promise<boolean>;
response: Promise<boolean>;
}> = [];
if (this.domainPolicy) {
this.installDomainPolicyHandler(session);
fetchPreResumeOps.push(
queuePreResume("Fetch.enable", {
patterns: this.domainPolicy.fetchPatterns,
}),
);
}
// Send init scripts only after auto-attach has been queued.
if (this.initScripts.length) {
for (const source of this.initScripts) {
Expand Down Expand Up @@ -728,6 +878,7 @@ export class V3Context {
await Promise.all([
...corePreResumeOps.map((op) => op.dispatched),
...headerPreResumeOps.map((op) => op.dispatched),
...fetchPreResumeOps.map((op) => op.dispatched),
...initScriptOps.map((op) => op.dispatched),
piercerPreloadOp.dispatched,
])
Expand All @@ -741,11 +892,13 @@ export class V3Context {
const [
coreResults,
headerResults,
fetchResults,
initScriptResults,
piercerPreRegistered,
] = await Promise.all([
Promise.all(corePreResumeOps.map((op) => op.response)),
Promise.all(headerPreResumeOps.map((op) => op.response)),
Promise.all(fetchPreResumeOps.map((op) => op.response)),
Promise.all(initScriptOps.map((op) => op.response)),
piercerPreloadOp.response,
]);
Expand Down Expand Up @@ -778,6 +931,32 @@ export class V3Context {
return;
}
resumed = true;

if (fetchPreResumeOps.length > 0 && !fetchResults.every(Boolean)) {
this.uninstallDomainPolicyHandler(session);
if (this.pendingCreatedTargetUrl.has(info.targetId)) {
this.pageCreationFailures.set(
info.targetId,
new StagehandSetDomainPolicyError([
`session=${sessionId} target=${info.targetId} error=Fetch.enable failed during target attach`,
]),
);
}
v3Logger({
category: "ctx",
message: "Failed to enable domain policy for target",
level: 1,
Comment thread
seanmcguire12 marked this conversation as resolved.
Outdated
Comment thread
seanmcguire12 marked this conversation as resolved.
Outdated
auxiliary: {
targetId: { value: String(info.targetId), type: "string" },
targetType: { value: String(info.type), type: "string" },
},
});
await this.conn
.send("Target.closeTarget", { targetId: info.targetId })
.catch(() => {});
return;
}

const scriptsInstalled =
coreResults.every(Boolean) && initScriptResults.every(Boolean);

Expand Down Expand Up @@ -933,6 +1112,12 @@ export class V3Context {
}

this._targetSessionListeners.delete(sessionId);
const session = this.conn.getSession(sessionId);
if (session) {
this.uninstallDomainPolicyHandler(session);
} else {
this._domainPolicySessionListeners.delete(sessionId);
}
this._sessionInit.delete(sessionId);
this._piercerInstalled.delete(sessionId);
}
Expand All @@ -944,6 +1129,7 @@ export class V3Context {
const page = this.pagesByTarget.get(targetId);
if (!page) return;

this.pageCreationFailures.delete(targetId);
const mainId = page.mainFrameId();
this.mainFrameToTarget.delete(mainId);
this.frameOwnerPage.delete(mainId);
Expand Down
Loading
Loading