Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ so that your API key is not stored in source control.

### Workload Identity Authentication

For secure, automated environments like cloud-managed Kubernetes, Azure, and Google Cloud Platform, you can use workload identity authentication with short-lived tokens from cloud identity providers instead of long-lived API keys.
For secure, automated environments like cloud-managed Kubernetes, Azure, Google Cloud Platform, and AWS Bedrock, you can use workload identity authentication with short-lived tokens from cloud identity providers instead of long-lived API keys.

#### Kubernetes (service account tokens)

Expand Down Expand Up @@ -134,6 +134,29 @@ client = OpenAI(
)
```

#### AWS Bedrock

Requires `botocore` (`pip install 'openai[bedrock]'`). Credentials are resolved from the [standard AWS credential chain](https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html).

```python
from openai import OpenAI
from openai.auth import aws_bedrock_token_provider

client = OpenAI(
base_url="https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1",
api_key=aws_bedrock_token_provider(
region="us-east-1",
profile="my-profile", # optional — defaults to the standard AWS credential chain
),
)

# List models supported by the OpenAI-compatible endpoint
for model in client.models.list().data:
print(model.id)
```

> **Note:** The OpenAI SDK works only with Bedrock models that have the [OpenAI-compatible API](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-mantle.html) enabled. Use `client.models.list()` to see which models are available on your endpoint.

#### Custom subject token provider

```python
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"]
realtime = ["websockets >= 13, < 16"]
datalib = ["numpy >= 1", "pandas >= 1.2.3", "pandas-stubs >= 1.1.0.11"]
voice_helpers = ["sounddevice>=0.5.1", "numpy>=2.0.2"]
bedrock = ["botocore>=1.29.0"]

[tool.rye]
managed = true
Expand Down
2 changes: 2 additions & 0 deletions src/openai/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
SubjectTokenProvider as SubjectTokenProvider,
WorkloadIdentityAuth as WorkloadIdentityAuth,
gcp_id_token_provider as gcp_id_token_provider,
aws_bedrock_token_provider as aws_bedrock_token_provider,
k8s_service_account_token_provider as k8s_service_account_token_provider,
azure_managed_identity_token_provider as azure_managed_identity_token_provider,
)
Expand All @@ -16,4 +17,5 @@
"k8s_service_account_token_provider",
"azure_managed_identity_token_provider",
"gcp_id_token_provider",
"aws_bedrock_token_provider",
]
85 changes: 85 additions & 0 deletions src/openai/auth/_workload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import os
import time
import base64
import threading
from typing import Any, Callable, TypedDict, cast
from pathlib import Path
Expand Down Expand Up @@ -173,6 +175,89 @@ def get_token() -> str:
return {"token_type": "id", "get_token": get_token}


def aws_bedrock_token_provider(
*,
region: str | None = None,
profile: str | None = None,
token_duration: int = 3600,
) -> Callable[[], str]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Add async-compatible Bedrock token provider

aws_bedrock_token_provider() currently returns a synchronous Callable[[], str], but AsyncOpenAI unconditionally awaits api_key providers in AsyncOpenAI._refresh_api_key. If an async user passes this new helper (the same way as the sync example), requests fail at runtime with a TypeError because str is not awaitable. This makes the new Bedrock auth path unusable for async clients unless users write their own wrapper.

Useful? React with 👍 / 👎.

"""
Get a token provider for AWS Bedrock using IAM credentials.

Returns a callable that generates a bearer token from a SigV4 presigned URL.
Pass it directly to ``api_key`` when creating an OpenAI client pointed at a
Bedrock runtime endpoint. Credentials are resolved from the standard AWS credential chain:
https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html

The token is cached and automatically refreshed before it expires.

Args:
region: AWS region. Defaults to ``AWS_REGION`` or ``AWS_DEFAULT_REGION`` environment variable.
profile: AWS profile name. If not set, botocore resolves credentials from the standard chain.
token_duration: Token expiry in seconds. Defaults to 3600 (1 hour).
"""
_cached_token: list[str | None] = [None]
_refresh_at: list[float] = [0.0]

def _generate_token() -> str:
try:
import botocore.session
from botocore.auth import SigV4QueryAuth
from botocore.awsrequest import AWSRequest
except ImportError as e:
raise ImportError(
"botocore is required for AWS Bedrock token generation. "
"Install it with: pip install 'openai[bedrock]'"
) from e

try:
resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
if not resolved_region:
raise SubjectTokenProviderError(
"AWS region must be provided via the 'region' parameter, "
"or the AWS_REGION / AWS_DEFAULT_REGION environment variable."
)

session = botocore.session.Session(profile=profile)
credentials = session.get_credentials()
if credentials is None:
raise SubjectTokenProviderError(
"No AWS credentials found. "
"Ensure your AWS credentials are configured."
)
frozen_credentials = credentials.get_frozen_credentials()

request = AWSRequest(
method="POST",
url="https://bedrock.amazonaws.com/",
headers={"host": "bedrock.amazonaws.com"},
params={"Action": "CallWithBearerToken"},
)

signer = SigV4QueryAuth(frozen_credentials, "bedrock", resolved_region, expires=token_duration)
signer.add_auth(request)

signed_url = request.url
# Strip the https:// prefix before encoding
url_without_scheme = signed_url[len("https://") :]
encoded_token = base64.b64encode(f"{url_without_scheme}&Version=1".encode()).decode()

return f"bedrock-api-key-{encoded_token}"
except (ImportError, SubjectTokenProviderError):
raise
except Exception as e:
raise SubjectTokenProviderError(f"Failed to generate AWS Bedrock token: {e}") from e

def get_token() -> str:
now = time.monotonic()
if _cached_token[0] is None or now >= _refresh_at[0]:
_cached_token[0] = _generate_token()
_refresh_at[0] = now + max(token_duration - 60, token_duration * 0.9)
Comment thread
tbuatois marked this conversation as resolved.
Outdated
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Account for AWS credential expiry in token cache

The refresh schedule is based only on token_duration, but Bedrock bearer tokens signed with temporary AWS credentials stop working when those credentials expire, which can be much earlier than the requested token duration (for example, 15-minute STS credentials). With the current logic, a token can be reused for nearly an hour after its signing credentials have expired, causing intermittent 401s until _refresh_at is reached. The cache policy should incorporate credential expiration (or avoid caching) so expired signed tokens are not returned.

Useful? React with 👍 / 👎.

return _cached_token[0]

return get_token


class WorkloadIdentityAuth:
def __init__(
self,
Expand Down
90 changes: 90 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import json
import base64
from typing import cast
from pathlib import Path
from unittest.mock import MagicMock, patch

import httpx
import respx
Expand All @@ -9,8 +13,10 @@
from inline_snapshot import snapshot

from openai import OpenAI, OAuthError
from openai._exceptions import SubjectTokenProviderError
from openai.auth._workload import (
gcp_id_token_provider,
aws_bedrock_token_provider,
k8s_service_account_token_provider,
azure_managed_identity_token_provider,
)
Expand Down Expand Up @@ -188,3 +194,87 @@ def test_gcp_id_token_provider() -> None:

assert provider["token_type"] == "id"
assert provider["get_token"]() == "gcp-token"


def _make_mock_botocore(
access_key: str = "AKIAIOSFODNN7EXAMPLE",
secret_key: str = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
token: str | None = None,
) -> MagicMock:
"""Create a mock botocore module with fake credentials."""
mock_botocore = MagicMock()

frozen = MagicMock()
frozen.access_key = access_key
frozen.secret_key = secret_key
frozen.token = token

creds = MagicMock()
creds.get_frozen_credentials.return_value = frozen

session_instance = MagicMock()
session_instance.get_credentials.return_value = creds

mock_botocore.session.Session.return_value = session_instance

# Use real SigV4QueryAuth and AWSRequest from botocore
import botocore.auth
import botocore.awsrequest
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard botocore imports in Bedrock auth tests

These tests import real botocore modules unconditionally, but botocore is only declared as an optional extra and is not part of the default dev lockfile used by the standard nox test session. In a normal test environment without openai[bedrock], this helper raises ModuleNotFoundError before the new tests run, which can fail CI/test runs unrelated to Bedrock support. The tests should skip when botocore is unavailable or fully mock these modules.

Useful? React with 👍 / 👎.


mock_botocore.auth.SigV4QueryAuth = botocore.auth.SigV4QueryAuth
mock_botocore.awsrequest.AWSRequest = botocore.awsrequest.AWSRequest

return mock_botocore


def test_aws_bedrock_token_provider() -> None:
mock_botocore = _make_mock_botocore()

with patch.dict("sys.modules", {"botocore": mock_botocore, "botocore.session": mock_botocore.session, "botocore.auth": mock_botocore.auth, "botocore.awsrequest": mock_botocore.awsrequest}):
get_token = aws_bedrock_token_provider(region="us-east-1")

token = get_token()
assert token.startswith("bedrock-api-key-")

encoded_part = token[len("bedrock-api-key-"):]
decoded_url = base64.b64decode(encoded_part).decode()

assert "bedrock.amazonaws.com" in decoded_url
assert "X-Amz-Signature=" in decoded_url
assert "X-Amz-Credential=" in decoded_url
assert "Action=CallWithBearerToken" in decoded_url
assert "&Version=1" in decoded_url


def test_aws_bedrock_token_provider_custom_region() -> None:
mock_botocore = _make_mock_botocore()

with patch.dict("sys.modules", {"botocore": mock_botocore, "botocore.session": mock_botocore.session, "botocore.auth": mock_botocore.auth, "botocore.awsrequest": mock_botocore.awsrequest}):
get_token = aws_bedrock_token_provider(region="eu-west-1")
token = get_token()

encoded_part = token[len("bedrock-api-key-"):]
decoded_url = base64.b64decode(encoded_part).decode()

assert "eu-west-1" in decoded_url


def test_aws_bedrock_token_provider_no_credentials() -> None:
mock_botocore = MagicMock()
session_instance = MagicMock()
session_instance.get_credentials.return_value = None
mock_botocore.session.Session.return_value = session_instance

with patch.dict("sys.modules", {"botocore": mock_botocore, "botocore.session": mock_botocore.session, "botocore.auth": mock_botocore.auth, "botocore.awsrequest": mock_botocore.awsrequest}):
get_token = aws_bedrock_token_provider(region="us-east-1")

with pytest.raises(SubjectTokenProviderError, match="No AWS credentials found"):
get_token()


def test_aws_bedrock_token_provider_no_botocore() -> None:
with patch.dict("sys.modules", {"botocore": None, "botocore.session": None, "botocore.auth": None, "botocore.awsrequest": None}):
get_token = aws_bedrock_token_provider(region="us-east-1")

with pytest.raises(ImportError, match="botocore is required.*openai\\[bedrock\\]"):
get_token()