diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index e5e595052..ef554b1b4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -42,6 +42,13 @@ permissions: actions: write contents: write +# Cancel in-progress runs on the same PR when a new push arrives. +# Push-to-main, tag, and workflow_call events are not cancelled so that +# release and deploy jobs always run to completion. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + jobs: compatibility-checks: name: Compatibility Checks diff --git a/transformer_lens/utilities/hf_utils.py b/transformer_lens/utilities/hf_utils.py index 580255b1a..b39bb08c7 100644 --- a/transformer_lens/utilities/hf_utils.py +++ b/transformer_lens/utilities/hf_utils.py @@ -11,6 +11,7 @@ import os import shutil import stat +import time from typing import Any, Callable, Dict import torch @@ -75,13 +76,27 @@ def download_file_from_hf( If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. """ - file_path = hf_hub_download( - repo_id=repo_name, - filename=file_name, - subfolder=subfolder, - cache_dir=cache_dir, - **select_compatible_kwargs(kwargs, hf_hub_download), - ) + max_retries = 3 + for attempt in range(max_retries + 1): + try: + file_path = hf_hub_download( + repo_id=repo_name, + filename=file_name, + subfolder=subfolder, + cache_dir=cache_dir, + **select_compatible_kwargs(kwargs, hf_hub_download), + ) + break + except Exception as exc: + if "429" in str(exc) and attempt < max_retries: + wait = 10 * (attempt + 1) + print( + f"HuggingFace rate limited (429). Retrying in {wait}s " + f"({attempt + 1}/{max_retries})..." + ) + time.sleep(wait) + else: + raise if file_path.endswith(".pth") or force_is_torch: return torch.load(file_path, map_location="cpu", weights_only=False)