Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ permissions:
actions: write
contents: write

# Cancel in-progress runs on the same PR when a new push arrives.
# Push-to-main, tag, and workflow_call events are not cancelled so that
# release and deploy jobs always run to completion.
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}

jobs:
compatibility-checks:
name: Compatibility Checks
Expand Down
29 changes: 22 additions & 7 deletions transformer_lens/utilities/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import shutil
import stat
import time
from typing import Any, Callable, Dict

import torch
Expand Down Expand Up @@ -75,13 +76,27 @@ def download_file_from_hf(

If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
"""
file_path = hf_hub_download(
repo_id=repo_name,
filename=file_name,
subfolder=subfolder,
cache_dir=cache_dir,
**select_compatible_kwargs(kwargs, hf_hub_download),
)
max_retries = 3
for attempt in range(max_retries + 1):
try:
file_path = hf_hub_download(
repo_id=repo_name,
filename=file_name,
subfolder=subfolder,
cache_dir=cache_dir,
**select_compatible_kwargs(kwargs, hf_hub_download),
)
break
except Exception as exc:
if "429" in str(exc) and attempt < max_retries:
wait = 10 * (attempt + 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If many workers hit a 429 at the same time, all will retry at the same time, creating a repeated loop of failures. Can we add some random variance to the wait duration?

Additionally, we might want to increase the base duration of wait in this situation. HF's API timeout window is 5 minutes, if we have a large volume of requests at the same time we are going to want to spread out the calls on each retry using longer waits + variance to increase our chances of success.

print(
f"HuggingFace rate limited (429). Retrying in {wait}s "
f"({attempt + 1}/{max_retries})..."
)
time.sleep(wait)
else:
raise

if file_path.endswith(".pth") or force_is_torch:
return torch.load(file_path, map_location="cpu", weights_only=False)
Expand Down
Loading