Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/5737.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Forward the client's `Date` header through `Webserver` proxy by accepting per-call header overrides on `Request.fetch()`, fixing signature-based authentication that binds the signature to the `Date` header.
52 changes: 49 additions & 3 deletions src/ai/backend/client/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import appdirs
import attrs
from aiohttp.client import _RequestContextManager, _WSRequestContextManager
from dateutil.parser import parse as parse_datetime
from dateutil.tz import tzutc
from multidict import CIMultiDict
from yarl import URL
Expand Down Expand Up @@ -277,6 +278,20 @@ def _pack_content(self) -> RequestContent | aiohttp.FormData:
return data
return self._content

def _apply_header_overrides(self, overrides: Mapping[str, str]) -> None:
"""
Apply caller-supplied header overrides on top of the auto-set headers.

If ``Date`` is overridden, parse it back into ``self.date`` so that
request signing uses the same value the upstream server will see.
Required for use cases like proxying a pre-signed request where the
original ``Date`` must be preserved end-to-end.
"""
for key, value in overrides.items():
self.headers[key] = value
if key.lower() == "date":
self.date = parse_datetime(value)
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overriding the on-the-wire Date header can break HMAC signing when the override is not already in ISO-8601 format: generate_signature() signs with date.isoformat(), while the Manager verifies signatures using the raw Date header string (see src/ai/backend/manager/api/rest/middleware/auth.py:396-405, which uses request["raw_date"]). With the current override behavior, a value like RFC1123 (Tue, 02 Sep ... GMT) will produce a different signed string than what the server reconstructs. Consider adjusting the signing path to use the exact header value that will be sent (e.g., thread the raw Date string into signing), or normalize the overridden header to the canonical format expected by the signature scheme.

Suggested change
If ``Date`` is overridden, parse it back into ``self.date`` so that
request signing uses the same value the upstream server will see.
Required for use cases like proxying a pre-signed request where the
original ``Date`` must be preserved end-to-end.
"""
for key, value in overrides.items():
self.headers[key] = value
if key.lower() == "date":
self.date = parse_datetime(value)
If ``Date`` is overridden, normalize it to the canonical ISO-8601
representation used by request signing so the on-the-wire header value
exactly matches what is included in the HMAC signature.
"""
for key, value in overrides.items():
if key.lower() == "date":
self.date = parse_datetime(value)
self.headers[key] = self.date.isoformat()
else:
self.headers[key] = value

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_datetime() can return a naive datetime when the overridden Date value has no timezone info. Elsewhere in the codebase the server normalizes missing tzinfo to UTC before using the date (e.g., manager/api/rest/middleware/auth.py:356-359). To keep client-side signing and date handling consistent, consider normalizing self.date to tzutc() when parse_datetime(value).tzinfo is None, and (optionally) failing fast with a clearer error if the Date override is unparsable.

Suggested change
self.date = parse_datetime(value)
try:
parsed_date = parse_datetime(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid Date header override: {value!r}") from exc
if parsed_date.tzinfo is None:
parsed_date = parsed_date.replace(tzinfo=tzutc())
self.date = parsed_date

Copilot uses AI. Check for mistakes.

def _build_url(self) -> URL:
base_url = self.config.endpoint.path.rstrip("/")
query_path = self.path.lstrip("/") if self.path is not None and len(self.path) > 0 else ""
Expand All @@ -291,7 +306,12 @@ def _build_url(self) -> URL:

# TODO: attach rate-limit information

def fetch(self, **kwargs: Any) -> FetchContextManager:
def fetch(
self,
*,
headers: Mapping[str, str] | None = None,
**kwargs: Any,
) -> FetchContextManager:
"""
Sends the request to the server and reads the response.

Expand All @@ -307,6 +327,11 @@ def fetch(self, **kwargs: Any) -> FetchContextManager:
rqst = Request('GET', ...)
async with rqst.fetch() as resp:
print(await resp.text())

:param headers: Header overrides applied after the auto-populated
headers. Useful for proxy use cases that need to forward the
client's original ``Date`` header (and matching signature)
instead of having it refreshed.
"""
if self.method not in self._allowed_methods:
raise ValueError(f"Disallowed HTTP method: {self.method}")
Expand All @@ -316,6 +341,8 @@ def fetch(self, **kwargs: Any) -> FetchContextManager:
self.headers["Date"] = self.date.isoformat()
if self.content_type is not None and "Content-Type" not in self.headers:
self.headers["Content-Type"] = self.content_type
if headers:
self._apply_header_overrides(headers)
force_anonymous = kwargs.pop("anonymous", False)

def _rqst_ctx_builder() -> _RequestContextManager:
Expand Down Expand Up @@ -345,7 +372,11 @@ def _rqst_ctx_builder() -> _RequestContextManager:
return FetchContextManager(self.session, _rqst_ctx_builder, self._session_mode, **kwargs)

def connect_websocket(
self, protocols: Iterable[str] = tuple(), **kwargs: Any
self,
protocols: Iterable[str] = tuple(),
*,
headers: Mapping[str, str] | None = None,
**kwargs: Any,
) -> WebSocketContextManager:
"""
Creates a WebSocket connection.
Expand All @@ -354,6 +385,9 @@ def connect_websocket(

This method only works with
:class:`~ai.backend.client.session.AsyncSession`.

:param headers: Header overrides applied after the auto-populated
headers. See :meth:`fetch` for details.
"""
if not isinstance(self.session, AsyncSession):
raise RuntimeError("Cannot use websockets with sessions in the synchronous mode")
Expand All @@ -365,6 +399,8 @@ def connect_websocket(
self.headers["Date"] = self.date.isoformat()
# websocket is always a "binary" stream.
self.content_type = "application/octet-stream"
if headers:
self._apply_header_overrides(headers)

def _ws_ctx_builder() -> _WSRequestContextManager:
full_url = self._build_url()
Expand All @@ -385,14 +421,22 @@ def _ws_ctx_builder() -> _WSRequestContextManager:

return WebSocketContextManager(self.session, _ws_ctx_builder, **kwargs)

def connect_events(self, **kwargs: Any) -> SSEContextManager:
def connect_events(
self,
*,
headers: Mapping[str, str] | None = None,
**kwargs: Any,
) -> SSEContextManager:
"""
Creates a Server-Sent Events connection.

.. warning::

This method only works with
:class:`~ai.backend.client.session.AsyncSession`.

:param headers: Header overrides applied after the auto-populated
headers. See :meth:`fetch` for details.
"""
if not isinstance(self.session, AsyncSession):
raise RuntimeError("Cannot use event streams with sessions in the synchronous mode")
Expand All @@ -403,6 +447,8 @@ def connect_events(self, **kwargs: Any) -> SSEContextManager:
raise RuntimeError("Failed to set request date")
self.headers["Date"] = self.date.isoformat()
self.content_type = "application/octet-stream"
if headers:
self._apply_header_overrides(headers)
Comment on lines 379 to +456
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior was added for connect_websocket(..., headers=...) and connect_events(..., headers=...), but the unit tests added here only cover fetch(). Since these are public entry points and apply overrides (including special-casing Date), consider adding small unit tests that assert header overrides are applied and self.date is updated (or not) for WebSocket/SSE as well.

Copilot uses AI. Check for mistakes.

def _rqst_ctx_builder() -> _RequestContextManager:
timeout_config = aiohttp.ClientTimeout(
Expand Down
24 changes: 21 additions & 3 deletions src/ai/backend/web/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,13 @@ async def web_handler(
continue
if (value := frontend_rqst.headers.get(key)) is not None:
backend_rqst.headers[key] = value
async with backend_rqst.fetch() as backend_resp:
# Preserve the client's original Date header so that signature-based
# auth schemes that bind the signature to Date keep working through
# the proxy. fetch() otherwise refreshes Date unconditionally.
fetch_header_overrides: dict[str, str] = {}
if (client_date := frontend_rqst.headers.get("Date")) is not None:
fetch_header_overrides["Date"] = client_date
async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp:
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forwards an unvalidated client-supplied Date header into Request.fetch(), which now parses it. If the header is malformed, parse_datetime() will raise and fall into the generic except Exception path here, turning a client input issue into a 500. Consider validating/parsing the header in the proxy layer (and returning a 400 on failure), or only applying the override when you explicitly expect a pre-signed request (otherwise ignore the client Date).

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same fetch_header_overrides construction is duplicated in web_handler, web_handler_with_jwt, and web_plugin_handler. Consider factoring this into a small helper (e.g., build a per-request override dict from frontend_rqst.headers) to reduce repetition and keep future header-forwarding changes consistent across handlers.

Suggested change
# Preserve the client's original Date header so that signature-based
# auth schemes that bind the signature to Date keep working through
# the proxy. fetch() otherwise refreshes Date unconditionally.
fetch_header_overrides: dict[str, str] = {}
if (client_date := frontend_rqst.headers.get("Date")) is not None:
fetch_header_overrides["Date"] = client_date
async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp:
def _build_fetch_header_overrides(headers: CIMultiDict[str]) -> dict[str, str]:
# Preserve the client's original Date header so that signature-based
# auth schemes that bind the signature to Date keep working through
# the proxy. fetch() otherwise refreshes Date unconditionally.
fetch_header_overrides: dict[str, str] = {}
if (client_date := headers.get("Date")) is not None:
fetch_header_overrides["Date"] = client_date
return fetch_header_overrides
async with backend_rqst.fetch(
headers=_build_fetch_header_overrides(frontend_rqst.headers)
) as backend_resp:

Copilot uses AI. Check for mistakes.
frontend_resp_hdrs = {
key: value
for key, value in backend_resp.headers.items()
Expand Down Expand Up @@ -427,8 +433,14 @@ async def web_handler_with_jwt(
if (value := frontend_rqst.headers.get(key)) is not None:
backend_rqst.headers[key] = value

# Preserve the client's original Date header so that signature-based
# auth schemes that bind the signature to Date keep working through
# the proxy. fetch() otherwise refreshes Date unconditionally.
fetch_header_overrides: dict[str, str] = {}
if (client_date := frontend_rqst.headers.get("Date")) is not None:
fetch_header_overrides["Date"] = client_date
# Fetch from backend and stream response
async with backend_rqst.fetch() as backend_resp:
async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp:
frontend_resp_hdrs = {
key: value
for key, value in backend_resp.headers.items()
Expand Down Expand Up @@ -530,7 +542,13 @@ async def web_plugin_handler(
for key in HTTP_HEADERS_TO_FORWARD:
if (value := frontend_rqst.headers.get(key)) is not None:
backend_rqst.headers[key] = value
async with backend_rqst.fetch() as backend_resp:
# Preserve the client's original Date header so that signature-based
# auth schemes that bind the signature to Date keep working through
# the proxy. fetch() otherwise refreshes Date unconditionally.
fetch_header_overrides: dict[str, str] = {}
if (client_date := frontend_rqst.headers.get("Date")) is not None:
fetch_header_overrides["Date"] = client_date
async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp:
frontend_resp_hdrs = {
key: value
for key, value in backend_resp.headers.items()
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/client/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,49 @@ async def test_response_async(defconfig: APIConfig, dummy_endpoint: str) -> None
async with rqst.fetch() as resp:
assert await resp.text() == '{"test": 5678}'
assert await resp.json() == {"test": 5678}


async def test_fetch_preserves_date_header_override(dummy_endpoint: str) -> None:
fixed_date = "Tue, 02 Sep 2025 08:00:00 GMT"
with aioresponses() as m:
m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"")
async with AsyncSession():
rqst = Request("POST", "function")
async with rqst.fetch(headers={"Date": fixed_date}):
pass

sent_kwargs = next(iter(m.requests.values()))[0].kwargs
assert sent_kwargs["headers"]["Date"] == fixed_date

# The internal datetime must also reflect the override so that signing
# done with `self.date` matches the Date header sent on the wire.
assert rqst.date is not None
assert rqst.date.year == 2025
assert rqst.date.month == 9
assert rqst.date.day == 2


async def test_fetch_default_date_header_when_no_override(dummy_endpoint: str) -> None:
with aioresponses() as m:
m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"")
async with AsyncSession():
rqst = Request("POST", "function")
async with rqst.fetch():
pass

sent_kwargs = next(iter(m.requests.values()))[0].kwargs
# Without override, fetch() auto-populates Date with self.date.isoformat().
assert rqst.date is not None
assert sent_kwargs["headers"]["Date"] == rqst.date.isoformat()


async def test_fetch_passes_through_arbitrary_header_overrides(dummy_endpoint: str) -> None:
with aioresponses() as m:
m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"")
async with AsyncSession():
rqst = Request("POST", "function")
async with rqst.fetch(headers={"X-Custom-Header": "custom-value"}):
pass

sent_kwargs = next(iter(m.requests.values()))[0].kwargs
assert sent_kwargs["headers"]["X-Custom-Header"] == "custom-value"
Loading