-
Notifications
You must be signed in to change notification settings - Fork 4.8k
feat(auth): add AWS Bedrock token provider #3135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
a9f9f2c
cc64a8b
aa1add5
71df8f4
0d4fe24
f7c8cf0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import time | ||
| import base64 | ||
| import threading | ||
| from typing import Any, Callable, TypedDict, cast | ||
| from pathlib import Path | ||
|
|
@@ -173,6 +175,89 @@ def get_token() -> str: | |
| return {"token_type": "id", "get_token": get_token} | ||
|
|
||
|
|
||
| def aws_bedrock_token_provider( | ||
| *, | ||
| region: str | None = None, | ||
| profile: str | None = None, | ||
| token_duration: int = 3600, | ||
| ) -> Callable[[], str]: | ||
| """ | ||
| Get a token provider for AWS Bedrock using IAM credentials. | ||
|
|
||
| Returns a callable that generates a bearer token from a SigV4 presigned URL. | ||
| Pass it directly to ``api_key`` when creating an OpenAI client pointed at a | ||
| Bedrock runtime endpoint. Credentials are resolved from the standard AWS credential chain: | ||
| https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html | ||
|
|
||
| The token is cached and automatically refreshed before it expires. | ||
|
|
||
| Args: | ||
| region: AWS region. Defaults to ``AWS_REGION`` or ``AWS_DEFAULT_REGION`` environment variable. | ||
| profile: AWS profile name. If not set, botocore resolves credentials from the standard chain. | ||
| token_duration: Token expiry in seconds. Defaults to 3600 (1 hour). | ||
| """ | ||
| _cached_token: list[str | None] = [None] | ||
| _refresh_at: list[float] = [0.0] | ||
|
|
||
| def _generate_token() -> str: | ||
| try: | ||
| import botocore.session | ||
| from botocore.auth import SigV4QueryAuth | ||
| from botocore.awsrequest import AWSRequest | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "botocore is required for AWS Bedrock token generation. " | ||
| "Install it with: pip install 'openai[bedrock]'" | ||
| ) from e | ||
|
|
||
| try: | ||
| resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") | ||
| if not resolved_region: | ||
| raise SubjectTokenProviderError( | ||
| "AWS region must be provided via the 'region' parameter, " | ||
| "or the AWS_REGION / AWS_DEFAULT_REGION environment variable." | ||
| ) | ||
|
|
||
| session = botocore.session.Session(profile=profile) | ||
| credentials = session.get_credentials() | ||
| if credentials is None: | ||
| raise SubjectTokenProviderError( | ||
| "No AWS credentials found. " | ||
| "Ensure your AWS credentials are configured." | ||
| ) | ||
| frozen_credentials = credentials.get_frozen_credentials() | ||
|
|
||
| request = AWSRequest( | ||
| method="POST", | ||
| url="https://bedrock.amazonaws.com/", | ||
| headers={"host": "bedrock.amazonaws.com"}, | ||
| params={"Action": "CallWithBearerToken"}, | ||
| ) | ||
|
|
||
| signer = SigV4QueryAuth(frozen_credentials, "bedrock", resolved_region, expires=token_duration) | ||
| signer.add_auth(request) | ||
|
|
||
| signed_url = request.url | ||
| # Strip the https:// prefix before encoding | ||
| url_without_scheme = signed_url[len("https://") :] | ||
| encoded_token = base64.b64encode(f"{url_without_scheme}&Version=1".encode()).decode() | ||
|
|
||
| return f"bedrock-api-key-{encoded_token}" | ||
| except (ImportError, SubjectTokenProviderError): | ||
| raise | ||
| except Exception as e: | ||
| raise SubjectTokenProviderError(f"Failed to generate AWS Bedrock token: {e}") from e | ||
|
|
||
| def get_token() -> str: | ||
| now = time.monotonic() | ||
| if _cached_token[0] is None or now >= _refresh_at[0]: | ||
| _cached_token[0] = _generate_token() | ||
| _refresh_at[0] = now + max(token_duration - 60, token_duration * 0.9) | ||
|
tbuatois marked this conversation as resolved.
Outdated
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The refresh schedule is based only on Useful? React with 👍 / 👎. |
||
| return _cached_token[0] | ||
|
|
||
| return get_token | ||
|
|
||
|
|
||
| class WorkloadIdentityAuth: | ||
| def __init__( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,10 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import base64 | ||
| from typing import cast | ||
| from pathlib import Path | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import httpx | ||
| import respx | ||
|
|
@@ -9,8 +13,10 @@ | |
| from inline_snapshot import snapshot | ||
|
|
||
| from openai import OpenAI, OAuthError | ||
| from openai._exceptions import SubjectTokenProviderError | ||
| from openai.auth._workload import ( | ||
| gcp_id_token_provider, | ||
| aws_bedrock_token_provider, | ||
| k8s_service_account_token_provider, | ||
| azure_managed_identity_token_provider, | ||
| ) | ||
|
|
@@ -188,3 +194,87 @@ def test_gcp_id_token_provider() -> None: | |
|
|
||
| assert provider["token_type"] == "id" | ||
| assert provider["get_token"]() == "gcp-token" | ||
|
|
||
|
|
||
| def _make_mock_botocore( | ||
| access_key: str = "AKIAIOSFODNN7EXAMPLE", | ||
| secret_key: str = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", | ||
| token: str | None = None, | ||
| ) -> MagicMock: | ||
| """Create a mock botocore module with fake credentials.""" | ||
| mock_botocore = MagicMock() | ||
|
|
||
| frozen = MagicMock() | ||
| frozen.access_key = access_key | ||
| frozen.secret_key = secret_key | ||
| frozen.token = token | ||
|
|
||
| creds = MagicMock() | ||
| creds.get_frozen_credentials.return_value = frozen | ||
|
|
||
| session_instance = MagicMock() | ||
| session_instance.get_credentials.return_value = creds | ||
|
|
||
| mock_botocore.session.Session.return_value = session_instance | ||
|
|
||
| # Use real SigV4QueryAuth and AWSRequest from botocore | ||
| import botocore.auth | ||
| import botocore.awsrequest | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
These tests import real Useful? React with 👍 / 👎. |
||
|
|
||
| mock_botocore.auth.SigV4QueryAuth = botocore.auth.SigV4QueryAuth | ||
| mock_botocore.awsrequest.AWSRequest = botocore.awsrequest.AWSRequest | ||
|
|
||
| return mock_botocore | ||
|
|
||
|
|
||
| def test_aws_bedrock_token_provider() -> None: | ||
| mock_botocore = _make_mock_botocore() | ||
|
|
||
| with patch.dict("sys.modules", {"botocore": mock_botocore, "botocore.session": mock_botocore.session, "botocore.auth": mock_botocore.auth, "botocore.awsrequest": mock_botocore.awsrequest}): | ||
| get_token = aws_bedrock_token_provider(region="us-east-1") | ||
|
|
||
| token = get_token() | ||
| assert token.startswith("bedrock-api-key-") | ||
|
|
||
| encoded_part = token[len("bedrock-api-key-"):] | ||
| decoded_url = base64.b64decode(encoded_part).decode() | ||
|
|
||
| assert "bedrock.amazonaws.com" in decoded_url | ||
| assert "X-Amz-Signature=" in decoded_url | ||
| assert "X-Amz-Credential=" in decoded_url | ||
| assert "Action=CallWithBearerToken" in decoded_url | ||
| assert "&Version=1" in decoded_url | ||
|
|
||
|
|
||
| def test_aws_bedrock_token_provider_custom_region() -> None: | ||
| mock_botocore = _make_mock_botocore() | ||
|
|
||
| with patch.dict("sys.modules", {"botocore": mock_botocore, "botocore.session": mock_botocore.session, "botocore.auth": mock_botocore.auth, "botocore.awsrequest": mock_botocore.awsrequest}): | ||
| get_token = aws_bedrock_token_provider(region="eu-west-1") | ||
| token = get_token() | ||
|
|
||
| encoded_part = token[len("bedrock-api-key-"):] | ||
| decoded_url = base64.b64decode(encoded_part).decode() | ||
|
|
||
| assert "eu-west-1" in decoded_url | ||
|
|
||
|
|
||
| def test_aws_bedrock_token_provider_no_credentials() -> None: | ||
| mock_botocore = MagicMock() | ||
| session_instance = MagicMock() | ||
| session_instance.get_credentials.return_value = None | ||
| mock_botocore.session.Session.return_value = session_instance | ||
|
|
||
| with patch.dict("sys.modules", {"botocore": mock_botocore, "botocore.session": mock_botocore.session, "botocore.auth": mock_botocore.auth, "botocore.awsrequest": mock_botocore.awsrequest}): | ||
| get_token = aws_bedrock_token_provider(region="us-east-1") | ||
|
|
||
| with pytest.raises(SubjectTokenProviderError, match="No AWS credentials found"): | ||
| get_token() | ||
|
|
||
|
|
||
| def test_aws_bedrock_token_provider_no_botocore() -> None: | ||
| with patch.dict("sys.modules", {"botocore": None, "botocore.session": None, "botocore.auth": None, "botocore.awsrequest": None}): | ||
| get_token = aws_bedrock_token_provider(region="us-east-1") | ||
|
|
||
| with pytest.raises(ImportError, match="botocore is required.*openai\\[bedrock\\]"): | ||
| get_token() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aws_bedrock_token_provider()currently returns a synchronousCallable[[], str], butAsyncOpenAIunconditionally awaitsapi_keyproviders inAsyncOpenAI._refresh_api_key. If an async user passes this new helper (the same way as the sync example), requests fail at runtime with aTypeErrorbecausestris not awaitable. This makes the new Bedrock auth path unusable for async clients unless users write their own wrapper.Useful? React with 👍 / 👎.