diff --git a/README.md b/README.md index 9450c0bc51..fcfa870989 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ so that your API key is not stored in source control. ### Workload Identity Authentication -For secure, automated environments like cloud-managed Kubernetes, Azure, and Google Cloud Platform, you can use workload identity authentication with short-lived tokens from cloud identity providers instead of long-lived API keys. +For secure, automated environments like cloud-managed Kubernetes, Azure, Google Cloud Platform, and AWS Bedrock, you can use workload identity authentication with short-lived tokens from cloud identity providers instead of long-lived API keys. #### Kubernetes (service account tokens) @@ -134,6 +134,43 @@ client = OpenAI( ) ``` +#### AWS Bedrock + +Requires `botocore` (`pip install 'openai[bedrock]'`). Credentials are resolved from the [standard AWS credential chain](https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html). + +```python +from openai import OpenAI +from openai.auth import aws_bedrock_token_provider + +client = OpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/v1", # region must match the token provider + api_key=aws_bedrock_token_provider( + region="us-east-1", + profile="my-profile", # optional — defaults to the standard AWS credential chain + ), +) + +# List models supported by the OpenAI-compatible endpoint +for model in client.models.list().data: + print(model.id) +``` + +For `AsyncOpenAI`, use `async_aws_bedrock_token_provider`: + +```python +from openai import AsyncOpenAI +from openai.auth import async_aws_bedrock_token_provider + +client = AsyncOpenAI( + base_url="https://bedrock-mantle.us-east-1.api.aws/v1", # region must match the token provider + api_key=async_aws_bedrock_token_provider( + region="us-east-1", + ), +) +``` + +> **Note:** The OpenAI SDK works only with Bedrock models that have the [OpenAI-compatible API](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-mantle.html) enabled. Use `client.models.list()` to see which models are available on your endpoint. + #### Custom subject token provider ```python diff --git a/pyproject.toml b/pyproject.toml index d0d533e8a6..cf2e441913 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] realtime = ["websockets >= 13, < 16"] datalib = ["numpy >= 1", "pandas >= 1.2.3", "pandas-stubs >= 1.1.0.11"] voice_helpers = ["sounddevice>=0.5.1", "numpy>=2.0.2"] +bedrock = ["botocore>=1.29.0"] [tool.rye] managed = true diff --git a/src/openai/auth/__init__.py b/src/openai/auth/__init__.py index 367aa86b72..ef4600c575 100644 --- a/src/openai/auth/__init__.py +++ b/src/openai/auth/__init__.py @@ -5,6 +5,8 @@ SubjectTokenProvider as SubjectTokenProvider, WorkloadIdentityAuth as WorkloadIdentityAuth, gcp_id_token_provider as gcp_id_token_provider, + aws_bedrock_token_provider as aws_bedrock_token_provider, + async_aws_bedrock_token_provider as async_aws_bedrock_token_provider, k8s_service_account_token_provider as k8s_service_account_token_provider, azure_managed_identity_token_provider as azure_managed_identity_token_provider, ) @@ -16,4 +18,6 @@ "k8s_service_account_token_provider", "azure_managed_identity_token_provider", "gcp_id_token_provider", + "aws_bedrock_token_provider", + "async_aws_bedrock_token_provider", ] diff --git a/src/openai/auth/_workload.py b/src/openai/auth/_workload.py index e3f6f7fb75..0844b1be52 100644 --- a/src/openai/auth/_workload.py +++ b/src/openai/auth/_workload.py @@ -1,8 +1,10 @@ from __future__ import annotations +import os import time +import base64 import threading -from typing import Any, Callable, TypedDict, cast +from typing import Any, Callable, Awaitable, TypedDict, cast from pathlib import Path from typing_extensions import Literal, NotRequired @@ -173,6 +175,123 @@ def get_token() -> str: return {"token_type": "id", "get_token": get_token} +def _make_bedrock_token_generator( + *, + region: str | None = None, + profile: str | None = None, + token_duration: int = 3600, +) -> Callable[[], str]: + _session: list[Any] = [None] + + def get_token() -> str: + try: + import botocore.session + from botocore.auth import SigV4QueryAuth + from botocore.awsrequest import AWSRequest + except ImportError as e: + raise ImportError( + "botocore is required for AWS Bedrock token generation. Install it with: pip install 'openai[bedrock]'" + ) from e + + try: + resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") + if not resolved_region: + raise SubjectTokenProviderError( + "AWS region must be provided via the 'region' parameter, " + "or the AWS_REGION / AWS_DEFAULT_REGION environment variable." + ) + + if _session[0] is None: + _session[0] = botocore.session.Session(profile=profile) + + credentials = _session[0].get_credentials() + if credentials is None: + raise SubjectTokenProviderError("No AWS credentials found. Ensure your AWS credentials are configured.") + frozen_credentials = credentials.get_frozen_credentials() + + request = AWSRequest( + method="POST", + url="https://bedrock.amazonaws.com/", + headers={"host": "bedrock.amazonaws.com"}, + params={"Action": "CallWithBearerToken"}, + ) + + signer = SigV4QueryAuth(frozen_credentials, "bedrock", resolved_region, expires=token_duration) + signer.add_auth(request) + + signed_url = request.url + # Strip the https:// prefix before encoding + url_without_scheme = signed_url[len("https://") :] + encoded_token = base64.b64encode(f"{url_without_scheme}&Version=1".encode()).decode() + + return f"bedrock-api-key-{encoded_token}" + except (ImportError, SubjectTokenProviderError): + raise + except Exception as e: + raise SubjectTokenProviderError(f"Failed to generate AWS Bedrock token: {e}") from e + + return get_token + + +def aws_bedrock_token_provider( + *, + region: str | None = None, + profile: str | None = None, + token_duration: int = 3600, +) -> Callable[[], str]: + """ + Get a sync token provider for AWS Bedrock. Use with ``OpenAI``. + + Returns a callable that generates a bearer token from a SigV4 presigned URL. + Pass it directly to ``api_key`` when creating an OpenAI client pointed at a + Bedrock runtime endpoint. Credentials are resolved from the standard AWS credential chain: + https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html + + The botocore session is cached so credential resolution is efficient, while + the token itself is regenerated on each call to ensure it always reflects + the latest valid credentials (important for short-lived STS/assumed-role sessions). + + For ``AsyncOpenAI``, use :func:`async_aws_bedrock_token_provider` instead. + + Args: + region: AWS region. Must match the region in the ``base_url`` passed to the client. + Defaults to ``AWS_REGION`` or ``AWS_DEFAULT_REGION`` environment variable. + profile: AWS profile name. If not set, botocore resolves credentials from the standard chain. + token_duration: Presigned URL expiry in seconds. Defaults to 3600 (1 hour). + """ + return _make_bedrock_token_generator(region=region, profile=profile, token_duration=token_duration) + + +def async_aws_bedrock_token_provider( + *, + region: str | None = None, + profile: str | None = None, + token_duration: int = 3600, +) -> Callable[[], Awaitable[str]]: + """ + Get an async token provider for AWS Bedrock. Use with ``AsyncOpenAI``. + + Returns an async callable that generates a bearer token from a SigV4 presigned URL. + Pass it directly to ``api_key`` when creating an AsyncOpenAI client pointed at a + Bedrock runtime endpoint. Credentials are resolved from the standard AWS credential chain: + https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html + + For ``OpenAI`` (sync), use :func:`aws_bedrock_token_provider` instead. + + Args: + region: AWS region. Must match the region in the ``base_url`` passed to the client. + Defaults to ``AWS_REGION`` or ``AWS_DEFAULT_REGION`` environment variable. + profile: AWS profile name. If not set, botocore resolves credentials from the standard chain. + token_duration: Presigned URL expiry in seconds. Defaults to 3600 (1 hour). + """ + _sync = _make_bedrock_token_generator(region=region, profile=profile, token_duration=token_duration) + + async def get_token() -> str: + return await to_thread(_sync) + + return get_token + + class WorkloadIdentityAuth: def __init__( self, diff --git a/tests/test_auth.py b/tests/test_auth.py index c8f4f2d7cf..deeae18308 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import json +import base64 from typing import cast from pathlib import Path +from unittest.mock import MagicMock, patch import httpx import respx @@ -9,8 +13,11 @@ from inline_snapshot import snapshot from openai import OpenAI, OAuthError +from openai._exceptions import SubjectTokenProviderError from openai.auth._workload import ( gcp_id_token_provider, + aws_bedrock_token_provider, + async_aws_bedrock_token_provider, k8s_service_account_token_provider, azure_managed_identity_token_provider, ) @@ -188,3 +195,69 @@ def test_gcp_id_token_provider() -> None: assert provider["token_type"] == "id" assert provider["get_token"]() == "gcp-token" + + +def _mock_botocore() -> MagicMock: + """Create a minimal mock botocore that stubs SigV4 signing.""" + mock = MagicMock() + mock.session.Session.return_value.get_credentials.return_value.get_frozen_credentials.return_value = MagicMock() + + def _fake_add_auth(request: MagicMock) -> None: + request.url += "&X-Amz-Credential=FAKE&X-Amz-Signature=FAKE" + + mock.auth.SigV4QueryAuth.return_value.add_auth = _fake_add_auth + mock.awsrequest.AWSRequest.return_value = MagicMock(url="https://bedrock.amazonaws.com/?Action=CallWithBearerToken") + + return mock + + +def _patch_botocore(mock: MagicMock): # type: ignore[type-arg] + return patch.dict( + "sys.modules", + { + "botocore": mock, + "botocore.session": mock.session, + "botocore.auth": mock.auth, + "botocore.awsrequest": mock.awsrequest, + }, + ) + + +def test_aws_bedrock_token_provider() -> None: + mock = _mock_botocore() + + with _patch_botocore(mock): + token = aws_bedrock_token_provider(region="us-east-1")() + assert token.startswith("bedrock-api-key-") + + decoded = base64.b64decode(token[len("bedrock-api-key-") :]).decode() + assert "bedrock.amazonaws.com" in decoded + assert "X-Amz-Signature=" in decoded + assert "Action=CallWithBearerToken" in decoded + assert "&Version=1" in decoded + + +def test_aws_bedrock_token_provider_no_credentials() -> None: + mock = MagicMock() + mock.session.Session.return_value.get_credentials.return_value = None + + with _patch_botocore(mock): + with pytest.raises(SubjectTokenProviderError, match="No AWS credentials found"): + aws_bedrock_token_provider(region="us-east-1")() + + +def test_aws_bedrock_token_provider_no_botocore() -> None: + with patch.dict( + "sys.modules", {"botocore": None, "botocore.session": None, "botocore.auth": None, "botocore.awsrequest": None} + ): + with pytest.raises(ImportError, match="botocore is required.*openai\\[bedrock\\]"): + aws_bedrock_token_provider(region="us-east-1")() + + +@pytest.mark.asyncio +async def test_async_aws_bedrock_token_provider() -> None: + mock = _mock_botocore() + + with _patch_botocore(mock): + token = await async_aws_bedrock_token_provider(region="us-east-1")() + assert token.startswith("bedrock-api-key-")