Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/5737.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Forward the client's `Date` header through the `Webserver` proxy on the anonymous (signature pass-through) path so signature schemes that bind the upstream signature to `Date` keep working end-to-end; `Request.fetch()` now accepts a `headers=` override for this. The re-signing path keeps refreshing `Date` so a client cannot dictate the timestamp signed by the proxy.
57 changes: 54 additions & 3 deletions src/ai/backend/client/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import appdirs
import attrs
from aiohttp.client import _RequestContextManager, _WSRequestContextManager
from dateutil.parser import parse as parse_datetime
from dateutil.tz import tzutc
from multidict import CIMultiDict
from yarl import URL
Expand Down Expand Up @@ -277,6 +278,25 @@ def _pack_content(self) -> RequestContent | aiohttp.FormData:
return data
return self._content

def _apply_header_overrides(self, overrides: Mapping[str, str]) -> None:
"""
Apply caller-supplied header overrides on top of the auto-set headers.

If ``Date`` is overridden, also parse it back into ``self.date`` so
that request signing uses the same value the upstream server will
see. If the override value cannot be parsed, the header is still
forwarded but ``self.date`` is left untouched; signing should not
be enabled in that case (e.g., the proxy gates this on anonymous
sessions where ``_sign()`` is skipped).
"""
for key, value in overrides.items():
self.headers[key] = value
if key.lower() == "date":
try:
self.date = parse_datetime(value)
except (ValueError, OverflowError):
pass

def _build_url(self) -> URL:
base_url = self.config.endpoint.path.rstrip("/")
query_path = self.path.lstrip("/") if self.path is not None and len(self.path) > 0 else ""
Expand All @@ -291,7 +311,12 @@ def _build_url(self) -> URL:

# TODO: attach rate-limit information

def fetch(self, **kwargs: Any) -> FetchContextManager:
def fetch(
self,
*,
headers: Mapping[str, str] | None = None,
**kwargs: Any,
) -> FetchContextManager:
"""
Sends the request to the server and reads the response.

Expand All @@ -307,6 +332,11 @@ def fetch(self, **kwargs: Any) -> FetchContextManager:
rqst = Request('GET', ...)
async with rqst.fetch() as resp:
print(await resp.text())

:param headers: Header overrides applied after the auto-populated
headers. Useful for proxy use cases that need to forward the
client's original ``Date`` header (and matching signature)
instead of having it refreshed.
"""
if self.method not in self._allowed_methods:
raise ValueError(f"Disallowed HTTP method: {self.method}")
Expand All @@ -316,6 +346,8 @@ def fetch(self, **kwargs: Any) -> FetchContextManager:
self.headers["Date"] = self.date.isoformat()
if self.content_type is not None and "Content-Type" not in self.headers:
self.headers["Content-Type"] = self.content_type
if headers:
self._apply_header_overrides(headers)
force_anonymous = kwargs.pop("anonymous", False)

def _rqst_ctx_builder() -> _RequestContextManager:
Expand Down Expand Up @@ -345,7 +377,11 @@ def _rqst_ctx_builder() -> _RequestContextManager:
return FetchContextManager(self.session, _rqst_ctx_builder, self._session_mode, **kwargs)

def connect_websocket(
self, protocols: Iterable[str] = tuple(), **kwargs: Any
self,
protocols: Iterable[str] = tuple(),
*,
headers: Mapping[str, str] | None = None,
**kwargs: Any,
) -> WebSocketContextManager:
"""
Creates a WebSocket connection.
Expand All @@ -354,6 +390,9 @@ def connect_websocket(

This method only works with
:class:`~ai.backend.client.session.AsyncSession`.

:param headers: Header overrides applied after the auto-populated
headers. See :meth:`fetch` for details.
"""
if not isinstance(self.session, AsyncSession):
raise RuntimeError("Cannot use websockets with sessions in the synchronous mode")
Expand All @@ -365,6 +404,8 @@ def connect_websocket(
self.headers["Date"] = self.date.isoformat()
# websocket is always a "binary" stream.
self.content_type = "application/octet-stream"
if headers:
self._apply_header_overrides(headers)

def _ws_ctx_builder() -> _WSRequestContextManager:
full_url = self._build_url()
Expand All @@ -385,14 +426,22 @@ def _ws_ctx_builder() -> _WSRequestContextManager:

return WebSocketContextManager(self.session, _ws_ctx_builder, **kwargs)

def connect_events(self, **kwargs: Any) -> SSEContextManager:
def connect_events(
self,
*,
headers: Mapping[str, str] | None = None,
**kwargs: Any,
) -> SSEContextManager:
"""
Creates a Server-Sent Events connection.

.. warning::

This method only works with
:class:`~ai.backend.client.session.AsyncSession`.

:param headers: Header overrides applied after the auto-populated
headers. See :meth:`fetch` for details.
"""
if not isinstance(self.session, AsyncSession):
raise RuntimeError("Cannot use event streams with sessions in the synchronous mode")
Expand All @@ -403,6 +452,8 @@ def connect_events(self, **kwargs: Any) -> SSEContextManager:
raise RuntimeError("Failed to set request date")
self.headers["Date"] = self.date.isoformat()
self.content_type = "application/octet-stream"
if headers:
self._apply_header_overrides(headers)
Comment on lines 379 to +456
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior was added for connect_websocket(..., headers=...) and connect_events(..., headers=...), but the unit tests added here only cover fetch(). Since these are public entry points and apply overrides (including special-casing Date), consider adding small unit tests that assert header overrides are applied and self.date is updated (or not) for WebSocket/SSE as well.

Copilot uses AI. Check for mistakes.

def _rqst_ctx_builder() -> _RequestContextManager:
timeout_config = aiohttp.ClientTimeout(
Expand Down
44 changes: 41 additions & 3 deletions src/ai/backend/web/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ai.backend.client.exceptions import BackendAPIError, BackendClientError
from ai.backend.client.request import Request, RequestContent, SessionMode
from ai.backend.client.session import AsyncSession as APISession
from ai.backend.common.exception import InvalidAPIParameters
from ai.backend.common.web.session import STORAGE_KEY, extra_config_headers, get_session
from ai.backend.logging import BraceStyleAdapter
Expand Down Expand Up @@ -141,6 +142,30 @@ async def close_upstream(self) -> None:
await self.up_conn.close()


def _pass_through_date_header(
frontend_rqst: web.Request,
api_session: APISession,
) -> dict[str, str]:
"""
Build the ``headers`` override for ``Request.fetch()`` that forwards the
client's original ``Date`` header when (and only when) the proxy is in
pure pass-through mode.

Forwarding ``Date`` is only safe for anonymous sessions, where the client
owns the upstream signature and the ``Date`` it bound the signature to
must survive end-to-end. In the re-signing path the proxy signs with its
own keypair, so a client-supplied ``Date`` would let the caller dictate
the timestamp signed on the wire — keep ``fetch()``'s auto-refreshed
``Date`` there instead.
"""
if not api_session.config.is_anonymous:
return {}
client_date = frontend_rqst.headers.get("Date")
if client_date is None:
return {}
return {"Date": client_date}


def _decrypt_payload(endpoint: str, payload: bytes) -> bytes:
iv, real_payload = payload.split(b":")
key = (base64.b64encode(endpoint.encode("ascii")) + iv + iv)[:32]
Expand Down Expand Up @@ -255,7 +280,14 @@ async def web_handler(
continue
if (value := frontend_rqst.headers.get(key)) is not None:
backend_rqst.headers[key] = value
async with backend_rqst.fetch() as backend_resp:
# When the proxy is in pure pass-through mode (anonymous session),
# the client owns the upstream signature and the Date it was bound
# to must survive end-to-end. In the re-signing path the proxy
# signs with its own keypair, so a client-supplied Date would let
# the caller dictate the timestamp signed on the wire — keep the
# auto-refreshed Date there.
fetch_header_overrides = _pass_through_date_header(frontend_rqst, api_session)
async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp:
frontend_resp_hdrs = {
key: value
for key, value in backend_resp.headers.items()
Expand Down Expand Up @@ -427,8 +459,11 @@ async def web_handler_with_jwt(
if (value := frontend_rqst.headers.get(key)) is not None:
backend_rqst.headers[key] = value

# See `_pass_through_date_header` for why this is gated on
# anonymous sessions only.
fetch_header_overrides = _pass_through_date_header(frontend_rqst, api_session)
# Fetch from backend and stream response
async with backend_rqst.fetch() as backend_resp:
async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp:
frontend_resp_hdrs = {
key: value
for key, value in backend_resp.headers.items()
Expand Down Expand Up @@ -530,7 +565,10 @@ async def web_plugin_handler(
for key in HTTP_HEADERS_TO_FORWARD:
if (value := frontend_rqst.headers.get(key)) is not None:
backend_rqst.headers[key] = value
async with backend_rqst.fetch() as backend_resp:
# See `_pass_through_date_header` for why this is gated on
# anonymous sessions only.
fetch_header_overrides = _pass_through_date_header(frontend_rqst, api_session)
async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp:
frontend_resp_hdrs = {
key: value
for key, value in backend_resp.headers.items()
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/client/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,67 @@ async def test_response_async(defconfig: APIConfig, dummy_endpoint: str) -> None
async with rqst.fetch() as resp:
assert await resp.text() == '{"test": 5678}'
assert await resp.json() == {"test": 5678}


async def test_fetch_preserves_date_header_override(dummy_endpoint: str) -> None:
fixed_date = "Tue, 02 Sep 2025 08:00:00 GMT"
with aioresponses() as m:
m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"")
async with AsyncSession():
rqst = Request("POST", "function")
async with rqst.fetch(headers={"Date": fixed_date}):
pass

sent_kwargs = next(iter(m.requests.values()))[0].kwargs
assert sent_kwargs["headers"]["Date"] == fixed_date

# The internal datetime must also reflect the override so that signing
# done with `self.date` matches the Date header sent on the wire.
assert rqst.date is not None
assert rqst.date.year == 2025
assert rqst.date.month == 9
assert rqst.date.day == 2


async def test_fetch_default_date_header_when_no_override(dummy_endpoint: str) -> None:
with aioresponses() as m:
m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"")
async with AsyncSession():
rqst = Request("POST", "function")
async with rqst.fetch():
pass

sent_kwargs = next(iter(m.requests.values()))[0].kwargs
# Without override, fetch() auto-populates Date with self.date.isoformat().
assert rqst.date is not None
assert sent_kwargs["headers"]["Date"] == rqst.date.isoformat()


async def test_fetch_passes_through_arbitrary_header_overrides(dummy_endpoint: str) -> None:
with aioresponses() as m:
m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"")
async with AsyncSession():
rqst = Request("POST", "function")
async with rqst.fetch(headers={"X-Custom-Header": "custom-value"}):
pass

sent_kwargs = next(iter(m.requests.values()))[0].kwargs
assert sent_kwargs["headers"]["X-Custom-Header"] == "custom-value"


async def test_fetch_unparseable_date_override_does_not_crash(
dummy_endpoint: str,
) -> None:
# An unparseable Date must not raise — the header is still forwarded
# (upstream can reject) but `self.date` falls back to the auto-set value
# so any signing path stays internally consistent.
with aioresponses() as m:
m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"")
async with AsyncSession():
rqst = Request("POST", "function")
async with rqst.fetch(headers={"Date": "not a date"}):
pass

sent_kwargs = next(iter(m.requests.values()))[0].kwargs
assert sent_kwargs["headers"]["Date"] == "not a date"
assert rqst.date is not None # untouched by the parse failure
52 changes: 52 additions & 0 deletions tests/unit/webserver/test_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from types import SimpleNamespace
from typing import cast

from aiohttp import web
from multidict import CIMultiDict

from ai.backend.client.session import AsyncSession as APISession
from ai.backend.web.proxy import _pass_through_date_header


def _frontend_request(headers: dict[str, str]) -> web.Request:
"""Cheap stand-in for `aiohttp.web.Request` — only `.headers` is read."""
return cast(web.Request, SimpleNamespace(headers=CIMultiDict(headers)))


def _api_session(*, is_anonymous: bool) -> APISession:
"""Stand-in exposing only the fields `_pass_through_date_header` reads."""
return cast(
APISession,
SimpleNamespace(config=SimpleNamespace(is_anonymous=is_anonymous)),
)


CLIENT_DATE = "Tue, 02 Sep 2025 08:00:00 GMT"


def test_anonymous_session_with_date_forwards_it() -> None:
overrides = _pass_through_date_header(
_frontend_request({"Date": CLIENT_DATE}),
_api_session(is_anonymous=True),
)
assert overrides == {"Date": CLIENT_DATE}


def test_anonymous_session_without_date_returns_empty() -> None:
overrides = _pass_through_date_header(
_frontend_request({}),
_api_session(is_anonymous=True),
)
assert overrides == {}


def test_authenticated_session_does_not_forward_date() -> None:
# Critical: in the re-signing path the proxy signs with its own keypair,
# so the client must NOT be allowed to dictate the timestamp on the wire.
overrides = _pass_through_date_header(
_frontend_request({"Date": CLIENT_DATE}),
_api_session(is_anonymous=False),
)
assert overrides == {}
Loading