Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion airbyte_cdk/sources/declarative/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,18 @@

from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeOauth2Authenticator
from airbyte_cdk.sources.declarative.auth.token_pool_authenticator import TokenPoolAuthenticator
from airbyte_cdk.sources.declarative.auth.token_rotation_strategies import (
RateLimitAwareRotation,
RoundRobinRotation,
TokenRotationStrategy,
)

__all__ = ["DeclarativeOauth2Authenticator", "JwtAuthenticator"]
__all__ = [
"DeclarativeOauth2Authenticator",
"JwtAuthenticator",
"RateLimitAwareRotation",
"RoundRobinRotation",
"TokenPoolAuthenticator",
"TokenRotationStrategy",
]
97 changes: 97 additions & 0 deletions airbyte_cdk/sources/declarative/auth/token_pool_authenticator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
#

from dataclasses import InitVar, dataclass, field
from typing import Any, List, Mapping, Optional, Union

import requests

from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
from airbyte_cdk.sources.declarative.auth.token_rotation_strategies import (
RateLimitAwareRotation,
RoundRobinRotation,
TokenRotationStrategy,
)
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.types import Config


@dataclass
class TokenPoolAuthenticator(DeclarativeAuthenticator):
"""Authenticator that rotates through multiple API tokens.

Accepts a list of tokens (or a delimited string) and rotates through them
using a configurable strategy (round-robin or rate-limit-aware). This enables
distributing rate-limit consumption across multiple credentials.

Attributes:
tokens: Interpolated string resolving to the token(s). Can be a single token
or multiple tokens joined by `token_separator`.
config: The user-provided configuration.
parameters: Additional runtime parameters for string interpolation.
token_separator: Delimiter used to split `tokens` into individual values.
auth_method: Prefix for the token value in the header (e.g., "Bearer", "token", "").
header: HTTP header name to set.
rotation_strategy: Strategy controlling how tokens are rotated.
"""

tokens: Union[InterpolatedString, str]
config: Config
parameters: InitVar[Mapping[str, Any]]
token_separator: str = ","
auth_method: str = "Bearer"
header: str = "Authorization"
rotation_strategy: Optional[TokenRotationStrategy] = None

_token_list: List[str] = field(default_factory=list, init=False, repr=False)
_strategy: TokenRotationStrategy = field(init=False, repr=False)

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
tokens_interpolated = InterpolatedString.create(self.tokens, parameters=parameters)
raw_tokens = str(tokens_interpolated.eval(self.config))
self._token_list = [t.strip() for t in raw_tokens.split(self.token_separator) if t.strip()]

if not self._token_list:
raise ValueError("TokenPoolAuthenticator requires at least one token.")

if self.rotation_strategy is not None:
self._strategy = self.rotation_strategy
# Inject token list into strategy if not already populated
if hasattr(self._strategy, "tokens") and not self._strategy.tokens:
self._strategy.tokens = self._token_list
if isinstance(self._strategy, RoundRobinRotation):
self._strategy.__post_init__(parameters)
elif isinstance(self._strategy, RateLimitAwareRotation):
self._strategy.__post_init__(parameters)
else:
# Default to round-robin
self._strategy = RoundRobinRotation(tokens=self._token_list, parameters=parameters)

@property
def auth_header(self) -> str:
return self.header

@property
def token(self) -> str:
raw_token = self._strategy.get_active_token()
if self.auth_method:
return f"{self.auth_method} {raw_token}"
return raw_token

def on_http_response(self, response: requests.Response) -> None:
"""Called after each HTTP response to update per-token rate-limit state."""
self._strategy.update_from_response(response)

def update_token(self) -> None:
"""Force rotation to the next token.

Provided for compatibility with imperative-style connectors that call
`authenticator.update_token()` from backoff strategies.
"""
if isinstance(self._strategy, RateLimitAwareRotation):
self._strategy._rotate()
elif isinstance(self._strategy, RoundRobinRotation):
# RoundRobinRotation advances on each get_active_token() call, so
# calling get_active_token() once consumes the rotation.
self._strategy.get_active_token()
210 changes: 210 additions & 0 deletions airbyte_cdk/sources/declarative/auth/token_rotation_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
#

import logging
import time
from abc import ABC, abstractmethod
from dataclasses import InitVar, dataclass, field
from itertools import cycle
from typing import Any, Dict, List, Mapping, Optional

import requests

from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse

logger = logging.getLogger("airbyte")


@dataclass
class TokenState:
"""Tracks rate-limit state for an individual token."""

remaining: int = -1 # -1 means unknown
reset_at: Optional[AirbyteDateTime] = None


class TokenRotationStrategy(ABC):
"""Base class for token rotation strategies."""

@abstractmethod
def get_active_token(self) -> str:
"""Return the currently active token."""
raise NotImplementedError

def update_from_response(self, response: requests.Response) -> None:
"""Update internal state from an HTTP response. Override in subclasses."""
pass


@dataclass
class RoundRobinRotation(TokenRotationStrategy):
"""Cycle through tokens on each `get_active_token()` call."""

tokens: List[str]
parameters: InitVar[Mapping[str, Any]]

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._iter = cycle(self.tokens)
self._current = next(self._iter)

def get_active_token(self) -> str:
token = self._current
self._current = next(self._iter)
return token


@dataclass
class RateLimitAwareRotation(TokenRotationStrategy):
"""Track per-token quota from response headers and rotate when exhausted.

When a token's remaining quota hits zero, rotate to the next token. When all
tokens are exhausted, sleep until the earliest reset time. Proactive throttling
spreads remaining calls over the reset window to avoid hitting the wall.
"""

tokens: List[str]
parameters: InitVar[Mapping[str, Any]]
ratelimit_remaining_header: str = "x-ratelimit-remaining"
ratelimit_reset_header: str = "x-ratelimit-reset"
max_wait_seconds: int = 7200
budget_reserve_fraction: float = 0.1
budget_min_reserve: int = 50

_token_state: Dict[str, TokenState] = field(default_factory=dict, init=False)
_active_index: int = field(default=0, init=False)
_budget_logged: bool = field(default=False, init=False)

HEARTBEAT_INTERVAL: float = 60.0

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._token_state = {t: TokenState() for t in self.tokens}

def get_active_token(self) -> str:
"""Return the active token, rotating if the current one is exhausted."""
attempts = 0
while attempts < len(self.tokens):
token = self.tokens[self._active_index]
state = self._token_state[token]
# If remaining is unknown (-1) or > 0, use this token
if state.remaining != 0:
return token
# Current token is exhausted, try next
self._rotate()
attempts += 1

# All tokens are exhausted — sleep until earliest reset
self._sleep_until_reset()
return self.tokens[self._active_index]

def update_from_response(self, response: requests.Response) -> None:
"""Update the active token's state from response headers."""
token = self.tokens[self._active_index]
state = self._token_state[token]

remaining_header = response.headers.get(self.ratelimit_remaining_header)
reset_header = response.headers.get(self.ratelimit_reset_header)

if remaining_header is not None:
try:
state.remaining = int(remaining_header)
except (ValueError, TypeError):
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
logger.debug(
"Could not parse ratelimit-remaining header value: %s", remaining_header
)

if reset_header is not None:
try:
reset_ts = float(reset_header)
state.reset_at = ab_datetime_parse(str(int(reset_ts)))
except (ValueError, TypeError):
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
logger.debug("Could not parse ratelimit-reset header value: %s", reset_header)

# Proactive rotation: if remaining is below reserve, rotate
if state.remaining >= 0:
reserve = self._get_budget_reserve(state)
if state.remaining <= reserve:
self._maybe_throttle(state)
if state.remaining == 0:
self._rotate()

def _rotate(self) -> None:
self._active_index = (self._active_index + 1) % len(self.tokens)

def _get_budget_reserve(self, state: TokenState) -> int:
"""Return the minimum number of calls to keep in reserve for a token."""
limit_estimate = max(5000, state.remaining) if state.remaining > 0 else 5000
return max(self.budget_min_reserve, int(limit_estimate * self.budget_reserve_fraction))

def _maybe_throttle(self, state: TokenState) -> None:
"""Inject a small delay when all tokens are running low."""
if not all(
s.remaining >= 0 and s.remaining <= self._get_budget_reserve(s)
for s in self._token_state.values()
if s.remaining >= 0
):
return

if state.reset_at is None:
return

seconds_to_reset = max((state.reset_at - ab_datetime_now()).total_seconds(), 0)
total_remaining = sum(max(s.remaining, 0) for s in self._token_state.values())
if total_remaining <= 0 or seconds_to_reset <= 0:
return

delay = min(seconds_to_reset / total_remaining, 10.0)
if delay >= 0.1:
if not self._budget_logged:
logger.info(
"API budget: throttling requests (%.1fs delay). %d calls remaining across %d token(s), "
"%.0fs until reset.",
delay,
total_remaining,
len(self.tokens),
seconds_to_reset,
)
self._budget_logged = True
time.sleep(delay)

def _sleep_until_reset(self) -> None:
"""Sleep until the earliest token reset time, or raise if too long."""
reset_times = [s.reset_at for s in self._token_state.values() if s.reset_at is not None]
if not reset_times:
raise RuntimeError(
"All tokens in the pool are exhausted and no reset time is available."
)

earliest_reset = min(reset_times)
wait_seconds = max((earliest_reset - ab_datetime_now()).total_seconds(), 0)

if wait_seconds > self.max_wait_seconds:
raise RuntimeError(
f"All tokens in the pool are exhausted. Earliest reset in {wait_seconds:.0f}s "
f"exceeds max_wait_seconds ({self.max_wait_seconds}s)."
)

logger.info(
"All tokens exhausted. Sleeping %.0fs until rate limit resets.",
wait_seconds,
)
self._sleep_with_heartbeat(wait_seconds)

# Reset state for all tokens after sleeping
for state in self._token_state.values():
state.remaining = -1
state.reset_at = None
self._budget_logged = False

def _sleep_with_heartbeat(self, total_seconds: float) -> None:
"""Sleep with periodic log messages to keep the heartbeat alive."""
remaining = total_seconds
while remaining > 0:
chunk = min(remaining, self.HEARTBEAT_INTERVAL)
time.sleep(chunk)
remaining -= chunk
if remaining > 0:
logger.info(
"Rate limit exhausted. Waiting for reset — %.0fs remaining.",
remaining,
)
Loading
Loading