diff --git a/airbyte_cdk/sources/declarative/auth/__init__.py b/airbyte_cdk/sources/declarative/auth/__init__.py index 810437810..e901f7af9 100644 --- a/airbyte_cdk/sources/declarative/auth/__init__.py +++ b/airbyte_cdk/sources/declarative/auth/__init__.py @@ -4,5 +4,18 @@ from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeOauth2Authenticator +from airbyte_cdk.sources.declarative.auth.token_pool_authenticator import TokenPoolAuthenticator +from airbyte_cdk.sources.declarative.auth.token_rotation_strategies import ( + RateLimitAwareRotation, + RoundRobinRotation, + TokenRotationStrategy, +) -__all__ = ["DeclarativeOauth2Authenticator", "JwtAuthenticator"] +__all__ = [ + "DeclarativeOauth2Authenticator", + "JwtAuthenticator", + "RateLimitAwareRotation", + "RoundRobinRotation", + "TokenPoolAuthenticator", + "TokenRotationStrategy", +] diff --git a/airbyte_cdk/sources/declarative/auth/token_pool_authenticator.py b/airbyte_cdk/sources/declarative/auth/token_pool_authenticator.py new file mode 100644 index 000000000..2a4088d8e --- /dev/null +++ b/airbyte_cdk/sources/declarative/auth/token_pool_authenticator.py @@ -0,0 +1,97 @@ +# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +# + +from dataclasses import InitVar, dataclass, field +from typing import Any, List, Mapping, Optional, Union + +import requests + +from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator +from airbyte_cdk.sources.declarative.auth.token_rotation_strategies import ( + RateLimitAwareRotation, + RoundRobinRotation, + TokenRotationStrategy, +) +from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString +from airbyte_cdk.sources.types import Config + + +@dataclass +class TokenPoolAuthenticator(DeclarativeAuthenticator): + """Authenticator that rotates through multiple API tokens. + + Accepts a list of tokens (or a delimited string) and rotates through them + using a configurable strategy (round-robin or rate-limit-aware). This enables + distributing rate-limit consumption across multiple credentials. + + Attributes: + tokens: Interpolated string resolving to the token(s). Can be a single token + or multiple tokens joined by `token_separator`. + config: The user-provided configuration. + parameters: Additional runtime parameters for string interpolation. + token_separator: Delimiter used to split `tokens` into individual values. + auth_method: Prefix for the token value in the header (e.g., "Bearer", "token", ""). + header: HTTP header name to set. + rotation_strategy: Strategy controlling how tokens are rotated. + """ + + tokens: Union[InterpolatedString, str] + config: Config + parameters: InitVar[Mapping[str, Any]] + token_separator: str = "," + auth_method: str = "Bearer" + header: str = "Authorization" + rotation_strategy: Optional[TokenRotationStrategy] = None + + _token_list: List[str] = field(default_factory=list, init=False, repr=False) + _strategy: TokenRotationStrategy = field(init=False, repr=False) + + def __post_init__(self, parameters: Mapping[str, Any]) -> None: + tokens_interpolated = InterpolatedString.create(self.tokens, parameters=parameters) + raw_tokens = str(tokens_interpolated.eval(self.config)) + self._token_list = [t.strip() for t in raw_tokens.split(self.token_separator) if t.strip()] + + if not self._token_list: + raise ValueError("TokenPoolAuthenticator requires at least one token.") + + if self.rotation_strategy is not None: + self._strategy = self.rotation_strategy + # Inject token list into strategy if not already populated + if hasattr(self._strategy, "tokens") and not self._strategy.tokens: + self._strategy.tokens = self._token_list + if isinstance(self._strategy, RoundRobinRotation): + self._strategy.__post_init__(parameters) + elif isinstance(self._strategy, RateLimitAwareRotation): + self._strategy.__post_init__(parameters) + else: + # Default to round-robin + self._strategy = RoundRobinRotation(tokens=self._token_list, parameters=parameters) + + @property + def auth_header(self) -> str: + return self.header + + @property + def token(self) -> str: + raw_token = self._strategy.get_active_token() + if self.auth_method: + return f"{self.auth_method} {raw_token}" + return raw_token + + def on_http_response(self, response: requests.Response) -> None: + """Called after each HTTP response to update per-token rate-limit state.""" + self._strategy.update_from_response(response) + + def update_token(self) -> None: + """Force rotation to the next token. + + Provided for compatibility with imperative-style connectors that call + `authenticator.update_token()` from backoff strategies. + """ + if isinstance(self._strategy, RateLimitAwareRotation): + self._strategy._rotate() + elif isinstance(self._strategy, RoundRobinRotation): + # RoundRobinRotation advances on each get_active_token() call, so + # calling get_active_token() once consumes the rotation. + self._strategy.get_active_token() diff --git a/airbyte_cdk/sources/declarative/auth/token_rotation_strategies.py b/airbyte_cdk/sources/declarative/auth/token_rotation_strategies.py new file mode 100644 index 000000000..a799afa20 --- /dev/null +++ b/airbyte_cdk/sources/declarative/auth/token_rotation_strategies.py @@ -0,0 +1,210 @@ +# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +# + +import logging +import time +from abc import ABC, abstractmethod +from dataclasses import InitVar, dataclass, field +from itertools import cycle +from typing import Any, Dict, List, Mapping, Optional + +import requests + +from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse + +logger = logging.getLogger("airbyte") + + +@dataclass +class TokenState: + """Tracks rate-limit state for an individual token.""" + + remaining: int = -1 # -1 means unknown + reset_at: Optional[AirbyteDateTime] = None + + +class TokenRotationStrategy(ABC): + """Base class for token rotation strategies.""" + + @abstractmethod + def get_active_token(self) -> str: + """Return the currently active token.""" + raise NotImplementedError + + def update_from_response(self, response: requests.Response) -> None: + """Update internal state from an HTTP response. Override in subclasses.""" + pass + + +@dataclass +class RoundRobinRotation(TokenRotationStrategy): + """Cycle through tokens on each `get_active_token()` call.""" + + tokens: List[str] + parameters: InitVar[Mapping[str, Any]] + + def __post_init__(self, parameters: Mapping[str, Any]) -> None: + self._iter = cycle(self.tokens) + self._current = next(self._iter) + + def get_active_token(self) -> str: + token = self._current + self._current = next(self._iter) + return token + + +@dataclass +class RateLimitAwareRotation(TokenRotationStrategy): + """Track per-token quota from response headers and rotate when exhausted. + + When a token's remaining quota hits zero, rotate to the next token. When all + tokens are exhausted, sleep until the earliest reset time. Proactive throttling + spreads remaining calls over the reset window to avoid hitting the wall. + """ + + tokens: List[str] + parameters: InitVar[Mapping[str, Any]] + ratelimit_remaining_header: str = "x-ratelimit-remaining" + ratelimit_reset_header: str = "x-ratelimit-reset" + max_wait_seconds: int = 7200 + budget_reserve_fraction: float = 0.1 + budget_min_reserve: int = 50 + + _token_state: Dict[str, TokenState] = field(default_factory=dict, init=False) + _active_index: int = field(default=0, init=False) + _budget_logged: bool = field(default=False, init=False) + + HEARTBEAT_INTERVAL: float = 60.0 + + def __post_init__(self, parameters: Mapping[str, Any]) -> None: + self._token_state = {t: TokenState() for t in self.tokens} + + def get_active_token(self) -> str: + """Return the active token, rotating if the current one is exhausted.""" + attempts = 0 + while attempts < len(self.tokens): + token = self.tokens[self._active_index] + state = self._token_state[token] + # If remaining is unknown (-1) or > 0, use this token + if state.remaining != 0: + return token + # Current token is exhausted, try next + self._rotate() + attempts += 1 + + # All tokens are exhausted — sleep until earliest reset + self._sleep_until_reset() + return self.tokens[self._active_index] + + def update_from_response(self, response: requests.Response) -> None: + """Update the active token's state from response headers.""" + token = self.tokens[self._active_index] + state = self._token_state[token] + + remaining_header = response.headers.get(self.ratelimit_remaining_header) + reset_header = response.headers.get(self.ratelimit_reset_header) + + if remaining_header is not None: + try: + state.remaining = int(remaining_header) + except (ValueError, TypeError): + logger.debug( + "Could not parse ratelimit-remaining header value: %s", remaining_header + ) + + if reset_header is not None: + try: + reset_ts = float(reset_header) + state.reset_at = ab_datetime_parse(str(int(reset_ts))) + except (ValueError, TypeError): + logger.debug("Could not parse ratelimit-reset header value: %s", reset_header) + + # Proactive rotation: if remaining is below reserve, rotate + if state.remaining >= 0: + reserve = self._get_budget_reserve(state) + if state.remaining <= reserve: + self._maybe_throttle(state) + if state.remaining == 0: + self._rotate() + + def _rotate(self) -> None: + self._active_index = (self._active_index + 1) % len(self.tokens) + + def _get_budget_reserve(self, state: TokenState) -> int: + """Return the minimum number of calls to keep in reserve for a token.""" + limit_estimate = max(5000, state.remaining) if state.remaining > 0 else 5000 + return max(self.budget_min_reserve, int(limit_estimate * self.budget_reserve_fraction)) + + def _maybe_throttle(self, state: TokenState) -> None: + """Inject a small delay when all tokens are running low.""" + if not all( + s.remaining >= 0 and s.remaining <= self._get_budget_reserve(s) + for s in self._token_state.values() + if s.remaining >= 0 + ): + return + + if state.reset_at is None: + return + + seconds_to_reset = max((state.reset_at - ab_datetime_now()).total_seconds(), 0) + total_remaining = sum(max(s.remaining, 0) for s in self._token_state.values()) + if total_remaining <= 0 or seconds_to_reset <= 0: + return + + delay = min(seconds_to_reset / total_remaining, 10.0) + if delay >= 0.1: + if not self._budget_logged: + logger.info( + "API budget: throttling requests (%.1fs delay). %d calls remaining across %d token(s), " + "%.0fs until reset.", + delay, + total_remaining, + len(self.tokens), + seconds_to_reset, + ) + self._budget_logged = True + time.sleep(delay) + + def _sleep_until_reset(self) -> None: + """Sleep until the earliest token reset time, or raise if too long.""" + reset_times = [s.reset_at for s in self._token_state.values() if s.reset_at is not None] + if not reset_times: + raise RuntimeError( + "All tokens in the pool are exhausted and no reset time is available." + ) + + earliest_reset = min(reset_times) + wait_seconds = max((earliest_reset - ab_datetime_now()).total_seconds(), 0) + + if wait_seconds > self.max_wait_seconds: + raise RuntimeError( + f"All tokens in the pool are exhausted. Earliest reset in {wait_seconds:.0f}s " + f"exceeds max_wait_seconds ({self.max_wait_seconds}s)." + ) + + logger.info( + "All tokens exhausted. Sleeping %.0fs until rate limit resets.", + wait_seconds, + ) + self._sleep_with_heartbeat(wait_seconds) + + # Reset state for all tokens after sleeping + for state in self._token_state.values(): + state.remaining = -1 + state.reset_at = None + self._budget_logged = False + + def _sleep_with_heartbeat(self, total_seconds: float) -> None: + """Sleep with periodic log messages to keep the heartbeat alive.""" + remaining = total_seconds + while remaining > 0: + chunk = min(remaining, self.HEARTBEAT_INTERVAL) + time.sleep(chunk) + remaining -= chunk + if remaining > 0: + logger.info( + "Rate limit exhausted. Waiting for reset — %.0fs remaining.", + remaining, + ) diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 7d99a0188..cdceb3b32 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -320,6 +320,101 @@ definitions: $parameters: type: object additionalProperties: true + TokenPoolAuthenticator: + title: Token Pool Authenticator + description: Authenticator that rotates through multiple API tokens to distribute rate-limit consumption. Tokens are cycled automatically based on the configured rotation strategy. + type: object + required: + - type + - tokens + properties: + type: + type: string + enum: [TokenPoolAuthenticator] + tokens: + title: API Tokens + description: The API token(s) to rotate through. Can be a single token or multiple tokens joined by the token_separator. Resolved from config via interpolation. + type: string + interpolation_context: + - config + examples: + - "{{ config['api_tokens'] }}" + - "{{ config['credentials']['personal_access_token'] }}" + token_separator: + title: Token Separator + description: Separator used to split the tokens string into individual tokens. + type: string + default: "," + auth_method: + title: Auth Method + description: Prefix for the token value in the header (e.g., "Bearer", "token"). Set to empty string for no prefix. + type: string + default: "Bearer" + header: + title: Header Name + description: The HTTP header name to set with the active token. + type: string + default: "Authorization" + rotation_strategy: + title: Rotation Strategy + description: Strategy controlling how tokens are rotated. Defaults to round-robin. + anyOf: + - "$ref": "#/definitions/RoundRobinRotation" + - "$ref": "#/definitions/RateLimitAwareRotation" + $parameters: + type: object + additionalProperties: true + RoundRobinRotation: + title: Round Robin Rotation + description: Rotates through tokens sequentially on each request. No rate-limit awareness. + type: object + required: + - type + properties: + type: + type: string + enum: [RoundRobinRotation] + $parameters: + type: object + additionalProperties: true + RateLimitAwareRotation: + title: Rate Limit Aware Rotation + description: Rotates to the next token when the current token's rate limit is exhausted, based on response headers. Sleeps when all tokens are exhausted until the earliest reset time. + type: object + required: + - type + properties: + type: + type: string + enum: [RateLimitAwareRotation] + ratelimit_remaining_header: + title: Rate Limit Remaining Header + description: Response header indicating remaining API calls for the current token. + type: string + default: "x-ratelimit-remaining" + ratelimit_reset_header: + title: Rate Limit Reset Header + description: Response header indicating when the rate limit window resets (Unix timestamp). + type: string + default: "x-ratelimit-reset" + max_wait_seconds: + title: Max Wait Seconds + description: Maximum seconds to wait for any token's rate limit to reset. If all tokens are exhausted and the earliest reset exceeds this, the sync fails with a transient error. + type: integer + default: 7200 + budget_reserve_fraction: + title: Budget Reserve Fraction + description: Start proactive throttling when a token's remaining quota drops below this fraction. Set to 0 to disable. + type: number + default: 0.1 + budget_min_reserve: + title: Budget Min Reserve + description: Minimum number of calls to keep in reserve per token before triggering rotation. + type: integer + default: 50 + $parameters: + type: object + additionalProperties: true CheckStream: title: Streams to Check description: Defines the streams to try reading when running a check operation. @@ -2274,6 +2369,7 @@ definitions: - "$ref": "#/definitions/JwtAuthenticator" - "$ref": "#/definitions/SessionTokenAuthenticator" - "$ref": "#/definitions/SelectiveAuthenticator" + - "$ref": "#/definitions/TokenPoolAuthenticator" - "$ref": "#/definitions/CustomAuthenticator" - "$ref": "#/definitions/NoAuth" - "$ref": "#/definitions/LegacySessionTokenAuthenticator" diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 3bee30ca6..b20d13db7 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -2584,6 +2584,84 @@ class Config: parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") +class RoundRobinRotation(BaseModel): + class Config: + extra = Extra.allow + + type: Literal["RoundRobinRotation"] + parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + + +class RateLimitAwareRotation(BaseModel): + class Config: + extra = Extra.allow + + type: Literal["RateLimitAwareRotation"] + ratelimit_remaining_header: Optional[str] = Field( + "x-ratelimit-remaining", + description="Response header indicating remaining API calls for the current token.", + title="Rate Limit Remaining Header", + ) + ratelimit_reset_header: Optional[str] = Field( + "x-ratelimit-reset", + description="Response header indicating when the rate limit window resets (Unix timestamp).", + title="Rate Limit Reset Header", + ) + max_wait_seconds: Optional[int] = Field( + 7200, + description="Maximum seconds to wait for any token's rate limit to reset. If all tokens are exhausted and the earliest reset exceeds this, the sync fails with a transient error.", + title="Max Wait Seconds", + ) + budget_reserve_fraction: Optional[float] = Field( + 0.1, + description="Start proactive throttling when a token's remaining quota drops below this fraction. Set to 0 to disable.", + title="Budget Reserve Fraction", + ) + budget_min_reserve: Optional[int] = Field( + 50, + description="Minimum number of calls to keep in reserve per token before triggering rotation.", + title="Budget Min Reserve", + ) + parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + + +class TokenPoolAuthenticator(BaseModel): + class Config: + extra = Extra.allow + + type: Literal["TokenPoolAuthenticator"] + tokens: str = Field( + ..., + description="The API token(s) to rotate through. Can be a single token or multiple tokens joined by the token_separator. Resolved from config via interpolation.", + examples=[ + "{{ config['api_tokens'] }}", + "{{ config['credentials']['personal_access_token'] }}", + ], + title="API Tokens", + ) + token_separator: Optional[str] = Field( + ",", + description="Separator used to split the tokens string into individual tokens.", + title="Token Separator", + ) + auth_method: Optional[str] = Field( + "Bearer", + description='Prefix for the token value in the header (e.g., "Bearer", "token"). Set to empty string for no prefix.', + title="Auth Method", + ) + header: Optional[str] = Field( + "Authorization", + description="The HTTP header name to set with the active token.", + title="Header Name", + ) + rotation_strategy: Optional[Union[RoundRobinRotation, RateLimitAwareRotation]] = Field( + None, + description="Strategy controlling how tokens are rotated. Defaults to round-robin.", + title="Rotation Strategy", + ) + parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + + class ConditionalStreams(BaseModel): type: Literal["ConditionalStreams"] condition: str = Field( @@ -2792,6 +2870,7 @@ class HttpRequester(BaseModelWithDeprecations): JwtAuthenticator, SessionTokenAuthenticator, SelectiveAuthenticator, + TokenPoolAuthenticator, CustomAuthenticator, NoAuth, LegacySessionTokenAuthenticator, diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 1e5801bb0..a44d8d675 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -72,12 +72,17 @@ BearerAuthenticator, LegacySessionTokenAuthenticator, ) +from airbyte_cdk.sources.declarative.auth.token_pool_authenticator import TokenPoolAuthenticator from airbyte_cdk.sources.declarative.auth.token_provider import ( InterpolatedSessionTokenProvider, InterpolatedStringTokenProvider, SessionTokenProvider, TokenProvider, ) +from airbyte_cdk.sources.declarative.auth.token_rotation_strategies import ( + RateLimitAwareRotation, + RoundRobinRotation, +) from airbyte_cdk.sources.declarative.checks import ( CheckDynamicStream, CheckStream, @@ -406,6 +411,9 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( Rate as RateModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + RateLimitAwareRotation as RateLimitAwareRotationModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( RecordExpander as RecordExpanderModel, ) @@ -430,6 +438,9 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( ResponseToFileExtractor as ResponseToFileExtractorModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + RoundRobinRotation as RoundRobinRotationModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( SchemaNormalization as SchemaNormalizationModel, ) @@ -455,6 +466,9 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( SubstreamPartitionRouter as SubstreamPartitionRouterModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + TokenPoolAuthenticator as TokenPoolAuthenticatorModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( TypesMap as TypesMapModel, ) @@ -805,6 +819,9 @@ def _init_mappings(self) -> None: RequestOptionModel: self.create_request_option, LegacySessionTokenAuthenticatorModel: self.create_legacy_session_token_authenticator, SelectiveAuthenticatorModel: self.create_selective_authenticator, + TokenPoolAuthenticatorModel: self.create_token_pool_authenticator, + RoundRobinRotationModel: self.create_round_robin_rotation, + RateLimitAwareRotationModel: self.create_rate_limit_aware_rotation, SimpleRetrieverModel: self.create_simple_retriever, StateDelegatingStreamModel: self.create_state_delegating_stream, SpecModel: self.create_spec, @@ -3310,6 +3327,49 @@ def create_selective_authenticator( **kwargs, ) + def create_token_pool_authenticator( + self, model: TokenPoolAuthenticatorModel, config: Config, **kwargs: Any + ) -> TokenPoolAuthenticator: + rotation_strategy = None + if model.rotation_strategy is not None: + rotation_strategy = self._create_component_from_model( + model=model.rotation_strategy, config=config + ) + return TokenPoolAuthenticator( + tokens=model.tokens, + config=config, + parameters=model.parameters or {}, + token_separator=model.token_separator or ",", + auth_method=model.auth_method if model.auth_method is not None else "Bearer", + header=model.header or "Authorization", + rotation_strategy=rotation_strategy, + ) + + def create_round_robin_rotation( + self, model: RoundRobinRotationModel, config: Config, **kwargs: Any + ) -> RoundRobinRotation: + return RoundRobinRotation( + tokens=[], # populated by TokenPoolAuthenticator.__post_init__ + parameters=model.parameters or {}, + ) + + def create_rate_limit_aware_rotation( + self, model: RateLimitAwareRotationModel, config: Config, **kwargs: Any + ) -> RateLimitAwareRotation: + return RateLimitAwareRotation( + tokens=[], # populated by TokenPoolAuthenticator.__post_init__ + parameters=model.parameters or {}, + ratelimit_remaining_header=model.ratelimit_remaining_header or "x-ratelimit-remaining", + ratelimit_reset_header=model.ratelimit_reset_header or "x-ratelimit-reset", + max_wait_seconds=model.max_wait_seconds if model.max_wait_seconds is not None else 7200, + budget_reserve_fraction=model.budget_reserve_fraction + if model.budget_reserve_fraction is not None + else 0.1, + budget_min_reserve=model.budget_min_reserve + if model.budget_min_reserve is not None + else 50, + ) + @staticmethod def create_legacy_session_token_authenticator( model: LegacySessionTokenAuthenticatorModel, config: Config, *, url_base: str, **kwargs: Any diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index 3ce4c8540..00a02218d 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -481,4 +481,7 @@ def send_request( exit_on_rate_limit=self._exit_on_rate_limit, ) + if response is not None and hasattr(self._authenticator, "on_http_response"): + self._authenticator.on_http_response(response) + return response diff --git a/unit_tests/sources/declarative/auth/test_token_pool_authenticator.py b/unit_tests/sources/declarative/auth/test_token_pool_authenticator.py new file mode 100644 index 000000000..d338ba5bc --- /dev/null +++ b/unit_tests/sources/declarative/auth/test_token_pool_authenticator.py @@ -0,0 +1,195 @@ +# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +# + +import time +from unittest.mock import MagicMock + +import pytest + +from airbyte_cdk.sources.declarative.auth.token_pool_authenticator import TokenPoolAuthenticator +from airbyte_cdk.sources.declarative.auth.token_rotation_strategies import ( + RateLimitAwareRotation, + RoundRobinRotation, +) + + +@pytest.mark.parametrize( + "tokens_str,separator,auth_method,header,expected_first,expected_second", + [ + pytest.param( + "token1,token2,token3", + ",", + "Bearer", + "Authorization", + "Bearer token1", + "Bearer token2", + id="default_bearer_comma_separated", + ), + pytest.param( + "tok_a|tok_b", + "|", + "token", + "Authorization", + "token tok_a", + "token tok_b", + id="pipe_separator_with_token_prefix", + ), + pytest.param( + "single_token", + ",", + "", + "X-Api-Key", + "single_token", + "single_token", + id="single_token_no_prefix", + ), + ], +) +def test_token_pool_authenticator_basic( + tokens_str, separator, auth_method, header, expected_first, expected_second +): + auth = TokenPoolAuthenticator( + tokens=tokens_str, + config={}, + parameters={}, + token_separator=separator, + auth_method=auth_method, + header=header, + ) + assert auth.auth_header == header + assert auth.token == expected_first + assert auth.token == expected_second + + +def test_token_pool_authenticator_round_robin_cycles(): + auth = TokenPoolAuthenticator( + tokens="a,b,c", + config={}, + parameters={}, + ) + results = [auth.token for _ in range(6)] + assert results == [ + "Bearer a", + "Bearer b", + "Bearer c", + "Bearer a", + "Bearer b", + "Bearer c", + ] + + +def test_token_pool_authenticator_interpolated_config(): + config = {"api_tokens": "key1,key2"} + auth = TokenPoolAuthenticator( + tokens="{{ config['api_tokens'] }}", + config=config, + parameters={}, + ) + assert auth.token == "Bearer key1" + assert auth.token == "Bearer key2" + + +def test_token_pool_authenticator_empty_tokens_raises(): + with pytest.raises(ValueError, match="at least one token"): + TokenPoolAuthenticator( + tokens="", + config={}, + parameters={}, + ) + + +def test_round_robin_rotation(): + strategy = RoundRobinRotation(tokens=["x", "y", "z"], parameters={}) + results = [strategy.get_active_token() for _ in range(6)] + assert results == ["x", "y", "z", "x", "y", "z"] + + +def test_rate_limit_aware_rotation_rotates_on_exhaustion(): + strategy = RateLimitAwareRotation( + tokens=["tok1", "tok2"], + parameters={}, + ) + # Initially tok1 + assert strategy.get_active_token() == "tok1" + + # Simulate tok1 exhausted + response = MagicMock() + response.headers = { + "x-ratelimit-remaining": "0", + "x-ratelimit-reset": str(int(time.time()) + 3600), + } + strategy.update_from_response(response) + + # Should now return tok2 + assert strategy.get_active_token() == "tok2" + + +def test_rate_limit_aware_rotation_uses_token_until_exhausted(): + strategy = RateLimitAwareRotation( + tokens=["tok1", "tok2"], + parameters={}, + budget_min_reserve=0, + budget_reserve_fraction=0.0, + ) + + response = MagicMock() + response.headers = { + "x-ratelimit-remaining": "100", + "x-ratelimit-reset": str(int(time.time()) + 3600), + } + strategy.update_from_response(response) + + # Still on tok1 since remaining > 0 + assert strategy.get_active_token() == "tok1" + + +def test_rate_limit_aware_rotation_raises_when_all_exhausted_and_max_wait_exceeded(): + strategy = RateLimitAwareRotation( + tokens=["tok1", "tok2"], + parameters={}, + max_wait_seconds=1, + ) + + # Exhaust tok1 (active_index=0) + response1 = MagicMock() + response1.headers = { + "x-ratelimit-remaining": "0", + "x-ratelimit-reset": str(int(time.time()) + 9999), + } + strategy.update_from_response(response1) + # tok1 is now at remaining=0, _active_index rotated to 1 + + # Exhaust tok2 (active_index=1) + response2 = MagicMock() + response2.headers = { + "x-ratelimit-remaining": "0", + "x-ratelimit-reset": str(int(time.time()) + 9999), + } + strategy.update_from_response(response2) + + # All tokens exhausted, reset time exceeds max_wait_seconds + with pytest.raises(RuntimeError, match="exceeds max_wait_seconds"): + strategy.get_active_token() + + +def test_on_http_response_called_on_authenticator(): + auth = TokenPoolAuthenticator( + tokens="tok1,tok2", + config={}, + parameters={}, + rotation_strategy=RateLimitAwareRotation( + tokens=["tok1", "tok2"], + parameters={}, + ), + ) + + response = MagicMock() + response.headers = { + "x-ratelimit-remaining": "0", + "x-ratelimit-reset": str(int(time.time()) + 3600), + } + auth.on_http_response(response) + + # After exhaustion of tok1, should get tok2 + assert auth.token == "Bearer tok2"