diff --git a/changes/5737.feature.md b/changes/5737.feature.md new file mode 100644 index 00000000000..884fc0ae248 --- /dev/null +++ b/changes/5737.feature.md @@ -0,0 +1 @@ +Forward the client's `Date` header through the `Webserver` proxy on the anonymous (signature pass-through) path so signature schemes that bind the upstream signature to `Date` keep working end-to-end; `Request.fetch()` now accepts a `headers=` override for this. The re-signing path keeps refreshing `Date` so a client cannot dictate the timestamp signed by the proxy. diff --git a/src/ai/backend/client/request.py b/src/ai/backend/client/request.py index 5f3b4fec679..dda0d697a69 100644 --- a/src/ai/backend/client/request.py +++ b/src/ai/backend/client/request.py @@ -23,6 +23,7 @@ import appdirs import attrs from aiohttp.client import _RequestContextManager, _WSRequestContextManager +from dateutil.parser import parse as parse_datetime from dateutil.tz import tzutc from multidict import CIMultiDict from yarl import URL @@ -277,6 +278,25 @@ def _pack_content(self) -> RequestContent | aiohttp.FormData: return data return self._content + def _apply_header_overrides(self, overrides: Mapping[str, str]) -> None: + """ + Apply caller-supplied header overrides on top of the auto-set headers. + + If ``Date`` is overridden, also parse it back into ``self.date`` so + that request signing uses the same value the upstream server will + see. If the override value cannot be parsed, the header is still + forwarded but ``self.date`` is left untouched; signing should not + be enabled in that case (e.g., the proxy gates this on anonymous + sessions where ``_sign()`` is skipped). + """ + for key, value in overrides.items(): + self.headers[key] = value + if key.lower() == "date": + try: + self.date = parse_datetime(value) + except (ValueError, OverflowError): + pass + def _build_url(self) -> URL: base_url = self.config.endpoint.path.rstrip("/") query_path = self.path.lstrip("/") if self.path is not None and len(self.path) > 0 else "" @@ -291,7 +311,12 @@ def _build_url(self) -> URL: # TODO: attach rate-limit information - def fetch(self, **kwargs: Any) -> FetchContextManager: + def fetch( + self, + *, + headers: Mapping[str, str] | None = None, + **kwargs: Any, + ) -> FetchContextManager: """ Sends the request to the server and reads the response. @@ -307,6 +332,11 @@ def fetch(self, **kwargs: Any) -> FetchContextManager: rqst = Request('GET', ...) async with rqst.fetch() as resp: print(await resp.text()) + + :param headers: Header overrides applied after the auto-populated + headers. Useful for proxy use cases that need to forward the + client's original ``Date`` header (and matching signature) + instead of having it refreshed. """ if self.method not in self._allowed_methods: raise ValueError(f"Disallowed HTTP method: {self.method}") @@ -316,6 +346,8 @@ def fetch(self, **kwargs: Any) -> FetchContextManager: self.headers["Date"] = self.date.isoformat() if self.content_type is not None and "Content-Type" not in self.headers: self.headers["Content-Type"] = self.content_type + if headers: + self._apply_header_overrides(headers) force_anonymous = kwargs.pop("anonymous", False) def _rqst_ctx_builder() -> _RequestContextManager: @@ -345,7 +377,11 @@ def _rqst_ctx_builder() -> _RequestContextManager: return FetchContextManager(self.session, _rqst_ctx_builder, self._session_mode, **kwargs) def connect_websocket( - self, protocols: Iterable[str] = tuple(), **kwargs: Any + self, + protocols: Iterable[str] = tuple(), + *, + headers: Mapping[str, str] | None = None, + **kwargs: Any, ) -> WebSocketContextManager: """ Creates a WebSocket connection. @@ -354,6 +390,9 @@ def connect_websocket( This method only works with :class:`~ai.backend.client.session.AsyncSession`. + + :param headers: Header overrides applied after the auto-populated + headers. See :meth:`fetch` for details. """ if not isinstance(self.session, AsyncSession): raise RuntimeError("Cannot use websockets with sessions in the synchronous mode") @@ -365,6 +404,8 @@ def connect_websocket( self.headers["Date"] = self.date.isoformat() # websocket is always a "binary" stream. self.content_type = "application/octet-stream" + if headers: + self._apply_header_overrides(headers) def _ws_ctx_builder() -> _WSRequestContextManager: full_url = self._build_url() @@ -385,7 +426,12 @@ def _ws_ctx_builder() -> _WSRequestContextManager: return WebSocketContextManager(self.session, _ws_ctx_builder, **kwargs) - def connect_events(self, **kwargs: Any) -> SSEContextManager: + def connect_events( + self, + *, + headers: Mapping[str, str] | None = None, + **kwargs: Any, + ) -> SSEContextManager: """ Creates a Server-Sent Events connection. @@ -393,6 +439,9 @@ def connect_events(self, **kwargs: Any) -> SSEContextManager: This method only works with :class:`~ai.backend.client.session.AsyncSession`. + + :param headers: Header overrides applied after the auto-populated + headers. See :meth:`fetch` for details. """ if not isinstance(self.session, AsyncSession): raise RuntimeError("Cannot use event streams with sessions in the synchronous mode") @@ -403,6 +452,8 @@ def connect_events(self, **kwargs: Any) -> SSEContextManager: raise RuntimeError("Failed to set request date") self.headers["Date"] = self.date.isoformat() self.content_type = "application/octet-stream" + if headers: + self._apply_header_overrides(headers) def _rqst_ctx_builder() -> _RequestContextManager: timeout_config = aiohttp.ClientTimeout( diff --git a/src/ai/backend/web/proxy.py b/src/ai/backend/web/proxy.py index 3c9595c2484..9032102b0a8 100644 --- a/src/ai/backend/web/proxy.py +++ b/src/ai/backend/web/proxy.py @@ -17,6 +17,7 @@ from ai.backend.client.exceptions import BackendAPIError, BackendClientError from ai.backend.client.request import Request, RequestContent, SessionMode +from ai.backend.client.session import AsyncSession as APISession from ai.backend.common.exception import InvalidAPIParameters from ai.backend.common.web.session import STORAGE_KEY, extra_config_headers, get_session from ai.backend.logging import BraceStyleAdapter @@ -141,6 +142,30 @@ async def close_upstream(self) -> None: await self.up_conn.close() +def _pass_through_date_header( + frontend_rqst: web.Request, + api_session: APISession, +) -> dict[str, str]: + """ + Build the ``headers`` override for ``Request.fetch()`` that forwards the + client's original ``Date`` header when (and only when) the proxy is in + pure pass-through mode. + + Forwarding ``Date`` is only safe for anonymous sessions, where the client + owns the upstream signature and the ``Date`` it bound the signature to + must survive end-to-end. In the re-signing path the proxy signs with its + own keypair, so a client-supplied ``Date`` would let the caller dictate + the timestamp signed on the wire — keep ``fetch()``'s auto-refreshed + ``Date`` there instead. + """ + if not api_session.config.is_anonymous: + return {} + client_date = frontend_rqst.headers.get("Date") + if client_date is None: + return {} + return {"Date": client_date} + + def _decrypt_payload(endpoint: str, payload: bytes) -> bytes: iv, real_payload = payload.split(b":") key = (base64.b64encode(endpoint.encode("ascii")) + iv + iv)[:32] @@ -255,7 +280,14 @@ async def web_handler( continue if (value := frontend_rqst.headers.get(key)) is not None: backend_rqst.headers[key] = value - async with backend_rqst.fetch() as backend_resp: + # When the proxy is in pure pass-through mode (anonymous session), + # the client owns the upstream signature and the Date it was bound + # to must survive end-to-end. In the re-signing path the proxy + # signs with its own keypair, so a client-supplied Date would let + # the caller dictate the timestamp signed on the wire — keep the + # auto-refreshed Date there. + fetch_header_overrides = _pass_through_date_header(frontend_rqst, api_session) + async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp: frontend_resp_hdrs = { key: value for key, value in backend_resp.headers.items() @@ -427,8 +459,11 @@ async def web_handler_with_jwt( if (value := frontend_rqst.headers.get(key)) is not None: backend_rqst.headers[key] = value + # See `_pass_through_date_header` for why this is gated on + # anonymous sessions only. + fetch_header_overrides = _pass_through_date_header(frontend_rqst, api_session) # Fetch from backend and stream response - async with backend_rqst.fetch() as backend_resp: + async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp: frontend_resp_hdrs = { key: value for key, value in backend_resp.headers.items() @@ -530,7 +565,10 @@ async def web_plugin_handler( for key in HTTP_HEADERS_TO_FORWARD: if (value := frontend_rqst.headers.get(key)) is not None: backend_rqst.headers[key] = value - async with backend_rqst.fetch() as backend_resp: + # See `_pass_through_date_header` for why this is gated on + # anonymous sessions only. + fetch_header_overrides = _pass_through_date_header(frontend_rqst, api_session) + async with backend_rqst.fetch(headers=fetch_header_overrides) as backend_resp: frontend_resp_hdrs = { key: value for key, value in backend_resp.headers.items() diff --git a/tests/unit/client/test_request.py b/tests/unit/client/test_request.py index c2acd7c700c..7324f0293e4 100644 --- a/tests/unit/client/test_request.py +++ b/tests/unit/client/test_request.py @@ -267,3 +267,67 @@ async def test_response_async(defconfig: APIConfig, dummy_endpoint: str) -> None async with rqst.fetch() as resp: assert await resp.text() == '{"test": 5678}' assert await resp.json() == {"test": 5678} + + +async def test_fetch_preserves_date_header_override(dummy_endpoint: str) -> None: + fixed_date = "Tue, 02 Sep 2025 08:00:00 GMT" + with aioresponses() as m: + m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"") + async with AsyncSession(): + rqst = Request("POST", "function") + async with rqst.fetch(headers={"Date": fixed_date}): + pass + + sent_kwargs = next(iter(m.requests.values()))[0].kwargs + assert sent_kwargs["headers"]["Date"] == fixed_date + + # The internal datetime must also reflect the override so that signing + # done with `self.date` matches the Date header sent on the wire. + assert rqst.date is not None + assert rqst.date.year == 2025 + assert rqst.date.month == 9 + assert rqst.date.day == 2 + + +async def test_fetch_default_date_header_when_no_override(dummy_endpoint: str) -> None: + with aioresponses() as m: + m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"") + async with AsyncSession(): + rqst = Request("POST", "function") + async with rqst.fetch(): + pass + + sent_kwargs = next(iter(m.requests.values()))[0].kwargs + # Without override, fetch() auto-populates Date with self.date.isoformat(). + assert rqst.date is not None + assert sent_kwargs["headers"]["Date"] == rqst.date.isoformat() + + +async def test_fetch_passes_through_arbitrary_header_overrides(dummy_endpoint: str) -> None: + with aioresponses() as m: + m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"") + async with AsyncSession(): + rqst = Request("POST", "function") + async with rqst.fetch(headers={"X-Custom-Header": "custom-value"}): + pass + + sent_kwargs = next(iter(m.requests.values()))[0].kwargs + assert sent_kwargs["headers"]["X-Custom-Header"] == "custom-value" + + +async def test_fetch_unparseable_date_override_does_not_crash( + dummy_endpoint: str, +) -> None: + # An unparseable Date must not raise — the header is still forwarded + # (upstream can reject) but `self.date` falls back to the auto-set value + # so any signing path stays internally consistent. + with aioresponses() as m: + m.post(dummy_endpoint + "function", status=HTTPStatus.OK, body=b"") + async with AsyncSession(): + rqst = Request("POST", "function") + async with rqst.fetch(headers={"Date": "not a date"}): + pass + + sent_kwargs = next(iter(m.requests.values()))[0].kwargs + assert sent_kwargs["headers"]["Date"] == "not a date" + assert rqst.date is not None # untouched by the parse failure diff --git a/tests/unit/webserver/test_proxy.py b/tests/unit/webserver/test_proxy.py new file mode 100644 index 00000000000..fb46c75dfc2 --- /dev/null +++ b/tests/unit/webserver/test_proxy.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import cast + +from aiohttp import web +from multidict import CIMultiDict + +from ai.backend.client.session import AsyncSession as APISession +from ai.backend.web.proxy import _pass_through_date_header + + +def _frontend_request(headers: dict[str, str]) -> web.Request: + """Cheap stand-in for `aiohttp.web.Request` — only `.headers` is read.""" + return cast(web.Request, SimpleNamespace(headers=CIMultiDict(headers))) + + +def _api_session(*, is_anonymous: bool) -> APISession: + """Stand-in exposing only the fields `_pass_through_date_header` reads.""" + return cast( + APISession, + SimpleNamespace(config=SimpleNamespace(is_anonymous=is_anonymous)), + ) + + +CLIENT_DATE = "Tue, 02 Sep 2025 08:00:00 GMT" + + +def test_anonymous_session_with_date_forwards_it() -> None: + overrides = _pass_through_date_header( + _frontend_request({"Date": CLIENT_DATE}), + _api_session(is_anonymous=True), + ) + assert overrides == {"Date": CLIENT_DATE} + + +def test_anonymous_session_without_date_returns_empty() -> None: + overrides = _pass_through_date_header( + _frontend_request({}), + _api_session(is_anonymous=True), + ) + assert overrides == {} + + +def test_authenticated_session_does_not_forward_date() -> None: + # Critical: in the re-signing path the proxy signs with its own keypair, + # so the client must NOT be allowed to dictate the timestamp on the wire. + overrides = _pass_through_date_header( + _frontend_request({"Date": CLIENT_DATE}), + _api_session(is_anonymous=False), + ) + assert overrides == {}