diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index 5c55970fdf..2ad7e5bdc0 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -10,6 +10,7 @@ map_repo_download_progress_to_download_progress_data, resolve_model_in_path, ) +from exo.download.peer_shard_downloader import PeerAwareShardDownloader from exo.download.shard_downloader import ShardDownloader from exo.shared.constants import EXO_MODELS_DIR, EXO_MODELS_PATH from exo.shared.models.model_cards import ModelId, get_model_cards @@ -115,7 +116,15 @@ async def _command_processor(self) -> None: continue match cmd.command: - case StartDownload(shard_metadata=shard): + case StartDownload(shard_metadata=shard, available_peers=peers): + # Pass peer endpoints to the shard downloader if it supports it + if isinstance(self.shard_downloader, PeerAwareShardDownloader): + self.shard_downloader.set_available_peers(peers) + elif hasattr(self.shard_downloader, "shard_downloader") and isinstance( + self.shard_downloader.shard_downloader, PeerAwareShardDownloader # type: ignore[union-attr] + ): + # Unwrap SingletonShardDownloader + self.shard_downloader.shard_downloader.set_available_peers(peers) # type: ignore[union-attr] await self._start_download(shard) case DeleteDownload(model_id=model_id): await self._delete_download(model_id) diff --git a/src/exo/download/download_utils.py b/src/exo/download/download_utils.py index 3f6f1dc9dc..de1e82efad 100644 --- a/src/exo/download/download_utils.py +++ b/src/exo/download/download_utils.py @@ -1,5 +1,6 @@ import asyncio import hashlib +import json import os import shutil import ssl @@ -589,6 +590,9 @@ async def _download_file( ) as f: while chunk := await r.content.read(8 * 1024 * 1024): n_read = n_read + (await f.write(chunk)) + await f.flush() + # Write companion metadata for peer download streaming + await _write_partial_meta(partial_path, n_read, length, remote_hash) on_progress(n_read, length, False) final_hash = await calc_hash( @@ -604,10 +608,31 @@ async def _download_file( f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}" ) await aios.rename(partial_path, target_dir / path) + # Clean up companion metadata file + meta_path = Path(f"{partial_path}.meta") + if await aios.path.exists(meta_path): + await aios.remove(meta_path) on_progress(length, length, True) return target_dir / path +async def _write_partial_meta( + partial_path: Path, safe_bytes: int, total: int, etag: str +) -> None: + """Write companion .partial.meta file for peer download streaming. + + This small JSON file tells the peer file server how many bytes of the + .partial file have been safely flushed to disk and are safe to serve. + """ + meta_path = Path(f"{partial_path}.meta") + meta = json.dumps({"safe_bytes": safe_bytes, "total": total, "etag": etag}) + # Write to temp then rename for atomicity + tmp_path = Path(f"{partial_path}.meta.tmp") + async with aiofiles.open(tmp_path, "w") as f: + await f.write(meta) + await aios.rename(tmp_path, meta_path) + + def calculate_repo_progress( shard: ShardMetadata, model_id: ModelId, diff --git a/src/exo/download/impl_shard_downloader.py b/src/exo/download/impl_shard_downloader.py index d87da8eeff..3bf52a0e71 100644 --- a/src/exo/download/impl_shard_downloader.py +++ b/src/exo/download/impl_shard_downloader.py @@ -7,6 +7,7 @@ from loguru import logger from exo.download.download_utils import RepoDownloadProgress, download_shard +from exo.download.peer_shard_downloader import PeerAwareShardDownloader from exo.download.shard_downloader import ShardDownloader from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards from exo.shared.types.worker.shards import ( @@ -16,11 +17,16 @@ def exo_shard_downloader( - max_parallel_downloads: int = 8, offline: bool = False + max_parallel_downloads: int = 8, + offline: bool = False, + peer_download_enabled: bool = False, ) -> ShardDownloader: - return SingletonShardDownloader( - ResumableShardDownloader(max_parallel_downloads, offline=offline) + inner: ShardDownloader = ResumableShardDownloader( + max_parallel_downloads, offline=offline ) + if peer_download_enabled: + inner = PeerAwareShardDownloader(inner) + return SingletonShardDownloader(inner) async def build_base_shard(model_id: ModelId) -> ShardMetadata: diff --git a/src/exo/download/peer_download.py b/src/exo/download/peer_download.py new file mode 100644 index 0000000000..1fab3657d6 --- /dev/null +++ b/src/exo/download/peer_download.py @@ -0,0 +1,169 @@ +"""HTTP client for downloading model files from peer nodes. + +Instead of downloading from HuggingFace, nodes can fetch model files from +peers on the same LAN that already have them (or are still downloading them). +Falls back gracefully if the peer is unreachable or the transfer fails. +""" + +import asyncio +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +import aiofiles +import aiofiles.os as aios +import aiohttp +from loguru import logger + + +@dataclass(frozen=True) +class PeerFileInfo: + """Status of a single file on a peer node.""" + + path: str + size: int + complete: bool + safe_bytes: int + + +async def get_peer_file_status( + peer_host: str, + peer_port: int, + model_id_normalized: str, + timeout: float = 5.0, +) -> list[PeerFileInfo] | None: + """Query a peer's file server for available files for a model. + + Returns None if the peer is unreachable. + """ + url = f"http://{peer_host}:{peer_port}/status/{model_id_normalized}" + try: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session: + async with session.get(url) as r: + if r.status != 200: + return None + data = await r.json() + return [PeerFileInfo(**f) for f in data.get("files", [])] + except Exception as e: + logger.debug(f"Could not reach peer {peer_host}:{peer_port}: {e}") + return None + + +async def download_file_from_peer( + peer_host: str, + peer_port: int, + model_id_normalized: str, + file_path: str, + target_dir: Path, + expected_size: int, + on_progress: Callable[[int, int, bool], None] = lambda _a, _b, _c: None, + max_poll_attempts: int = 60, + poll_interval: float = 3.0, +) -> Path | None: + """Download a single file from a peer's file server. + + Supports streaming relay: if the peer is still downloading the file, + we fetch available bytes, wait, and poll for more until the file is + complete. + + Returns the final file path on success, or None on failure (caller + should fall back to HuggingFace). + """ + target_path = target_dir / file_path + partial_path = target_dir / f"{file_path}.partial" + + # Check if already complete locally + if await aios.path.exists(target_path): + local_size = (await aios.stat(target_path)).st_size + if local_size == expected_size: + on_progress(expected_size, expected_size, True) + return target_path + + await aios.makedirs((target_dir / file_path).parent, exist_ok=True) + + url = f"http://{peer_host}:{peer_port}/files/{model_id_normalized}/{file_path}" + n_read = 0 + + # Resume from existing partial + if await aios.path.exists(partial_path): + n_read = (await aios.stat(partial_path)).st_size + + poll_count = 0 + chunk_size = 8 * 1024 * 1024 # 8MB, matching HF download + + try: + while n_read < expected_size and poll_count < max_poll_attempts: + headers: dict[str, str] = {} + if n_read > 0: + headers["Range"] = f"bytes={n_read}-" + + got_bytes = False + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300, sock_read=60) + ) as session: + async with session.get(url, headers=headers) as r: + if r.status == 416: + # Range not satisfiable - peer doesn't have more yet + pass + elif r.status in (200, 206): + peer_complete = r.headers.get("X-Exo-Complete") == "true" + safe_bytes = int(r.headers.get("X-Exo-Safe-Bytes", "0")) + + async with aiofiles.open( + partial_path, "ab" if n_read > 0 else "wb" + ) as f: + while True: + chunk = await r.content.read(chunk_size) + if not chunk: + break + written = await f.write(chunk) + n_read += written + got_bytes = True + on_progress(n_read, expected_size, False) + elif r.status == 404: + logger.debug( + f"File {file_path} not found on peer {peer_host}" + ) + return None + else: + logger.warning( + f"Unexpected status {r.status} from peer {peer_host}" + ) + return None + + # Check if we're done + if n_read >= expected_size: + break + + # If we got no new bytes, the peer might still be downloading + if not got_bytes: + poll_count += 1 + logger.debug( + f"Waiting for peer {peer_host} to download more of {file_path} " + f"({n_read}/{expected_size}, poll {poll_count}/{max_poll_attempts})" + ) + await asyncio.sleep(poll_interval) + else: + # Got data, reset poll counter + poll_count = 0 + + if n_read < expected_size: + logger.warning( + f"Peer download incomplete for {file_path}: {n_read}/{expected_size}" + ) + return None + + # Rename partial to final + await aios.rename(partial_path, target_path) + on_progress(expected_size, expected_size, True) + logger.info( + f"Downloaded {file_path} from peer {peer_host} ({expected_size} bytes)" + ) + return target_path + + except Exception as e: + logger.warning(f"Peer download failed for {file_path} from {peer_host}: {e}") + return None diff --git a/src/exo/download/peer_file_server.py b/src/exo/download/peer_file_server.py new file mode 100644 index 0000000000..f36823ac27 --- /dev/null +++ b/src/exo/download/peer_file_server.py @@ -0,0 +1,174 @@ +"""Lightweight HTTP file server for peer-to-peer model downloads. + +Each exo node runs a PeerFileServer that serves model files from the local +cache directory. When one node finishes downloading a model from HuggingFace, +other nodes on the same LAN can fetch it directly over HTTP instead of +re-downloading from the internet. + +Supports serving in-progress downloads via .partial.meta files that track +how many bytes have been safely flushed to disk. +""" + +import json +from pathlib import Path + +import aiofiles +import aiofiles.os as aios +from aiohttp import web +from loguru import logger + + +class PeerFileServer: + """HTTP server that exposes local model files for peer download.""" + + def __init__(self, host: str, port: int, models_dir: Path) -> None: + self.host = host + self.port = port + self.models_dir = models_dir + self._app = web.Application() + self._app.router.add_get("/status/{model_id}", self._handle_status) + self._app.router.add_get( + "/files/{model_id}/{file_path:.+}", self._handle_file + ) + self._app.router.add_get("/health", self._handle_health) + self._runner: web.AppRunner | None = None + + async def run(self) -> None: + self._runner = web.AppRunner(self._app) + await self._runner.setup() + site = web.TCPSite(self._runner, self.host, self.port) + await site.start() + logger.info(f"PeerFileServer listening on {self.host}:{self.port}") + + async def shutdown(self) -> None: + if self._runner: + await self._runner.cleanup() + + async def _handle_health(self, request: web.Request) -> web.Response: + return web.json_response({"status": "ok"}) + + async def _handle_status(self, request: web.Request) -> web.Response: + """Return status of all files for a model (complete + in-progress).""" + model_id = request.match_info["model_id"] + model_dir = self.models_dir / model_id + + if not await aios.path.exists(model_dir): + return web.json_response({"files": []}) + + files = [] + for item in model_dir.iterdir(): + if item.is_dir() or item.name.endswith(".partial.meta"): + continue + + if item.name.endswith(".partial"): + # In-progress file - read meta for safe bytes + meta = await _read_partial_meta(item) + if meta: + files.append( + { + "path": item.name.removesuffix(".partial"), + "size": meta.get("total", 0), + "complete": False, + "safe_bytes": meta.get("safe_bytes", 0), + } + ) + else: + # Complete file + stat = await aios.stat(item) + files.append( + { + "path": item.name, + "size": stat.st_size, + "complete": True, + "safe_bytes": stat.st_size, + } + ) + + return web.json_response({"files": files}) + + async def _handle_file(self, request: web.Request) -> web.StreamResponse: + """Serve a model file with Range request support. + + For complete files: standard HTTP file serving. + For .partial files: serves only the safe byte range (flushed to disk). + """ + model_id = request.match_info["model_id"] + file_path = request.match_info["file_path"] + + model_dir = self.models_dir / model_id + complete_path = model_dir / file_path + partial_path = model_dir / f"{file_path}.partial" + + # Determine which file to serve and its safe size + if await aios.path.exists(complete_path): + serve_path = complete_path + file_size = (await aios.stat(complete_path)).st_size + safe_bytes = file_size + is_complete = True + elif await aios.path.exists(partial_path): + meta = await _read_partial_meta(partial_path) + if not meta or meta.get("safe_bytes", 0) == 0: + return web.Response(status=404, text="File not available yet") + serve_path = partial_path + file_size = meta.get("total", 0) + safe_bytes = meta["safe_bytes"] + is_complete = False + else: + return web.Response(status=404, text="File not found") + + # Parse Range header + range_header = request.headers.get("Range") + start = 0 + if range_header: + try: + range_spec = range_header.replace("bytes=", "") + start = int(range_spec.split("-")[0]) + except (ValueError, IndexError): + return web.Response(status=416, text="Invalid range") + + if start >= safe_bytes: + return web.Response(status=416, text="Range not satisfiable") + + end = safe_bytes # Serve up to safe boundary only + content_length = end - start + + response = web.StreamResponse( + status=206 if start > 0 else 200, + headers={ + "Content-Type": "application/octet-stream", + "Content-Length": str(content_length), + "Accept-Ranges": "bytes", + "Content-Range": f"bytes {start}-{end - 1}/{file_size}", + "X-Exo-Safe-Bytes": str(safe_bytes), + "X-Exo-Total-Size": str(file_size), + "X-Exo-Complete": "true" if is_complete else "false", + }, + ) + await response.prepare(request) + + chunk_size = 8 * 1024 * 1024 # 8MB chunks matching HF download + async with aiofiles.open(serve_path, "rb") as f: + await f.seek(start) + remaining = content_length + while remaining > 0: + to_read = min(chunk_size, remaining) + chunk = await f.read(to_read) + if not chunk: + break + await response.write(chunk) + remaining -= len(chunk) + + await response.write_eof() + return response + + +async def _read_partial_meta(partial_path: Path) -> dict | None: + """Read the .partial.meta companion file for a .partial download.""" + meta_path = Path(f"{partial_path}.meta") + if not await aios.path.exists(meta_path): + return None + try: + async with aiofiles.open(meta_path, "r") as f: + return json.loads(await f.read()) + except (json.JSONDecodeError, OSError): + return None diff --git a/src/exo/download/peer_shard_downloader.py b/src/exo/download/peer_shard_downloader.py new file mode 100644 index 0000000000..4b5a71db34 --- /dev/null +++ b/src/exo/download/peer_shard_downloader.py @@ -0,0 +1,273 @@ +"""Peer-aware shard downloader that tries LAN peers before HuggingFace. + +Wraps an existing ShardDownloader and adds a peer-download step: before +hitting HuggingFace, try peers provided in the available_peers list. +Falls back to the inner downloader (HF) if peer download fails. + +The peer list is computed by the Worker at command-emit time and passed +through the StartDownload command, keeping the download coordinator +decoupled from Worker state. +""" + +import asyncio +import time +from collections.abc import Awaitable +from datetime import timedelta +from pathlib import Path +from typing import AsyncIterator, Callable + +from loguru import logger + +from exo.download.download_utils import ( + RepoDownloadProgress, + calculate_repo_progress, + ensure_models_dir, + fetch_file_list_with_cache, + is_image_model, + resolve_allow_patterns, +) +from exo.download.huggingface_utils import filter_repo_objects +from exo.download.peer_download import ( + download_file_from_peer, + get_peer_file_status, +) +from exo.download.shard_downloader import ShardDownloader +from exo.shared.types.commands import PeerEndpoint +from exo.shared.types.memory import Memory +from exo.shared.types.worker.downloads import RepoFileDownloadProgress +from exo.shared.types.worker.shards import ShardMetadata + + +class PeerAwareShardDownloader(ShardDownloader): + """ShardDownloader that tries peer download before HuggingFace. + + Decorates an inner ShardDownloader (typically ResumableShardDownloader). + On ensure_shard(), if available_peers were provided, tries downloading + from them over the LAN first. Falls back to the inner downloader if + no peer has it or the transfer fails. + """ + + def __init__(self, inner: ShardDownloader) -> None: + self._inner = inner + self._progress_callbacks: list[ + Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]] + ] = [] + # Peers are set per-download by the coordinator before calling ensure_shard + self._current_peers: list[PeerEndpoint] = [] + + def set_available_peers(self, peers: list[PeerEndpoint]) -> None: + """Set the peers to try for the next ensure_shard call. + + Called by DownloadCoordinator before triggering a download, based + on the peers embedded in the StartDownload command. + """ + self._current_peers = peers + + def on_progress( + self, + callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]], + ) -> None: + self._inner.on_progress(callback) + self._progress_callbacks.append(callback) + + async def ensure_shard( + self, shard: ShardMetadata, config_only: bool = False + ) -> Path: + if config_only: + return await self._inner.ensure_shard(shard, config_only=True) + + model_id = shard.model_card.model_id + normalized = model_id.normalize() + peers = self._current_peers + self._current_peers = [] # Reset after consumption + + if not peers: + logger.debug(f"No peers available for {model_id}, downloading from HuggingFace") + return await self._inner.ensure_shard(shard, config_only=False) + + # Try each peer (already sorted by priority: RDMA first, completed first) + for peer in peers: + logger.info( + f"Attempting peer download of {model_id} from " + f"{peer.ip}:{peer.port} (status: {peer.status}, link: {peer.connection_type})" + ) + result = await self._try_peer_download( + shard, peer.ip, peer.port, normalized + ) + if result is not None: + logger.info( + f"Successfully downloaded {model_id} from peer {peer.ip}" + ) + return result + logger.info( + f"Peer download from {peer.ip} failed, trying next peer or HuggingFace" + ) + + # All peers failed, fall back to HuggingFace + logger.info(f"All peer downloads failed for {model_id}, falling back to HuggingFace") + return await self._inner.ensure_shard(shard, config_only=False) + + async def _try_peer_download( + self, + shard: ShardMetadata, + peer_ip: str, + peer_port: int, + model_id_normalized: str, + ) -> Path | None: + """Attempt to download all model files from a single peer. + + Returns the model directory path on success, None on failure. + """ + # First, check what the peer has + peer_files = await get_peer_file_status( + peer_ip, peer_port, model_id_normalized + ) + if not peer_files: + return None + + peer_file_map = {f.path: f for f in peer_files} + + # Get the file list we need (same logic as download_shard) + revision = "main" + target_dir = await ensure_models_dir() / model_id_normalized + + try: + file_list = await fetch_file_list_with_cache( + shard.model_card.model_id, + revision, + recursive=True, + skip_internet=False, + ) + except Exception: + return None + + allow_patterns = await resolve_allow_patterns(shard) + filtered_file_list = list( + filter_repo_objects( + file_list, allow_patterns=allow_patterns, key=lambda x: x.path + ) + ) + + if is_image_model(shard): + filtered_file_list = [ + f + for f in filtered_file_list + if "/" in f.path or not f.path.endswith(".safetensors") + ] + + # Check the peer has all (or most) files we need + files_on_peer = sum(1 for f in filtered_file_list if f.path in peer_file_map) + if files_on_peer == 0: + logger.debug(f"Peer has no files we need for {model_id_normalized}") + return None + + # Download from peer with progress tracking + all_start_time = time.time() + file_progress: dict[str, RepoFileDownloadProgress] = {} + semaphore = asyncio.Semaphore(8) + failed = False + + async def download_one(file_path: str, expected_size: int) -> bool: + def on_file_progress( + curr_bytes: int, total_bytes: int, is_renamed: bool + ) -> None: + file_progress[file_path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=file_path, + downloaded=Memory.from_bytes(curr_bytes), + downloaded_this_session=Memory.from_bytes(curr_bytes), + total=Memory.from_bytes(total_bytes), + speed=curr_bytes / max(time.time() - all_start_time, 0.1), + eta=timedelta( + seconds=(total_bytes - curr_bytes) + / max( + curr_bytes / max(time.time() - all_start_time, 0.1), + 0.1, + ) + ), + status="complete" if is_renamed else "in_progress", + start_time=all_start_time, + ) + progress = calculate_repo_progress( + shard, + shard.model_card.model_id, + revision, + file_progress, + all_start_time, + ) + for cb in self._progress_callbacks: + asyncio.create_task(cb(shard, progress)) + + async with semaphore: + result = await download_file_from_peer( + peer_ip, + peer_port, + model_id_normalized, + file_path, + target_dir, + expected_size, + on_progress=on_file_progress, + ) + return result is not None + + # Initialize progress for all files + for f in filtered_file_list: + file_progress[f.path] = RepoFileDownloadProgress( + repo_id=str(shard.model_card.model_id), + repo_revision=revision, + file_path=f.path, + downloaded=Memory.from_bytes(0), + downloaded_this_session=Memory.from_bytes(0), + total=Memory.from_bytes(f.size or 0), + speed=0, + eta=timedelta(0), + status="not_started", + start_time=all_start_time, + ) + + # Download all files in parallel + tasks = [] + for f in filtered_file_list: + if f.size is None or f.size == 0: + continue + peer_info = peer_file_map.get(f.path) + if peer_info and peer_info.safe_bytes > 0: + tasks.append(download_one(f.path, f.size)) + else: + failed = True + break + + if failed: + return None + + results = await asyncio.gather(*tasks, return_exceptions=True) + if any(isinstance(r, Exception) or r is False for r in results): + return None + + # Emit final progress + final_progress = calculate_repo_progress( + shard, + shard.model_card.model_id, + revision, + file_progress, + all_start_time, + ) + for cb in self._progress_callbacks: + await cb(shard, final_progress) + + gguf = next( + (f for f in filtered_file_list if f.path.endswith(".gguf")), None + ) + return (target_dir / gguf.path) if gguf else target_dir + + async def get_shard_download_status( + self, + ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]: + async for path, status in self._inner.get_shard_download_status(): + yield path, status + + async def get_shard_download_status_for_shard( + self, shard: ShardMetadata + ) -> RepoDownloadProgress: + return await self._inner.get_shard_download_status_for_shard(shard) diff --git a/src/exo/download/peer_state.py b/src/exo/download/peer_state.py new file mode 100644 index 0000000000..6f400b92a5 --- /dev/null +++ b/src/exo/download/peer_state.py @@ -0,0 +1,126 @@ +"""Pure functions for discovering which peers have which models. + +These functions are called by the Worker (which owns the State) to compute +peer availability at command-emit time. The results are embedded in the +StartDownload command so the download coordinator stays decoupled from +Worker state. +""" + +from loguru import logger + +from exo.shared.types.commands import PeerEndpoint +from exo.shared.types.common import NodeId +from exo.shared.types.state import State +from exo.shared.types.topology import RDMAConnection, SocketConnection +from exo.shared.types.worker.downloads import ( + DownloadCompleted, + DownloadOngoing, +) + + +def discover_peers_for_model( + node_id: NodeId, + state: State, + model_id_normalized: str, + peer_download_port: int, +) -> list[PeerEndpoint]: + """Find peers that have a specific model (complete or in-progress). + + Called by the Worker when emitting a StartDownload command. Returns + peers sorted by priority: RDMA/Thunderbolt connections first, then + completed downloads before ongoing ones. + + Args: + node_id: This node's ID (excluded from results). + state: The global State object (owned by Worker). + model_id_normalized: Normalized model ID (e.g. "org--model"). + peer_download_port: Port where peers run their PeerFileServer. + + Returns: + List of PeerEndpoint sorted by connection quality and completeness. + """ + peers: list[PeerEndpoint] = [] + + for peer_node_id, download_list in state.downloads.items(): + if peer_node_id == node_id: + continue + + for dl in download_list: + dl_model_id = dl.shard_metadata.model_card.model_id + if dl_model_id.normalize() != model_id_normalized: + continue + + if isinstance(dl, DownloadCompleted): + status = "complete" + elif isinstance(dl, DownloadOngoing): + status = "ongoing" + else: + continue + + # Resolve IP and connection type from topology + endpoint = _resolve_peer_endpoint( + node_id, peer_node_id, state, peer_download_port, status + ) + if endpoint: + peers.append(endpoint) + + # Sort by priority: + # 1. RDMA/Thunderbolt connections first (lower latency, higher bandwidth) + # 2. Completed downloads before ongoing ones + peers.sort( + key=lambda p: ( + 0 if p.connection_type == "rdma" else 1, + 0 if p.status == "complete" else 1, + ) + ) + return peers + + +def _resolve_peer_endpoint( + node_id: NodeId, + peer_node_id: NodeId, + state: State, + peer_download_port: int, + status: str, +) -> PeerEndpoint | None: + """Resolve a peer's IP address and connection type from the topology.""" + try: + # Check for RDMA connections first (highest priority) + for conn in state.topology.out_edges(node_id): + if conn.sink != peer_node_id: + continue + if isinstance(conn.edge, RDMAConnection): + # RDMA peer — still need IP from a socket connection + ip = _find_socket_ip(node_id, peer_node_id, state) + if ip: + return PeerEndpoint( + node_id=peer_node_id, + ip=ip, + port=peer_download_port, + status=status, + connection_type="rdma", + ) + elif isinstance(conn.edge, SocketConnection): + return PeerEndpoint( + node_id=peer_node_id, + ip=conn.edge.sink_multiaddr.ip_address, + port=peer_download_port, + status=status, + connection_type="socket", + ) + except Exception as e: + logger.debug(f"Could not resolve endpoint for peer {peer_node_id}: {e}") + return None + + +def _find_socket_ip( + node_id: NodeId, peer_node_id: NodeId, state: State +) -> str | None: + """Find a socket connection IP for a peer (used as fallback for RDMA peers).""" + try: + for conn in state.topology.out_edges(node_id): + if conn.sink == peer_node_id and isinstance(conn.edge, SocketConnection): + return conn.edge.sink_multiaddr.ip_address + except Exception: + pass + return None diff --git a/src/exo/download/tests/test_peer_download.py b/src/exo/download/tests/test_peer_download.py new file mode 100644 index 0000000000..04c49de786 --- /dev/null +++ b/src/exo/download/tests/test_peer_download.py @@ -0,0 +1,265 @@ +"""Tests for peer-to-peer model downloading.""" + +import asyncio +import json +from collections.abc import AsyncIterator +from pathlib import Path + +import aiofiles +import aiofiles.os as aios +import pytest + +from exo.download.peer_download import download_file_from_peer, get_peer_file_status +from exo.download.peer_file_server import PeerFileServer + + +@pytest.fixture +async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]: + """Set up a temporary models directory for testing.""" + models_dir = tmp_path / "models" + await aios.makedirs(models_dir, exist_ok=True) + yield models_dir + + +@pytest.fixture +async def peer_server(temp_models_dir: Path) -> AsyncIterator[PeerFileServer]: + """Start a PeerFileServer on a random port for testing.""" + server = PeerFileServer(host="127.0.0.1", port=0, models_dir=temp_models_dir) + # Use port 0 to let OS assign a free port + from aiohttp import web + + server._runner = web.AppRunner(server._app) + await server._runner.setup() + site = web.TCPSite(server._runner, "127.0.0.1", 0) + await site.start() + # Get the actual port assigned + server.port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] + yield server + await server.shutdown() + + +class TestPeerFileServer: + """Tests for the HTTP file server that serves model files to peers.""" + + async def test_health_check(self, peer_server: PeerFileServer) -> None: + """Health endpoint should return ok.""" + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/health" + ) as r: + assert r.status == 200 + data = await r.json() + assert data["status"] == "ok" + + async def test_status_empty_model(self, peer_server: PeerFileServer) -> None: + """Status for non-existent model should return empty file list.""" + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "nonexistent--model" + ) + assert files is not None + assert len(files) == 0 + + async def test_status_with_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report complete files correctly.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + # Create a complete test file + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(b'{"test": true}') + + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "test--model" + ) + assert files is not None + assert len(files) == 1 + assert files[0].path == "config.json" + assert files[0].complete is True + assert files[0].safe_bytes == 14 + + async def test_status_with_partial_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Status should report partial files with safe byte count.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + # Create a partial file with metadata + partial_data = b"x" * 1024 + async with aiofiles.open(model_dir / "weights.safetensors.partial", "wb") as f: + await f.write(partial_data) + + meta = {"safe_bytes": 1024, "total": 4096, "etag": "abc123"} + async with aiofiles.open( + model_dir / "weights.safetensors.partial.meta", "w" + ) as f: + await f.write(json.dumps(meta)) + + files = await get_peer_file_status( + "127.0.0.1", peer_server.port, "test--model" + ) + assert files is not None + assert len(files) == 1 + assert files[0].path == "weights.safetensors" + assert files[0].complete is False + assert files[0].safe_bytes == 1024 + assert files[0].size == 4096 + + async def test_serve_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should serve a complete file with correct headers.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"hello world test content" + async with aiofiles.open(model_dir / "config.json", "wb") as f: + await f.write(content) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/config.json" + ) as r: + assert r.status == 200 + assert r.headers["X-Exo-Complete"] == "true" + body = await r.read() + assert body == content + + async def test_serve_with_range_request( + self, peer_server: PeerFileServer, temp_models_dir: Path + ) -> None: + """Should support Range requests for resume.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"0123456789abcdef" + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(content) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/weights.bin", + headers={"Range": "bytes=8-"}, + ) as r: + assert r.status == 206 + body = await r.read() + assert body == b"89abcdef" + + async def test_file_not_found(self, peer_server: PeerFileServer) -> None: + """Should return 404 for missing files.""" + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://127.0.0.1:{peer_server.port}/files/test--model/missing.bin" + ) as r: + assert r.status == 404 + + +class TestPeerDownloadClient: + """Tests for downloading files from a peer server.""" + + async def test_download_complete_file( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Should download a complete file from peer.""" + # Set up source file on the peer server + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"model weights data " * 100 + async with aiofiles.open(model_dir / "weights.bin", "wb") as f: + await f.write(content) + + # Download to a different directory + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + progress_calls: list[tuple[int, int, bool]] = [] + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "weights.bin", + download_dir, + len(content), + on_progress=lambda c, t, r: progress_calls.append((c, t, r)), + ) + + assert result is not None + assert result == download_dir / "weights.bin" + async with aiofiles.open(result, "rb") as f: + downloaded = await f.read() + assert downloaded == content + # Should have progress calls including final + assert len(progress_calls) > 0 + assert progress_calls[-1][2] is True # is_renamed + + async def test_download_returns_none_on_missing( + self, peer_server: PeerFileServer, tmp_path: Path + ) -> None: + """Should return None when file doesn't exist on peer.""" + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "nonexistent.bin", + download_dir, + 1000, + ) + assert result is None + + async def test_download_returns_none_on_unreachable_peer( + self, tmp_path: Path + ) -> None: + """Should return None when peer is unreachable.""" + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + + result = await download_file_from_peer( + "127.0.0.1", + 19999, # Nobody listening + "test--model", + "weights.bin", + download_dir, + 1000, + ) + assert result is None + + async def test_skip_already_complete( + self, peer_server: PeerFileServer, temp_models_dir: Path, tmp_path: Path + ) -> None: + """Should skip download if file already exists locally with correct size.""" + model_dir = temp_models_dir / "test--model" + await aios.makedirs(model_dir, exist_ok=True) + + content = b"existing content" + # File already exists in target + download_dir = tmp_path / "downloads" / "test--model" + await aios.makedirs(download_dir, exist_ok=True) + async with aiofiles.open(download_dir / "config.json", "wb") as f: + await f.write(content) + + result = await download_file_from_peer( + "127.0.0.1", + peer_server.port, + "test--model", + "config.json", + download_dir, + len(content), + ) + + assert result is not None + assert result == download_dir / "config.json" diff --git a/src/exo/main.py b/src/exo/main.py index 2ecb62c2f1..9399b32dfe 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -13,11 +13,12 @@ import exo.routing.topics as topics from exo.download.coordinator import DownloadCoordinator from exo.download.impl_shard_downloader import exo_shard_downloader +from exo.download.peer_file_server import PeerFileServer from exo.master.api import API # TODO: should API be in master? from exo.master.main import Master from exo.routing.event_router import EventRouter from exo.routing.router import Router, get_node_id_keypair -from exo.shared.constants import EXO_LOG +from exo.shared.constants import EXO_LOG, EXO_MODELS_DIR, EXO_PEER_DOWNLOAD_PORT from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_setup from exo.shared.types.common import NodeId, SessionId @@ -40,6 +41,7 @@ class Node: node_id: NodeId offline: bool + peer_file_server: PeerFileServer | None = None _tg: TaskGroup = field(init=False, default_factory=TaskGroup) @classmethod @@ -63,17 +65,8 @@ async def create(cls, args: "Args") -> Self: logger.info(f"Starting node {node_id}") - # Create DownloadCoordinator (unless --no-downloads) - if not args.no_downloads: - download_coordinator = DownloadCoordinator( - node_id, - exo_shard_downloader(offline=args.offline), - event_sender=event_router.sender(), - download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), - offline=args.offline, - ) - else: - download_coordinator = None + peer_file_server: PeerFileServer | None = None + peer_download_enabled = not args.no_peer_download and not args.no_downloads if args.spawn_api: api = API( @@ -98,6 +91,28 @@ async def create(cls, args: "Args") -> Self: else: worker = None + # Create peer file server and download coordinator + if peer_download_enabled: + peer_file_server = PeerFileServer( + host="0.0.0.0", + port=EXO_PEER_DOWNLOAD_PORT, + models_dir=EXO_MODELS_DIR, + ) + + if not args.no_downloads: + download_coordinator: DownloadCoordinator | None = DownloadCoordinator( + node_id, + exo_shard_downloader( + offline=args.offline, + peer_download_enabled=peer_download_enabled, + ), + event_sender=event_router.sender(), + download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), + offline=args.offline, + ) + else: + download_coordinator = None + # We start every node with a master master = Master( node_id, @@ -134,6 +149,7 @@ async def create(cls, args: "Args") -> Self: api, node_id, args.offline, + peer_file_server, ) async def run(self): @@ -143,6 +159,8 @@ async def run(self): tg.start_soon(self.router.run) tg.start_soon(self.event_router.run) tg.start_soon(self.election.run) + if self.peer_file_server: + tg.start_soon(self.peer_file_server.run) if self.download_coordinator: tg.start_soon(self.download_coordinator.run) if self.worker: @@ -227,7 +245,10 @@ async def _elect_loop(self): self.download_coordinator.shutdown() self.download_coordinator = DownloadCoordinator( self.node_id, - exo_shard_downloader(offline=self.offline), + exo_shard_downloader( + offline=self.offline, + peer_download_enabled=self.peer_file_server is not None, + ), event_sender=self.event_router.sender(), download_command_receiver=self.router.receiver( topics.DOWNLOAD_COMMANDS @@ -303,6 +324,7 @@ class Args(CamelCaseModel): tb_only: bool = False no_worker: bool = False no_downloads: bool = False + no_peer_download: bool = False offline: bool = os.getenv("EXO_OFFLINE", "false").lower() == "true" no_batch: bool = False fast_synch: bool | None = None # None = auto, True = force on, False = force off @@ -352,6 +374,11 @@ def parse(cls) -> Self: action="store_true", help="Disable the download coordinator (node won't download models)", ) + parser.add_argument( + "--no-peer-download", + action="store_true", + help="Disable peer-to-peer model downloads (each node downloads from HuggingFace independently)", + ) parser.add_argument( "--offline", action="store_true", diff --git a/src/exo/shared/constants.py b/src/exo/shared/constants.py index 695342983a..1fa2ef3332 100644 --- a/src/exo/shared/constants.py +++ b/src/exo/shared/constants.py @@ -83,3 +83,6 @@ def _get_xdg_dir(env_var: str, fallback: str) -> Path: EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "false").lower() == "true" EXO_MAX_CONCURRENT_REQUESTS = int(os.getenv("EXO_MAX_CONCURRENT_REQUESTS", "8")) + +# Peer-to-peer model download server port (one above default API port) +EXO_PEER_DOWNLOAD_PORT = int(os.getenv("EXO_PEER_DOWNLOAD_PORT", "52416")) diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index e1c372889b..af55130a01 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -66,9 +66,20 @@ class RequestEventLog(BaseCommand): since_idx: int +class PeerEndpoint(CamelCaseModel): + """A peer node that has (or is downloading) a model, with its network address.""" + + node_id: NodeId + ip: str + port: int + status: str = "complete" # "complete" or "ongoing" + connection_type: str = "socket" # "rdma" or "socket" + + class StartDownload(BaseCommand): target_node_id: NodeId shard_metadata: ShardMetadata + available_peers: list[PeerEndpoint] = Field(default_factory=list) class DeleteDownload(BaseCommand): diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index f993a261c7..ab8c38ab59 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -6,6 +6,8 @@ from loguru import logger from exo.download.download_utils import resolve_model_in_path +from exo.download.peer_state import discover_peers_for_model +from exo.shared.constants import EXO_PEER_DOWNLOAD_PORT from exo.shared.apply import apply from exo.shared.models.model_cards import ModelId from exo.shared.types.api import ImageEditsTaskParams @@ -193,12 +195,20 @@ async def plan_step(self): ) ) else: + # Discover peers that already have this model + peers = discover_peers_for_model( + self.node_id, + self.state, + shard.model_card.model_id.normalize(), + EXO_PEER_DOWNLOAD_PORT, + ) await self.download_command_sender.send( ForwarderDownloadCommand( origin=self._system_id, command=StartDownload( target_node_id=self.node_id, shard_metadata=shard, + available_peers=peers, ), ) )