diff --git a/examples/fully_async/README.md b/examples/fully_async/README.md index 36a36cb7d9..c9c5654c90 100644 --- a/examples/fully_async/README.md +++ b/examples/fully_async/README.md @@ -3,33 +3,125 @@ This example shows a simple way to make rollout generation **fully asynchronous**: a single global worker is created once and then keeps running in the background, continuously pulling prompts and launching generation tasks. Training only needs to fetch already finished results. This removes the per‑step wait that happens in the normal synchronous style. ## Files -* `fully_async_rollout.py`: global async worker + `generate_rollout_fully_async` entry. -* `run-qwen3-4b-fully_async.sh`: example launch script with Qwen3‑4B. +* `fully_async_rollout.py`: global async worker + `generate_rollout_fully_async` entry, including off-policy buffer management. +* `run-qwen3.5-4b-off-policy-benchmark.sh`: multi-mode off-policy benchmark script supporting one-step-off baseline, fully async, staleness-backpressure, and window-evict modes. ## Prerequisite -First set up model & environment following the Qwen3-4B example. +First set up model & environment following the Qwen3.5-4B example. ## Quick Start + +**Off-policy benchmark (4 modes):** ```bash -cd slime -bash examples/fully_async/run-qwen3-4b-fully_async.sh +# One-step off-policy async baseline (default rollout, no fully async worker) +MODE=one_step_off bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh + +# Fully async, no staleness control +MODE=fully_async bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh + +# Fully async + staleness backpressure + partial rollout +MODE=staleness_partial bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh + +# Fully async + version-window eviction + partial rollout +MODE=window_partial bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh ``` + You should see log lines like: ``` Creating new global async worker... Continuous async rollout worker started ``` -## How It Works (Very Short) +## How It Works + * First call: create `AsyncRolloutWorker` (thread + asyncio loop). * Loop keeps up to `--rollout-batch-size` tasks in flight using `generate_and_rm_group`. -* Completed groups are pushed into a queue; caller drains until it has enough samples. +* Completed groups are pushed into a `CompletedSampleRecord` store; caller drains until it has enough samples. * Worker is stopped automatically at process exit. -## Limitations -* No evaluation mode. -* Ordering is best effort (sorted at the end by index). -* Minimal error handling. +### Why do we need both staleness control and version-window eviction? + +The two existing async modes (`one_step_off` and `fully_async`) both lack the ability to control off-policy staleness and neither supports **partial rollout**: + +- **`one_step_off`**: uses the default `sglang_rollout` path. Although `sglang_rollout.py` internally implements `abort()` and partial-rollout recycling, the original `train_async.py` did not have `before_weight_update` / `after_weight_update` lifecycle hooks, so the training loop never notified the rollout module before a weight sync. In-flight tasks were simply drained to completion, making partial rollout impossible. +- **`fully_async`**: the original async worker had no concept of policy version tracking, no staleness budget, no `abort()` call, and no weight-update hooks. The worker ran continuously without any coordination with weight updates, so partial rollout was equally unsupported. + +The first fix is `staleness_partial`: it adds policy-version tracking, a stale-sample budget, and the lifecycle hooks needed by partial rollout. The staleness/backpressure idea here is close to the fully async design used in VERL. + +With the new lifecycle hooks (`before_weight_update` / `after_weight_update`) wired into `train_async.py` and `RolloutManager`, the async worker can now abort in-flight tasks before each weight update, recycle partially generated samples back to the data buffer, and mask off-policy tokens during training. + +However, staleness backpressure still has two practical limitations: + +1. If rollout throughput is lower than training consumption throughput, pausing new scheduling can introduce rollout bubbles and make the rollout side fall further behind. +2. When partial rollout is enabled, a common strategy is to prioritize recycled samples so they are resumed first. That improves reuse of partial work, but it also means a single trajectory may span many policy versions, so the `version span` can still lag by much more than 1 even if the stale backlog is bounded. + +That is why `window_evict` is introduced after staleness control. If you want to strictly cap the allowed version lag, for example keep it `<= 1`, while also avoiding rollout pauses when rollout is faster than training, `window_evict` is a better fit. Its sliding-version-window eviction behavior is mainly inspired by MiniMax Forge. + +### Off-Policy Buffer Policies + +In fully async mode, the rollout worker runs continuously and may produce samples generated under an older policy version. Two buffer policies control how these **stale (off-policy) samples** are managed: + +#### Buffer Policy Comparison + +| Feature | `legacy_backpressure` | `window_evict` | +|---------|----------------------|----------------| +| **Scheduling** | Pauses when stale budget reached | Never pauses, always scheduling | +| **Sample Eviction** | No eviction | Actively evicts out-of-window samples | +| **GPU Utilization** | May have idle periods | Always high utilization | +| **Version Lag Control** | Soft control (backlog ratio) | Hard control (window width W) | +| **Partial Rollout Span** | May span many versions | Bounded to ≤ W+1 versions | +| **Key Parameter** | `--staleness-threshold` | `--fully-async-version-window` | + +#### `legacy_backpressure` (default; used by `staleness_partial`) + +Pause scheduling new rollout tasks when the number of stale samples reaches a configurable budget: + +``` +budget = rollout_batch_size × update_weights_interval × (1 + staleness_threshold) +``` + +The worker resumes after the trainer consumes enough samples to bring the stale count below the budget. This is the simpler staleness-control mode, but pausing can leave rollout GPUs idle and it does not strictly bound per-trajectory version span under partial rollout. + +#### `window_evict` (used by `window_partial`) + +Keep rollout scheduling active at all times. Instead of pausing, evict completed samples whose policy version falls outside a sliding window `[current_version - W, current_version]`. This trades sample efficiency (some generated samples are discarded) for higher GPU utilization and a stricter bound on allowed version lag. + +Key parameters: +- `--fully-async-version-window W`: window width (default 1). +- `--fully-async-max-completed-samples N`: hard cap on buffered samples. +- `--fully-async-eviction-policy`: `drop_oldest_version` (default) or `drop_oldest_fifo`. + +### Partial Rollout & Off-Policy Masking + +When `--partial-rollout` is enabled, in-flight rollout tasks are **aborted** before each weight update rather than drained to completion. The partially generated samples are returned to the data buffer and re-scheduled under the new policy. + +Combined with `--mask-offpolicy-in-partial-rollout`, any trajectory whose generation spans multiple policy versions will have its off-policy tokens masked during training loss computation, ensuring that only on-policy tokens contribute to gradient updates. + +### Lifecycle Hooks + +The training loop (`train_async.py`) calls `RolloutManager.before_weight_update` / `after_weight_update` around each weight sync. These hooks are forwarded to module-level functions in the rollout module (`before_weight_update`, `after_weight_update` in `fully_async_rollout.py`), enabling the async worker to: +1. Pause scheduling and drain/abort in-flight tasks before weights change. +2. Update the internal policy version, evict out-of-window samples, and resume after weights are synced. +3. Report per-interval staleness and eviction metrics to wandb. + +## New CLI Arguments + +| Argument | Type | Default | Description | +|---|---|---|---| +| `--staleness-threshold` | float | None | Max stale backlog ratio. Enables backpressure when set. | +| `--fully-async-buffer-policy` | str | `legacy_backpressure` | Buffer policy: `legacy_backpressure` or `window_evict`. | +| `--fully-async-version-window` | int | 1 | Policy-version window width for `window_evict`. | +| `--fully-async-max-completed-samples` | int | auto | Hard cap on completed samples in memory. | +| `--fully-async-eviction-policy` | str | `drop_oldest_version` | Overflow eviction strategy for `window_evict`. | +| `--fully-async-debug-version-tracking` | flag | False | Print per-batch version summaries for debugging. | + +## Wandb Metrics + +When enabled, the following metric groups are logged under a dedicated `fully_async/step` axis: + +- `fully_async/count/*`: stale samples processed, consumed, recycled, dropped. +- `fully_async/partial/*`: partial rollout ratio and max version span. +- `fully_async/window/*`: completed store size, eligible samples, eviction counts. ## Config Differences (2 Key Points) To enable the fully async pattern there are only two changes compared to a normal run: diff --git a/examples/fully_async/fully_async_rollout.py b/examples/fully_async/fully_async_rollout.py index 7208365c18..24a0b7e66e 100644 --- a/examples/fully_async/fully_async_rollout.py +++ b/examples/fully_async/fully_async_rollout.py @@ -1,21 +1,233 @@ import asyncio import atexit +import math import queue import threading import time +from collections import Counter +from collections.abc import Iterable +from dataclasses import dataclass -# Import core functions from sglang_rollout directly to avoid code duplication -from slime.rollout.sglang_rollout import GenerateState, generate_and_rm_group +from slime.rollout.base_types import RolloutFnTrainOutput +from slime.rollout.sglang_rollout import GenerateState, abort, eval_rollout, generate_and_rm_group from slime.utils.async_utils import run from slime.utils.types import Sample -# Global worker manager _global_worker = None _worker_lock = threading.Lock() +def _extract_sample_id(group: list[Sample]) -> int | None: + if not group: + return None + sample = group[0] + for value in ( + sample.metadata.get("fully_async_sample_id"), + sample.metadata.get("fully_async_group_id"), + sample.group_index, + ): + if value is not None: + return value + return None + + +def _derive_max_stale_samples(args) -> int | None: + staleness_threshold = getattr(args, "staleness_threshold", None) + if staleness_threshold is None: + return None + return max( + 0, + math.ceil(args.rollout_batch_size * args.update_weights_interval * (1 + staleness_threshold)), + ) + + +def _derive_max_completed_samples(args, max_stale_samples: int | None) -> int: + configured_max_completed_samples = getattr(args, "fully_async_max_completed_samples", None) + if configured_max_completed_samples is not None: + return max(1, configured_max_completed_samples) + return max(1000, (max_stale_samples or 0) + args.rollout_batch_size) + + +def _normalize_policy_version(value) -> int | None: + if value is None: + return None + if isinstance(value, bool): + return int(value) + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + return int(stripped) + except ValueError: + return None + return None + + +def _extract_trajectory_versions_with_source(sample: Sample) -> tuple[str, list[int]]: + versions = [_normalize_policy_version(version) for version in sample.weight_versions] + versions = [version for version in versions if version is not None] + if versions: + return "weight_versions", versions + + scheduled_versions = [ + _normalize_policy_version(version) for version in sample.metadata.get("fully_async_schedule_versions", []) + ] + scheduled_versions = [version for version in scheduled_versions if version is not None] + if scheduled_versions: + return "fully_async_schedule_versions", scheduled_versions + + fallback_version = _normalize_policy_version(sample.metadata.get("policy_version")) + return ("policy_version", [fallback_version]) if fallback_version is not None else ("none", []) + + +def _weight_version_to_policy_version(version: int) -> int: + # Rollout engine weight_version starts from 1 after the initial sync while + # training policy_version starts from 0, so align the fallback path here. + return max(version - 1, 0) + + +def _extract_staleness_versions_with_source(sample: Sample) -> tuple[str, list[int]]: + scheduled_versions = [ + _normalize_policy_version(version) for version in sample.metadata.get("fully_async_schedule_versions", []) + ] + scheduled_versions = [version for version in scheduled_versions if version is not None] + if scheduled_versions: + return "fully_async_schedule_versions", scheduled_versions + + fallback_version = _normalize_policy_version(sample.metadata.get("policy_version")) + if fallback_version is not None: + return "policy_version", [fallback_version] + + versions = [_normalize_policy_version(version) for version in sample.weight_versions] + versions = [_weight_version_to_policy_version(version) for version in versions if version is not None] + if versions: + return "weight_versions", versions + + return ("none", []) + + +def _extract_trajectory_versions(sample: Sample) -> list[int]: + _, versions = _extract_staleness_versions_with_source(sample) + return versions + + +def _summarize_processed_group(group: list[Sample], current_policy_version: int) -> dict[str, object]: + source_counts = Counter() + staleness_source_counts = Counter() + group_mins: list[int] = [] + group_maxs: list[int] = [] + stale_trajectory_count = 0 + trajectory_summaries = [] + + for idx, sample in enumerate(group): + source, raw_versions = _extract_trajectory_versions_with_source(sample) + staleness_source, versions = _extract_staleness_versions_with_source(sample) + source_counts[source] += 1 + staleness_source_counts[staleness_source] += 1 + max_version = max(versions) if versions else None + min_version = min(versions) if versions else None + is_stale = max_version is not None and current_policy_version - max_version >= 1 + if min_version is not None: + group_mins.append(min_version) + if max_version is not None: + group_maxs.append(max_version) + if is_stale: + stale_trajectory_count += 1 + trajectory_summaries.append( + { + "trajectory_index": idx, + "source": source, + "versions": raw_versions, + "staleness_source": staleness_source, + "staleness_versions": versions, + "weight_versions": list(sample.weight_versions), + "schedule_versions": list(sample.metadata.get("fully_async_schedule_versions", [])), + "policy_version": sample.metadata.get("policy_version"), + "is_stale": is_stale, + } + ) + + group_min_version = min(group_mins) if group_mins else None + group_max_version = max(group_maxs) if group_maxs else None + return { + "sample_id": _extract_sample_id(group), + "group_size": len(group), + "group_min_version": group_min_version, + "group_max_version": group_max_version, + "stale_sample": group_max_version is not None and current_policy_version - group_max_version >= 1, + "stale_trajectory_count": stale_trajectory_count, + "partial_span": ( + max(0, group_max_version - group_min_version) + if group_min_version is not None and group_max_version is not None + else 0 + ), + "source_counts": dict(source_counts), + "staleness_source_counts": dict(staleness_source_counts), + "trajectory_summaries": trajectory_summaries, + } + + +def _log_processed_group_debug( + args, + groups: list[list[Sample]], + *, + current_policy_version: int, + rollout_id: int, + drained_group_count: int, + leftover_group_count: int, +) -> None: + if not getattr(args, "fully_async_debug_version_tracking", False): + return + + summaries = [_summarize_processed_group(group, current_policy_version) for group in groups] + stale_sample_count = sum(1 for summary in summaries if summary["stale_sample"]) + stale_trajectory_count = sum(int(summary["stale_trajectory_count"]) for summary in summaries) + partial_group_count = sum(1 for summary in summaries if int(summary["partial_span"]) > 0) + group_max_counter = Counter( + summary["group_max_version"] for summary in summaries if summary["group_max_version"] is not None + ) + source_counter = Counter() + staleness_source_counter = Counter() + for summary in summaries: + source_counter.update(summary["source_counts"]) + staleness_source_counter.update(summary["staleness_source_counts"]) + + print( + "[fully_async_debug] " + f"rollout_id={rollout_id}, current_policy_version={current_policy_version}, " + f"drained_groups={drained_group_count}, selected_groups={len(groups)}, " + f"leftover_completed_groups={leftover_group_count}, " + f"stale_samples_in_selected={stale_sample_count}, " + f"stale_trajectories_in_selected={stale_trajectory_count}, " + f"partial_groups_in_selected={partial_group_count}, " + f"group_max_versions={dict(sorted(group_max_counter.items()))}, " + f"version_sources={dict(sorted(source_counter.items()))}, " + f"staleness_version_sources={dict(sorted(staleness_source_counter.items()))}", + flush=True, + ) + + for summary in summaries[:3]: + print( + "[fully_async_debug] " + f"sample_id={summary['sample_id']}, group_size={summary['group_size']}, " + f"group_min_version={summary['group_min_version']}, " + f"group_max_version={summary['group_max_version']}, " + f"stale_sample={summary['stale_sample']}, " + f"stale_trajectory_count={summary['stale_trajectory_count']}, " + f"partial_span={summary['partial_span']}, " + f"source_counts={summary['source_counts']}, " + f"staleness_source_counts={summary['staleness_source_counts']}, " + f"trajectories={summary['trajectory_summaries']}", + flush=True, + ) + + def get_global_worker(args, data_buffer): - """Get or create global worker""" global _global_worker with _worker_lock: if _global_worker is None or not _global_worker.worker_thread.is_alive(): @@ -25,8 +237,15 @@ def get_global_worker(args, data_buffer): return _global_worker +def get_existing_worker(): + global _global_worker + with _worker_lock: + if _global_worker is None or not _global_worker.worker_thread.is_alive(): + return None + return _global_worker + + def stop_global_worker(): - """Stop global worker""" global _global_worker with _worker_lock: if _global_worker is not None: @@ -34,180 +253,727 @@ def stop_global_worker(): _global_worker = None +@dataclass +class CompletedSampleRecord: + sample_id: int | None + group: list[Sample] + policy_version: int | None + + +class _CompletedStoreAdapter: + def __init__(self, worker: "AsyncRolloutWorker"): + self.worker = worker + + def put(self, item: tuple[int | None, list[Sample]]): + sample_id, group = item + self.worker._put_completed_sample(sample_id, group) + + def put_nowait(self, item: tuple[int | None, list[Sample]]): + self.put(item) + + def qsize(self) -> int: + return self.worker.get_queue_size() + + class AsyncRolloutWorker: """ - Simplified asynchronous rollout worker, using threads instead of processes - Supports continuous running, independent of rollout function lifecycle + Background rollout worker with weight-sync hooks. + + Compared with the original example, this version adds: + - pause / resume around parameter synchronization + - optional partial rollout recycling before weight updates + - stale backlog accounting controlled by `staleness_threshold` """ def __init__(self, args, data_buffer, concurrency=10): self.args = args - self.data_buffer = data_buffer # Directly save data_buffer reference + self.data_buffer = data_buffer self.concurrency = concurrency self.running = True - self.output_queue = queue.Queue(maxsize=1000) # Continuous output queue + self.max_stale_samples = _derive_max_stale_samples(args) + self.buffer_policy = getattr(args, "fully_async_buffer_policy", "legacy_backpressure") + self.version_window = max(0, getattr(args, "fully_async_version_window", 1)) + self.eviction_policy = getattr(args, "fully_async_eviction_policy", "drop_oldest_version") + self.max_completed_samples = _derive_max_completed_samples(args, self.max_stale_samples) + self.completed_lock = threading.Lock() + self.completed_records: list[CompletedSampleRecord] = [] + self.output_queue = _CompletedStoreAdapter(self) self.worker_thread = None self.state = GenerateState(args) + self.loop = None + self.task_lock: asyncio.Lock | None = None + + self.max_concurrent_tasks = self.args.rollout_batch_size + self.sample_id_counter = -1 + self.task_sample_ids: dict[asyncio.Task, int | None] = {} + self.policy_version = getattr(args, "current_policy_version", 0) + self.stale_samples_processed = 0 + self.stale_trajectory_processed = 0 + self.consumed_samples = 0 + self.recycled_samples = 0 + self.dropped_samples = 0 + self.evicted_samples = 0 + self.evicted_by_version = 0 + self._window_total_samples = 0 + self._window_partial_samples = 0 + self._window_max_partial_span = 0 + + self.control_lock = threading.Lock() + self.pause_requested = False + # Staleness budget for the current version window. It is reset from the + # outstanding old-version snapshot on weight updates, then grows as new + # samples are pulled under the current version. Trainer consumption + # refunds this budget, while recycled partial samples remain counted. + self.stale_sample_ids: set[int] = set() + self.recycled_sample_ids: set[int] = set() + + def _uses_window_evict_policy(self) -> bool: + return self.buffer_policy == "window_evict" + + def _version_window_bounds(self, current_policy_version: int | None = None) -> tuple[int | None, int | None]: + if not self._uses_window_evict_policy(): + return None, None + max_policy_version = self.policy_version if current_policy_version is None else current_policy_version + return max_policy_version - self.version_window, max_policy_version + + def _extract_group_policy_version(self, group: list[Sample]) -> int | None: + group_max_version = None + for sample in group: + versions = _extract_trajectory_versions(sample) + if not versions: + continue + sample_max_version = max(versions) + if group_max_version is None: + group_max_version = sample_max_version + else: + group_max_version = max(group_max_version, sample_max_version) + return group_max_version + + def _snapshot_completed_records(self) -> list[CompletedSampleRecord]: + with self.completed_lock: + return list(self.completed_records) + + def _drop_sample_tracking_locked(self, sample_id: int | None): + if sample_id is None: + return + self.stale_sample_ids.discard(sample_id) + self.recycled_sample_ids.discard(sample_id) + + def _record_evicted_records_locked( + self, + records: list[CompletedSampleRecord], + *, + version_eviction: bool, + ) -> None: + if not records: + return + with self.control_lock: + self.evicted_samples += len(records) + if version_eviction: + self.evicted_by_version += len(records) + for record in records: + self._drop_sample_tracking_locked(record.sample_id) + + def _evict_records_outside_window_locked(self, current_policy_version: int | None = None) -> int: + if not self._uses_window_evict_policy(): + return 0 + min_policy_version, max_policy_version = self._version_window_bounds(current_policy_version) + retained_records = [] + evicted_records = [] + for record in self.completed_records: + if ( + record.policy_version is not None + and min_policy_version is not None + and max_policy_version is not None + and (record.policy_version < min_policy_version or record.policy_version > max_policy_version) + ): + evicted_records.append(record) + else: + retained_records.append(record) + if evicted_records: + self.completed_records = retained_records + self._record_evicted_records_locked(evicted_records, version_eviction=True) + return len(evicted_records) + + def _select_overflow_eviction_index_locked(self) -> int: + if self.eviction_policy == "drop_oldest_fifo": + return 0 + oldest_index = 0 + oldest_version = None + for idx, record in enumerate(self.completed_records): + candidate_version = -1 if record.policy_version is None else record.policy_version + if oldest_version is None or candidate_version < oldest_version: + oldest_index = idx + oldest_version = candidate_version + return oldest_index + + def _trim_completed_records_locked(self) -> int: + if self.max_completed_samples <= 0: + return 0 + trimmed_records = [] + while len(self.completed_records) > self.max_completed_samples: + eviction_index = self._select_overflow_eviction_index_locked() + trimmed_records.append(self.completed_records.pop(eviction_index)) + if trimmed_records: + self._record_evicted_records_locked(trimmed_records, version_eviction=False) + return len(trimmed_records) + + def _put_completed_sample(self, sample_id: int | None, group: list[Sample]) -> None: + record = CompletedSampleRecord( + sample_id=sample_id, + group=group, + policy_version=self._extract_group_policy_version(group), + ) + with self.completed_lock: + if not self._uses_window_evict_policy() and len(self.completed_records) >= self.max_completed_samples: + with self.control_lock: + self.dropped_samples += 1 + raise queue.Full + self.completed_records.append(record) + if self._uses_window_evict_policy(): + self._evict_records_outside_window_locked(self.policy_version) + self._trim_completed_records_locked() + + def _set_pause_requested(self, value: bool): + with self.control_lock: + self.pause_requested = value + + def _is_pause_requested(self) -> bool: + with self.control_lock: + return self.pause_requested + + def _set_stale_sample_ids(self, sample_ids: Iterable[int | None]): + with self.control_lock: + self.stale_sample_ids = {sample_id for sample_id in sample_ids if sample_id is not None} + + def _add_stale_sample_id(self, sample_id: int | None): + if sample_id is None: + return + with self.control_lock: + self.stale_sample_ids.add(sample_id) + + def _add_recycled_sample_ids(self, sample_ids: Iterable[int | None]): + with self.control_lock: + self.recycled_sample_ids.update(sample_id for sample_id in sample_ids if sample_id is not None) + + def _discard_recycled_sample_id(self, sample_id: int | None): + if sample_id is None: + return + with self.control_lock: + self.recycled_sample_ids.discard(sample_id) + + def _snapshot_recycled_sample_ids(self) -> set[int]: + with self.control_lock: + return set(self.recycled_sample_ids) + + def _get_stale_sample_count_locked(self) -> int: + return len(self.stale_sample_ids | self.recycled_sample_ids) + + def get_stale_sample_count(self) -> int: + with self.control_lock: + return self._get_stale_sample_count_locked() + + def _should_pause_for_staleness(self) -> bool: + if self._uses_window_evict_policy(): + return False + return self.max_stale_samples is not None and self.get_stale_sample_count() >= self.max_stale_samples + + def _buffered_sample_ids(self) -> list[int | None]: + return [record.sample_id for record in self._snapshot_completed_records()] + + def _current_stale_sample_ids(self) -> set[int]: + return ( + {sample_id for sample_id in self._buffered_sample_ids() if sample_id is not None} + | {sample_id for sample_id in self.task_sample_ids.values() if sample_id is not None} + | self._snapshot_recycled_sample_ids() + ) + + def _mark_sample_consumed(self, sample_id: int | None): + if sample_id is None: + return + with self.control_lock: + self.consumed_samples += 1 + self._drop_sample_tracking_locked(sample_id) + + def _mark_sample_recycled(self, sample_id: int | None): + if sample_id is None: + return + with self.control_lock: + self.stale_sample_ids.discard(sample_id) + self.recycled_sample_ids.add(sample_id) + self.recycled_samples += 1 + + def _enqueue_sample(self, sample_id: int | None, group: list[Sample]) -> bool: + try: + self.output_queue.put_nowait((sample_id, group)) + return True + except queue.Full: + print(f"WARNING: output queue full, dropping sample {sample_id}") + return False + + def _collect_task_result(self, task: asyncio.Task) -> bool: + sample_id = self.task_sample_ids.pop(task, None) + try: + group = task.result() + except Exception as exc: + print(f"Task failed with exception: {exc}") + return False + if sample_id is None: + sample_id = _extract_sample_id(group) + self._discard_recycled_sample_id(sample_id) + self._enqueue_sample(sample_id, group) + return True + + def _annotate_sample(self, sample_id: int | None, group: list[Sample]): + current_rollout_id = getattr(self.args, "current_rollout_id", -1) + for sample in group: + schedule_versions = sample.metadata.setdefault("fully_async_schedule_versions", []) + if not schedule_versions or schedule_versions[-1] != self.policy_version: + schedule_versions.append(self.policy_version) + if sample_id is not None: + sample.metadata["fully_async_sample_id"] = sample_id + sample.metadata["fully_async_group_id"] = sample_id + sample.metadata["policy_version"] = self.policy_version + sample.metadata.setdefault("start_rollout_id", current_rollout_id) + + async def _push_finished_samples_to_output_queue(self): + assert self.task_lock is not None + done_tasks = {task for task in self.state.pendings if task.done()} + for task in done_tasks: + self.state.pendings.remove(task) + self._collect_task_result(task) + + async def _wait_for_all_active_tasks(self): + while self.state.pendings: + done_tasks, pending = await asyncio.wait(self.state.pendings, return_when=asyncio.FIRST_COMPLETED) + self.state.pendings = pending + for task in done_tasks: + self._collect_task_result(task) + + def _snapshot_completed_store_metrics(self) -> dict[str, int]: + records = self._snapshot_completed_records() + min_policy_version, max_policy_version = self._version_window_bounds() + policy_versions = [record.policy_version for record in records if record.policy_version is not None] + eligible_samples = 0 + for record in records: + if ( + not self._uses_window_evict_policy() + or record.policy_version is None + or ( + min_policy_version is not None + and max_policy_version is not None + and min_policy_version <= record.policy_version <= max_policy_version + ) + ): + eligible_samples += 1 + return { + "fully_async/window/completed_store_size": len(records), + "fully_async/window/eligible_samples": eligible_samples, + "fully_async/window/version_span": (max(policy_versions) - min(policy_versions) if policy_versions else 0), + } + + def _reset_interval_metrics_locked(self) -> None: + self.stale_samples_processed = 0 + self.stale_trajectory_processed = 0 + self.consumed_samples = 0 + self.recycled_samples = 0 + self.dropped_samples = 0 + self.evicted_samples = 0 + self.evicted_by_version = 0 + self._window_total_samples = 0 + self._window_partial_samples = 0 + self._window_max_partial_span = 0 + + def _has_interval_metrics(self) -> bool: + with self.control_lock: + return any( + [ + self.stale_samples_processed, + self.stale_trajectory_processed, + self.consumed_samples, + self.recycled_samples, + self.dropped_samples, + self.evicted_samples, + self.evicted_by_version, + self._window_total_samples, + self._window_partial_samples, + self._window_max_partial_span, + ] + ) + + def _snapshot_processed_metrics(self) -> dict[str, float | int]: + with self.control_lock: + total = self._window_total_samples + partial = self._window_partial_samples + metrics = { + "fully_async/count/stale_samples_processed": self.stale_samples_processed, + "fully_async/count/stale_trajectory_processed": self.stale_trajectory_processed, + "fully_async/count/consumed_samples": self.consumed_samples, + "fully_async/count/recycled_samples": self.recycled_samples, + "fully_async/count/dropped_samples": self.dropped_samples, + "fully_async/partial/total_partial_num": partial, + "fully_async/partial/partial_ratio": partial / total if total else 0.0, + "fully_async/partial/max_partial_span": self._window_max_partial_span, + "fully_async/window/evicted_samples": self.evicted_samples, + "fully_async/window/evicted_by_version": self.evicted_by_version, + } + metrics.update(self._snapshot_completed_store_metrics()) + return metrics + + def record_processed_samples(self, groups: list[list[Sample]]) -> None: + current_policy_version = self.policy_version + with self.control_lock: + for group in groups: + trajectory_mins: list[int] = [] + trajectory_maxs: list[int] = [] + stale_trajectory_count = 0 + + for sample in group: + versions = _extract_trajectory_versions(sample) + if not versions: + continue + trajectory_mins.append(min(versions)) + trajectory_maxs.append(max(versions)) + if current_policy_version - max(versions) >= 1: + stale_trajectory_count += 1 + + sample_min_version = min(trajectory_mins) if trajectory_mins else None + sample_max_version = max(trajectory_maxs) if trajectory_maxs else None + if sample_max_version is not None and current_policy_version - sample_max_version >= 1: + self.stale_samples_processed += 1 + self.stale_trajectory_processed += stale_trajectory_count + + partial_span = 0 + if sample_min_version is not None and sample_max_version is not None: + partial_span = max(0, sample_max_version - sample_min_version) + + self._window_total_samples += 1 + if partial_span > 0: + self._window_partial_samples += 1 + self._window_max_partial_span = max(self._window_max_partial_span, partial_span) + + async def _prepare_for_weight_update_async(self, policy_version: int): + assert self.task_lock is not None + async with self.task_lock: + await self._push_finished_samples_to_output_queue() + active_samples_before = len(self.state.pendings) + + if active_samples_before == 0: + stale_samples = len(self._current_stale_sample_ids()) + return { + "policy_version": policy_version, + "active_samples": 0, + "stale_samples": stale_samples, + } + + if self.args.partial_rollout: + aborted_groups = await abort(self.args, getattr(self.args, "current_rollout_id", -1)) + if aborted_groups: + self.data_buffer.add_samples(aborted_groups) + self._add_recycled_sample_ids(_extract_sample_id(group) for group in aborted_groups) + self.task_sample_ids.clear() + self.state.reset() + else: + await self._wait_for_all_active_tasks() + + stale_samples = len(self._current_stale_sample_ids()) + return { + "policy_version": policy_version, + "active_samples": active_samples_before, + "stale_samples": stale_samples, + } + + def _evict_completed_records_outside_window(self, current_policy_version: int | None = None) -> int: + with self.completed_lock: + return self._evict_records_outside_window_locked(current_policy_version) + + async def _finish_weight_update_async(self, policy_version: int): + assert self.task_lock is not None + async with self.task_lock: + await self._push_finished_samples_to_output_queue() + with self.control_lock: + self.policy_version = policy_version + self.state.aborted = False + self._evict_completed_records_outside_window(policy_version) + self._set_stale_sample_ids(self._current_stale_sample_ids()) + self._set_pause_requested(False) + interval_metrics = self._snapshot_processed_metrics() + with self.control_lock: + self._reset_interval_metrics_locked() + return { + "policy_version": policy_version, + "stale_samples": self.get_stale_sample_count(), + "max_stale_samples": self.max_stale_samples, + **interval_metrics, + } + + async def run_eval(self, args, rollout_id): + async with self.task_lock: + await self._push_finished_samples_to_output_queue() + await self._wait_for_all_active_tasks() + # Temporarily reduce semaphore to prevent OOM during eval. + # eval_rollout creates per-sample tasks (prompt × n_samples_per_eval_prompt) + # that each acquire the semaphore individually. Unlike training where + # groups start/finish at staggered times, eval tasks launch simultaneously + # and generate very long sequences, so we need a much lower limit. + eval_concurrency = self.max_concurrent_tasks * self.args.n_samples_per_prompt + + original_semaphore = self.state.semaphore + self.state.semaphore = asyncio.Semaphore(eval_concurrency) + try: + output, _ = await eval_rollout(args, rollout_id) + finally: + self.state.semaphore = original_semaphore + return output + + async def _shutdown_async(self): + assert self.task_lock is not None + async with self.task_lock: + await self._push_finished_samples_to_output_queue() + if self.state.pendings: + try: + await abort(self.args, getattr(self.args, "current_rollout_id", -1)) + except Exception as exc: + print(f"Failed to abort pending rollout tasks during shutdown: {exc}") + self.task_sample_ids.clear() + self.state.reset() async def continuous_worker_loop(self): - """Continuous work loop - constantly get data from data_buffer and process""" print("Continuous async rollout worker started") - - active_tasks = set() - max_concurrent_tasks = self.args.rollout_batch_size - group_id_counter = 0 + self.loop = asyncio.get_running_loop() + self.task_lock = asyncio.Lock() while self.running: try: - # Clean up completed tasks - if active_tasks: - done_tasks = {task for task in active_tasks if task.done()} - for task in done_tasks: - try: - task.result() # Results are already handled in callbacks - except Exception as e: - print(f"Task failed with exception: {e}") - active_tasks -= done_tasks - - # If active task count hasn't reached limit, try to get new data and start tasks - while len(active_tasks) < max_concurrent_tasks and self.running: - samples = self.data_buffer.get_samples(1) + async with self.task_lock: + await self._push_finished_samples_to_output_queue() - for group in samples: - group_id = group_id_counter - group_id_counter += 1 - - # Create new async task - task = asyncio.create_task( - generate_and_rm_group( - self.args, - group, - sampling_params=self.state.sampling_params.copy(), - evaluation=False, - ) - ) + if self._is_pause_requested() or self._should_pause_for_staleness(): + await asyncio.sleep(0.05) + continue - # Add completion callback - def make_callback(gid): - def task_done_callback(done_task): - result = done_task.result() - self.output_queue.put((gid, result)) + while ( + len(self.state.pendings) < self.max_concurrent_tasks + and self.running + and not self._is_pause_requested() + and not self._should_pause_for_staleness() + ): + samples = self.data_buffer.get_samples(1) + if not samples: + break - return task_done_callback + group = samples[0] + sample_id = _extract_sample_id(group) + if sample_id is None: + sample_id = self.sample_id_counter + self.sample_id_counter -= 1 + self._add_stale_sample_id(sample_id) + self._annotate_sample(sample_id, group) - task.add_done_callback(make_callback(group_id)) - active_tasks.add(task) - break + task = asyncio.create_task( + generate_and_rm_group( + self.args, + group, + sampling_params=self.state.sampling_params.copy(), + evaluation=False, + ) + ) + self.state.pendings.add(task) + self.task_sample_ids[task] = sample_id + self._discard_recycled_sample_id(sample_id) - # Brief sleep to avoid busy waiting - await asyncio.sleep(1) + await asyncio.sleep(0.01) - except Exception as e: - print(f"Error in continuous worker loop: {e}") - await asyncio.sleep(1) + except Exception as exc: + print(f"Error in continuous worker loop: {exc}") + await asyncio.sleep(0.1) - if active_tasks: - print(f"Waiting for {len(active_tasks)} continuous tasks to complete...") - await asyncio.wait(active_tasks) + if self.task_lock is not None: + async with self.task_lock: + await self._wait_for_all_active_tasks() print("Continuous async rollout worker stopped") def worker_thread_func(self): - """Worker function running in independent thread""" asyncio.run(self.continuous_worker_loop()) def start(self): - """Start continuous work mode""" if self.worker_thread is None or not self.worker_thread.is_alive(): self.worker_thread = threading.Thread(target=self.worker_thread_func, daemon=True) self.worker_thread.start() - print("Started continuous async worker thread") + print( + "Started continuous async worker thread " + f"(max_stale_samples={self.max_stale_samples}, partial_rollout={self.args.partial_rollout})" + ) def stop(self): - """Stop worker thread""" + self._set_pause_requested(True) + if self.loop is not None and self.worker_thread and self.worker_thread.is_alive(): + try: + future = asyncio.run_coroutine_threadsafe(self._shutdown_async(), self.loop) + future.result(timeout=30) + except Exception as exc: + print(f"Failed to shutdown async worker cleanly: {exc}") self.running = False if self.worker_thread and self.worker_thread.is_alive(): self.worker_thread.join(timeout=5) print("Stopped async worker thread") - def get_completed_groups(self) -> list[tuple]: - """Get completed sample groups""" - completed = [] - while True: - try: - result = self.output_queue.get_nowait() - completed.append(result) - except queue.Empty: - break - return completed + def _needs_weight_update_sync(self) -> bool: + return self.max_stale_samples is not None or self.args.partial_rollout or self._uses_window_evict_policy() + + def before_weight_update(self, policy_version: int): + if not self._needs_weight_update_sync(): + return { + "policy_version": policy_version, + "active_samples": len(self.state.pendings) if self.state else 0, + "stale_samples": self.get_stale_sample_count(), + "skipped_sync": True, + } + self._set_pause_requested(True) + if self.loop is None: + stale_samples = len(self._current_stale_sample_ids()) + return { + "policy_version": policy_version, + "active_samples": 0, + "stale_samples": stale_samples, + } + future = asyncio.run_coroutine_threadsafe(self._prepare_for_weight_update_async(policy_version), self.loop) + return future.result() + + def after_weight_update(self, policy_version: int): + if not self._needs_weight_update_sync(): + with self.control_lock: + self.policy_version = policy_version + interval_metrics = self._snapshot_processed_metrics() + with self.control_lock: + self._reset_interval_metrics_locked() + return { + "policy_version": policy_version, + "stale_samples": self.get_stale_sample_count(), + "max_stale_samples": self.max_stale_samples, + "skipped_sync": True, + **interval_metrics, + } + if self.loop is None: + with self.control_lock: + self.policy_version = policy_version + self.state.aborted = False + self._evict_completed_records_outside_window(policy_version) + self._set_stale_sample_ids(self._current_stale_sample_ids()) + self._set_pause_requested(False) + interval_metrics = self._snapshot_processed_metrics() + with self.control_lock: + self._reset_interval_metrics_locked() + stale_samples = self.get_stale_sample_count() + return { + "policy_version": policy_version, + "stale_samples": stale_samples, + "max_stale_samples": self.max_stale_samples, + **interval_metrics, + } + future = asyncio.run_coroutine_threadsafe(self._finish_weight_update_async(policy_version), self.loop) + return future.result() + + def get_completed_samples( + self, + limit: int | None = None, + *, + current_policy_version: int | None = None, + ) -> list[tuple[int | None, list[Sample]]]: + if limit is not None and limit <= 0: + return [] + with self.completed_lock: + self._evict_records_outside_window_locked(current_policy_version) + if limit is None: + selected_records = list(self.completed_records) + self.completed_records.clear() + else: + selected_records = self.completed_records[:limit] + del self.completed_records[:limit] + return [(record.sample_id, record.group) for record in selected_records] + + def get_completed_groups(self) -> list[tuple[int | None, list[Sample]]]: + return self.get_completed_samples() def get_queue_size(self) -> int: - """Get current output queue size""" - return self.output_queue.qsize() + with self.completed_lock: + return len(self.completed_records) -async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[list[Sample]]: - """ - Simplified asynchronous rollout generation - using global continuous worker - """ +def before_weight_update(args, data_buffer, policy_version: int): + worker = get_existing_worker() + if worker is None: + return { + "policy_version": policy_version, + "active_samples": 0, + "stale_samples": 0, + } + return worker.before_weight_update(policy_version) + + +def after_weight_update(args, data_buffer, policy_version: int): + worker = get_existing_worker() + if worker is None: + max_stale_samples = _derive_max_stale_samples(args) + return { + "policy_version": policy_version, + "stale_samples": 0, + "max_stale_samples": max_stale_samples, + } + return worker.after_weight_update(policy_version) + + +async def generate_rollout_async(args, rollout_id: int, data_buffer) -> RolloutFnTrainOutput: assert args.rollout_global_dataset - # Get global worker, which will run continuously worker = get_global_worker(args, data_buffer) - - # Simplified: directly use rollout_batch_size as target - target_data_size = args.rollout_batch_size + target_sample_count = args.rollout_batch_size data = [] - completed_groups = {} + completed_samples = {} + drained_group_count = 0 do_print = True - print(f"Starting async rollout generation for {target_data_size} groups") - print(f"Global worker queue size: {worker.get_queue_size()}") + print(f"Starting async rollout generation for {target_sample_count} samples") + print("Global worker queue size: " f"{worker.get_queue_size()}, stale_samples={worker.get_stale_sample_count()}") - # Main loop: collect results from global worker's output queue start_time = time.time() last_progress_time = start_time - no_progress_timeout = 30.0 # Warn if no progress for 30 seconds + no_progress_timeout = 30.0 - while len(data) < target_data_size: - # Collect completed results - completed = worker.get_completed_groups() + while len(data) < target_sample_count: + pending_capacity = max(0, target_sample_count - len(data) - len(completed_samples)) + completed = worker.get_completed_samples(limit=pending_capacity, current_policy_version=worker.policy_version) + drained_group_count += len(completed) made_progress = False - for group_id, group in completed: - completed_groups[group_id] = group + for sample_id, group in completed: + completed_samples[sample_id] = group made_progress = True if made_progress: last_progress_time = time.time() - # Process completed groups in order (try to maintain order, but not strict requirement) processed_any = False - - # Process all available completed groups - available_ids = list(completed_groups.keys()) - for group_id in available_ids: - if len(data) >= target_data_size: + for sample_id in list(completed_samples.keys()): + if len(data) >= target_sample_count: break - group = completed_groups.pop(group_id) + group = completed_samples.pop(sample_id) - # If any sample in the group was aborted, return the whole group to the data buffer - # and do not forward it to the training engine. try: - any_aborted = any([sample.status == Sample.Status.ABORTED for sample in group]) + any_aborted = any(sample.status == Sample.Status.ABORTED for sample in group) except Exception: any_aborted = False if any_aborted: try: - # add back to buffer so it can be retried or handled by buffer policy data_buffer.add_samples([group]) - print(f"Returned aborted group {group_id} to data buffer", flush=True) - except Exception as e: - print(f"Failed to return aborted group {group_id} to buffer: {e}", flush=True) - # don't count as processed for training + worker._mark_sample_recycled(sample_id) + print(f"Returned aborted sample {sample_id} to data buffer", flush=True) + except Exception as exc: + print(f"Failed to return aborted sample {sample_id} to buffer: {exc}", flush=True) continue if do_print: @@ -218,26 +984,29 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis ) do_print = False - # Simplified: directly add samples, no filters used data.append(group) + worker._mark_sample_consumed(sample_id) processed_any = True - # Check progress current_time = time.time() if current_time - last_progress_time > no_progress_timeout: print( f"Warning: No progress for {no_progress_timeout}s. " f"Queue size: {worker.get_queue_size()}, " - f"Collected: {len(data)}/{target_data_size}" + f"Stale samples: {worker.get_stale_sample_count()}, " + f"Collected: {len(data)}/{target_sample_count}" ) last_progress_time = current_time - # If no results were processed, brief sleep to avoid busy waiting if not processed_any: await asyncio.sleep(0.01) duration = time.time() - start_time - print(f"Rollout completed in {duration:.2f}s! Global worker queue size: {worker.get_queue_size()}") + print( + f"Rollout completed in {duration:.2f}s! " + f"Global worker queue size: {worker.get_queue_size()}, " + f"stale_samples={worker.get_stale_sample_count()}" + ) if data: print( @@ -247,17 +1016,46 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis ) data = sorted(data, key=lambda group: group[0].index) - return data + _log_processed_group_debug( + args, + data, + current_policy_version=worker.policy_version, + rollout_id=rollout_id, + drained_group_count=drained_group_count, + leftover_group_count=len(completed_samples), + ) + worker.record_processed_samples(data) + return RolloutFnTrainOutput(samples=data, metrics={}) + + +def flush_metrics(args, data_buffer): + worker = get_existing_worker() + if worker is None: + return None + if not worker._has_interval_metrics(): + return None + return worker._snapshot_processed_metrics() + + +def shutdown_worker(args, data_buffer): + stop_global_worker() def generate_rollout_fully_async(args, rollout_id, data_buffer, evaluation=False): if evaluation: - raise ValueError("Evaluation mode not supported in simple async rollout") - - completed_samples = run(generate_rollout_async(args, rollout_id, data_buffer)) - return completed_samples + worker = get_existing_worker() + if worker is not None and worker.loop is not None: + worker._set_pause_requested(True) + try: + future = asyncio.run_coroutine_threadsafe(worker.run_eval(args, rollout_id), worker.loop) + output = future.result() + finally: + worker._set_pause_requested(False) + else: + output, _ = run(eval_rollout(args, rollout_id)) + return output + return run(generate_rollout_async(args, rollout_id, data_buffer)) -# Register exit cleanup function atexit.register(stop_global_worker) diff --git a/examples/fully_async/run-qwen3-4b-fully_async.sh b/examples/fully_async/run-qwen3-4b-fully_async.sh deleted file mode 100644 index 778f58b2ad..0000000000 --- a/examples/fully_async/run-qwen3-4b-fully_async.sh +++ /dev/null @@ -1,139 +0,0 @@ -#!/bin/bash - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -ex - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-4B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load /root/Qwen3-4B_torch_dist - --load /root/Qwen3-4B_slime/ - --save /root/Qwen3-4B_slime/ - --save-interval 20 -) - -PROMPT_SET=/path/to/dapo-math-17k.jsonl - -ROLLOUT_ARGS=( - --rollout-function-path fully_async_rollout.generate_rollout_fully_async - --prompt-data ${PROMPT_SET} - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - - --rm-type dapo - --reward-key score - - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - - --global-batch-size 256 - --balance-data -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 - - --use-tis -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats - -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train_async.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 4 \ - --rollout-num-gpus 4 \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} diff --git a/examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh b/examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh new file mode 100755 index 0000000000..f890e4db1c --- /dev/null +++ b/examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh @@ -0,0 +1,218 @@ +#!/bin/bash +# +# Benchmark fully async rollout on 8x H100 with Qwen3.5-4B. +# +# Usage: +# MODE=one_step_off bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh +# MODE=fully_async bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh +# MODE=window_partial bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh +# MODE=staleness_partial bash examples/fully_async/run-qwen3.5-4b-off-policy-benchmark.sh +# +# Modes: +# one_step_off - default rollout, one-step off-policy async baseline, not support partial rollout +# fully_async - fully async rollout, no staleness control, not support partial rollout +# window_partial - fully async + version-window eviction + partial rollout + mask off-policy +# staleness_partial - fully async + staleness backpressure + partial rollout + mask off-policy + +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/../../scripts/models/qwen3.5-4B.sh" + +# --- Paths (adjust to your environment) --- +HF_CHECKPOINT=${HF_CHECKPOINT:-"/root/Qwen3.5-4B"} +REF_LOAD=${REF_LOAD:-"/root/Qwen3.5-4B_torch_dist"} +LOAD_PATH=${LOAD_PATH:-"/root/Qwen3.5-4B_slime_async_${MODE}/"} +SAVE_PATH=${SAVE_PATH:-"/root/Qwen3.5-4B_slime_async_${MODE}/"} +PROMPT_SET=${PROMPT_SET:-"/root/dapo-math-17k/dapo-math-17k.jsonl"} +# EVAL_DATASET=${EVAL_DATASET:-"/root/aime-2024/aime-2024.jsonl"} +MODE=${MODE:-"one_step_off"} +FULLY_ASYNC_VERSION_WINDOW=${FULLY_ASYNC_VERSION_WINDOW:-1} +FULLY_ASYNC_MAX_COMPLETED_SAMPLES=${FULLY_ASYNC_MAX_COMPLETED_SAMPLES:-128} +FULLY_ASYNC_EVICTION_POLICY=${FULLY_ASYNC_EVICTION_POLICY:-"drop_oldest_version"} +echo "=== Running fully async benchmark: mode=${MODE} ===" + +CKPT_ARGS=( + --hf-checkpoint ${HF_CHECKPOINT} + --ref-load ${REF_LOAD} + --load ${LOAD_PATH} + --save ${SAVE_PATH} + --save-interval 20 + --no-save-optim + --no-load-optim +) + +ROLLOUT_ARGS=( + --prompt-data ${PROMPT_SET} + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 40 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 16384 + --rollout-temperature 1 + + --global-batch-size 256 + --balance-data + + --update-weights-interval 2 +) + + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 2 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + + --use-tis + --custom-config-path examples/train_infer_mismatch_helper/mis.yaml + --custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_with_cp +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-mem-fraction-static 0.9 + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) + --sglang-max-running-requests 256 +) + +WANDB_ARGS=( +# --use-wandb +# --wandb-project slime-async-release +# --wandb-group qwen3.5-4B-async-${MODE} +# --wandb-key ${WANDB_KEY} +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +# --- Fully async rollout args (shared by all fully_async-based modes) --- +FULLY_ASYNC_ROLLOUT_ARGS=( + --rollout-function-path fully_async_rollout.generate_rollout_fully_async + --fully-async-debug-version-tracking +) + +# --- Mode-specific flags --- +MODE_ARGS=() +case "${MODE}" in + one_step_off) + ;; + fully_async) + MODE_ARGS+=("${FULLY_ASYNC_ROLLOUT_ARGS[@]}") + ;; + window_partial) + MODE_ARGS+=( + "${FULLY_ASYNC_ROLLOUT_ARGS[@]}" + --fully-async-buffer-policy window_evict + --fully-async-version-window "${FULLY_ASYNC_VERSION_WINDOW}" + --fully-async-max-completed-samples "${FULLY_ASYNC_MAX_COMPLETED_SAMPLES}" + --fully-async-eviction-policy "${FULLY_ASYNC_EVICTION_POLICY}" + --partial-rollout + --mask-offpolicy-in-partial-rollout + ) + ;; + staleness_partial) + MODE_ARGS+=( + "${FULLY_ASYNC_ROLLOUT_ARGS[@]}" + --fully-async-buffer-policy legacy_backpressure + --staleness-threshold 0.5 + --partial-rollout + --mask-offpolicy-in-partial-rollout + ) + ;; + *) + echo "Unknown MODE: ${MODE}. Use one of: one_step_off, fully_async, window_partial, staleness_partial" + exit 1 + ;; +esac + +# launch ray +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats \ + --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +# 4 GPUs for training, 4 GPUs for rollout (sglang) +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train_async.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --rollout-num-gpus 4 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${MODE_ARGS[@]} \ No newline at end of file diff --git a/scripts/models/qwen3.5-4B.sh b/scripts/models/qwen3.5-4B.sh new file mode 100644 index 0000000000..18bea47c9c --- /dev/null +++ b/scripts/models/qwen3.5-4B.sh @@ -0,0 +1,27 @@ +MODEL_ARGS=( + --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" + + --disable-bias-linear + --qk-layernorm + --group-query-attention + --num-attention-heads 16 + --num-query-groups 4 + --kv-channels 256 + --num-layers 32 + --hidden-size 2560 + --ffn-hidden-size 9216 + --use-gated-attention + + --normalization RMSNorm + --apply-layernorm-1p + --position-embedding-type rope + --norm-epsilon 1e-6 + --rotary-percent 0.25 + --swiglu + --vocab-size 248320 + + --rotary-base 10000000 + + # qwen3.5 specific + --attention-output-gate +) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index d7a208753b..d4a654af1a 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -1,4 +1,5 @@ import dataclasses +import importlib import itertools import logging import multiprocessing @@ -362,6 +363,7 @@ def __init__(self, args, pg): self.generate_rollout = load_function(self.args.rollout_function_path) self.eval_generate_rollout = load_function(self.args.eval_function_path) + self.generate_rollout_module = importlib.import_module(self.generate_rollout.__module__) self.custom_reward_post_process_func = None if self.args.custom_reward_post_process_path is not None: self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) @@ -382,6 +384,8 @@ def __init__(self, args, pg): init_tracking(args, primary=False) self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() self.rollout_id = -1 + self._fully_async_log_step = 0 + self.update_runtime_state(current_rollout_id=-1, current_policy_version=0) self._health_monitors = [] if not self.args.debug_train_only and self.args.use_fault_tolerance: @@ -431,10 +435,53 @@ def _try_ci_fault_injection(self): logger.warning(f"CI Fault Injection failed: {e}") def dispose(self): + self._log_fully_async_metrics(self._call_generate_rollout_hook("flush_metrics")) + self._call_generate_rollout_hook("shutdown_worker") for monitor in self._health_monitors: monitor.stop() logging_utils.finish_tracking(self.args) + def update_runtime_state(self, **metadata): + for key, value in metadata.items(): + setattr(self.args, key, value) + if hasattr(self.data_source, "update_metadata"): + self.data_source.update_metadata(metadata) + + def _call_generate_rollout_hook(self, hook_name: str, **kwargs): + hook = getattr(self.generate_rollout_module, hook_name, None) + if hook is None: + return None + return hook(self.args, self.data_source, **kwargs) + + def _log_fully_async_metrics(self, hook_result): + if not hook_result or self.rollout_id < 0: + return + fully_async_metrics = {key: value for key, value in hook_result.items() if key.startswith("fully_async/")} + if not fully_async_metrics: + return + fully_async_step = getattr(self, "_fully_async_log_step", 0) + fully_async_metrics["fully_async/step"] = fully_async_step + logging_utils.log(self.args, fully_async_metrics, step_key="fully_async/step") + self._fully_async_log_step = fully_async_step + 1 + + def before_weight_update(self, policy_version: int): + """Called *before* weights are synced. + + ``policy_version`` is the **current** (soon-to-be-old) version. + """ + self.update_runtime_state(current_policy_version=policy_version) + return self._call_generate_rollout_hook("before_weight_update", policy_version=policy_version) + + def after_weight_update(self, policy_version: int): + """Called *after* weights are synced. + + ``policy_version`` is the **new** version just applied. + """ + self.update_runtime_state(current_policy_version=policy_version) + result = self._call_generate_rollout_hook("after_weight_update", policy_version=policy_version) + self._log_fully_async_metrics(result) + return result + @property def server(self) -> RolloutServer | None: """Default server (first model). For backward compatibility.""" @@ -479,6 +526,7 @@ def get_num_rollout_per_epoch(self): def generate(self, rollout_id): start_time = time.time() self.rollout_id = rollout_id + self.update_runtime_state(current_rollout_id=rollout_id) self.health_monitoring_resume() if self.args.ci_test and self.args.use_fault_tolerance and rollout_id >= 2: self._try_ci_fault_injection() diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index a634d1f003..29db9c6a45 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -434,6 +434,66 @@ def add_rollout_arguments(parser): default=1, help="Interval for updating the weights", ) + parser.add_argument( + "--staleness-threshold", + type=float, + default=None, + help=( + "Maximum stale backlog ratio for fully async rollout. " + "When set, fully async rollout workers will pause scheduling new samples once the " + "number of stale samples reaches approximately " + "`rollout_batch_size * update_weights_interval * (1 + staleness_threshold)`." + ), + ) + parser.add_argument( + "--fully-async-debug-version-tracking", + action="store_true", + default=False, + help=( + "Print per-batch version summaries for fully async rollout consumption. " + "Useful for debugging why stale backlog exists while stale processed counters stay at zero." + ), + ) + parser.add_argument( + "--fully-async-buffer-policy", + type=str, + choices=["legacy_backpressure", "window_evict"], + default="legacy_backpressure", + help=( + "Completed-sample buffering policy for fully async rollout. " + "`legacy_backpressure` keeps the current stale-budget pause behavior, while " + "`window_evict` keeps rollout scheduling active and evicts completed samples that fall " + "outside the configured version window." + ), + ) + parser.add_argument( + "--fully-async-version-window", + type=int, + default=1, + help=( + "Maximum policy-version distance allowed in the completed-sample window when " + "`--fully-async-buffer-policy=window_evict`." + ), + ) + parser.add_argument( + "--fully-async-max-completed-samples", + type=int, + default=None, + help=( + "Hard cap on completed fully async samples kept in memory. Defaults to the legacy queue sizing " + "heuristic when unset." + ), + ) + parser.add_argument( + "--fully-async-eviction-policy", + type=str, + choices=["drop_oldest_version", "drop_oldest_fifo"], + default="drop_oldest_version", + help=( + "Overflow eviction policy for fully async completed samples when " + "`--fully-async-buffer-policy=window_evict`." + ), + ) parser.add_argument( "--keep-old-actor", action="store_true", diff --git a/slime/utils/wandb_utils.py b/slime/utils/wandb_utils.py index 2ec859de5f..7ca564d83b 100644 --- a/slime/utils/wandb_utils.py +++ b/slime/utils/wandb_utils.py @@ -199,6 +199,11 @@ def _init_wandb_common(): wandb.define_metric("rollout/*", step_metric="rollout/step") wandb.define_metric("multi_turn/*", step_metric="rollout/step") wandb.define_metric("passrate/*", step_metric="rollout/step") + wandb.define_metric("fully_async/step") + wandb.define_metric("fully_async/*", step_metric="fully_async/step") + wandb.define_metric("fully_async/count/*", step_metric="fully_async/step") + wandb.define_metric("fully_async/partial/*", step_metric="fully_async/step") + wandb.define_metric("fully_async/window/*", step_metric="fully_async/step") wandb.define_metric("eval/step") wandb.define_metric("eval/*", step_metric="eval/step") wandb.define_metric("perf/*", step_metric="rollout/step") diff --git a/tests/test_fully_async_rollout.py b/tests/test_fully_async_rollout.py new file mode 100644 index 0000000000..316e88a832 --- /dev/null +++ b/tests/test_fully_async_rollout.py @@ -0,0 +1,799 @@ +from __future__ import annotations + +import asyncio +import importlib.util +import sys +import types +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from types import SimpleNamespace + + +def load_fully_async_module(): + sample_status = Enum("Status", ["ABORTED", "COMPLETED"]) + + @dataclass + class Sample: + group_index: int | None = None + index: int = 0 + prompt: str = "" + response: str = "" + reward: float | None = None + label: str = "" + metadata: dict = field(default_factory=dict) + weight_versions: list[str] = field(default_factory=list) + status: object = sample_status.COMPLETED + + Sample.Status = sample_status + + class FakeGenerateState: + def __init__(self, args): + self.args = args + self.sampling_params = {} + self.reset() + + def reset(self): + self.pendings = set() + self.aborted = False + + async def fake_abort(args, rollout_id): + return [] + + async def fake_eval_rollout(args, rollout_id): + return {"dummy": {"rewards": [1.0]}}, [] + + async def fake_generate_and_rm_group(args, group, sampling_params, evaluation=False): + return group + + def fake_run(coro): + raise RuntimeError("run() should not be called in this unit test") + + @dataclass + class RolloutFnTrainOutput: + samples: list + metrics: dict | None = None + + modules = { + "slime": types.ModuleType("slime"), + "slime.rollout": types.ModuleType("slime.rollout"), + "slime.rollout.base_types": types.ModuleType("slime.rollout.base_types"), + "slime.rollout.sglang_rollout": types.ModuleType("slime.rollout.sglang_rollout"), + "slime.utils": types.ModuleType("slime.utils"), + "slime.utils.async_utils": types.ModuleType("slime.utils.async_utils"), + "slime.utils.types": types.ModuleType("slime.utils.types"), + } + modules["slime.rollout.base_types"].RolloutFnTrainOutput = RolloutFnTrainOutput + modules["slime.rollout.sglang_rollout"].GenerateState = FakeGenerateState + modules["slime.rollout.sglang_rollout"].abort = fake_abort + modules["slime.rollout.sglang_rollout"].eval_rollout = fake_eval_rollout + modules["slime.rollout.sglang_rollout"].generate_and_rm_group = fake_generate_and_rm_group + modules["slime.utils.async_utils"].run = fake_run + modules["slime.utils.types"].Sample = Sample + modules["slime"].rollout = modules["slime.rollout"] + modules["slime"].utils = modules["slime.utils"] + + saved_modules = {name: sys.modules.get(name) for name in modules} + saved_test_module = sys.modules.get("fully_async_rollout_under_test") + + try: + sys.modules.update(modules) + spec = importlib.util.spec_from_file_location( + "fully_async_rollout_under_test", + Path(__file__).resolve().parents[1] / "examples" / "fully_async" / "fully_async_rollout.py", + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module, Sample, sample_status + finally: + for name, saved in saved_modules.items(): + if saved is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = saved + if saved_test_module is None: + sys.modules.pop("fully_async_rollout_under_test", None) + else: + sys.modules["fully_async_rollout_under_test"] = saved_test_module + + +def make_args(**overrides): + defaults = dict( + sglang_server_concurrency=2, + rollout_batch_size=4, + rollout_global_dataset=True, + partial_rollout=True, + fully_async_buffer_policy="legacy_backpressure", + fully_async_version_window=1, + fully_async_max_completed_samples=None, + fully_async_eviction_policy="drop_oldest_version", + staleness_threshold=None, + update_weights_interval=2, + current_policy_version=0, + current_rollout_id=0, + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def test_derive_max_stale_samples_returns_none_without_threshold(): + module, _, _ = load_fully_async_module() + assert module._derive_max_stale_samples(make_args()) is None + + +def test_derive_max_stale_samples_uses_threshold_formula(): + module, _, _ = load_fully_async_module() + args = make_args(staleness_threshold=0.5, rollout_batch_size=4, update_weights_interval=2) + assert module._derive_max_stale_samples(args) == 12 + + +def test_worker_before_after_weight_update_without_running_loop_are_safe(): + module, _, _ = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.5), data_buffer=None) + + before = worker.before_weight_update(policy_version=3) + after = worker.after_weight_update(policy_version=4) + + assert before["policy_version"] == 3 + assert before["active_samples"] == 0 + assert after["policy_version"] == 4 + assert after["stale_samples"] == 0 + assert after["max_stale_samples"] == 12 + + +def test_before_weight_update_without_running_loop_uses_current_outstanding_samples(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0), data_buffer=None) + + worker.output_queue.put( + ( + 10, + [Sample(group_index=10, index=10, metadata={"fully_async_sample_id": 10}, status=sample_status.COMPLETED)], + ) + ) + worker._add_recycled_sample_ids([11]) + + before = worker.before_weight_update(policy_version=3) + + assert worker.get_stale_sample_count() == 1 + assert before["stale_samples"] == 2 + + +def test_completed_samples_do_not_refund_staleness_budget(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, update_weights_interval=1), data_buffer=None) + + worker._set_stale_sample_ids([10, 11]) + worker.output_queue.put( + ( + 10, + [Sample(group_index=10, index=10, metadata={"fully_async_sample_id": 10}, status=sample_status.COMPLETED)], + ) + ) + worker.output_queue.put( + ( + 11, + [Sample(group_index=11, index=11, metadata={"fully_async_sample_id": 11}, status=sample_status.COMPLETED)], + ) + ) + + completed = worker.get_completed_samples() + + assert [sample_id for sample_id, _ in completed] == [10, 11] + assert worker.get_stale_sample_count() == 2 + + +def test_completed_samples_limit_keeps_full_staleness_budget(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, update_weights_interval=1), data_buffer=None) + + worker._set_stale_sample_ids([10, 11, 12]) + for sample_id in (10, 11, 12): + worker.output_queue.put( + ( + sample_id, + [ + Sample( + group_index=sample_id, + index=sample_id, + metadata={"fully_async_sample_id": sample_id}, + status=sample_status.COMPLETED, + ) + ], + ) + ) + + completed = worker.get_completed_samples(limit=2) + + assert [sample_id for sample_id, _ in completed] == [10, 11] + assert worker.get_queue_size() == 1 + assert worker.get_stale_sample_count() == 3 + + +def test_pause_for_staleness_uses_budgeted_stale_samples(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker( + make_args(staleness_threshold=0.0, update_weights_interval=1, rollout_batch_size=2), + data_buffer=None, + ) + + worker.output_queue.put( + ( + 10, + [Sample(group_index=10, index=10, metadata={"fully_async_sample_id": 10}, status=sample_status.COMPLETED)], + ) + ) + worker.output_queue.put( + ( + 11, + [Sample(group_index=11, index=11, metadata={"fully_async_sample_id": 11}, status=sample_status.COMPLETED)], + ) + ) + + assert worker.get_stale_sample_count() == 0 + assert worker._should_pause_for_staleness() is False + + worker._set_stale_sample_ids([10, 11]) + + assert worker._should_pause_for_staleness() is True + + +def test_window_evict_policy_disables_staleness_pause(): + module, _, _ = load_fully_async_module() + worker = module.AsyncRolloutWorker( + make_args( + fully_async_buffer_policy="window_evict", + staleness_threshold=0.0, + update_weights_interval=1, + rollout_batch_size=2, + ), + data_buffer=None, + ) + + worker._set_stale_sample_ids([10, 11]) + + assert worker._should_pause_for_staleness() is False + + +def test_prepare_for_weight_update_recomputes_outstanding_after_waiting_for_active_tasks(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, partial_rollout=False), data_buffer=None) + + async def run_test(): + worker.task_lock = asyncio.Lock() + group = [ + Sample(group_index=21, index=21, metadata={"fully_async_sample_id": 21}, status=sample_status.COMPLETED) + ] + task = asyncio.create_task(asyncio.sleep(0, result=group)) + worker.state.pendings.add(task) + worker.task_sample_ids[task] = 21 + + before = await worker._prepare_for_weight_update_async(policy_version=3) + + assert worker.get_stale_sample_count() == 0 + assert before["active_samples"] == 1 + assert before["stale_samples"] == 1 + + asyncio.run(run_test()) + + +def test_extract_sample_id_prefers_stable_sample_identity(): + module, Sample, _ = load_fully_async_module() + group = [ + Sample(group_index=0, metadata={"fully_async_sample_id": 0}), + Sample(group_index=0, metadata={"fully_async_sample_id": 0}), + ] + assert module._extract_sample_id(group) == 0 + + +def test_record_processed_samples_tracks_stale_and_partial_metrics(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, current_policy_version=3), data_buffer=None) + + stale_sample = [ + Sample( + group_index=100, + metadata={"fully_async_sample_id": 100, "fully_async_schedule_versions": [1]}, + weight_versions=["1"], + status=sample_status.COMPLETED, + ), + Sample( + group_index=100, + metadata={"fully_async_sample_id": 100, "fully_async_schedule_versions": [1]}, + weight_versions=["1"], + status=sample_status.COMPLETED, + ), + ] + partial_sample = [ + Sample( + group_index=101, + metadata={"fully_async_sample_id": 101, "fully_async_schedule_versions": [2, 3]}, + weight_versions=["2"], + status=sample_status.COMPLETED, + ), + Sample( + group_index=101, + metadata={"fully_async_sample_id": 101, "fully_async_schedule_versions": [2, 3]}, + weight_versions=["3"], + status=sample_status.COMPLETED, + ), + ] + + worker.record_processed_samples([stale_sample, partial_sample]) + metrics = worker._snapshot_processed_metrics() + + assert metrics["fully_async/count/stale_samples_processed"] == 1 + assert metrics["fully_async/count/stale_trajectory_processed"] == 2 + assert metrics["fully_async/partial/total_partial_num"] == 1 + assert metrics["fully_async/partial/partial_ratio"] == 0.5 + assert metrics["fully_async/partial/max_partial_span"] == 1 + + +def test_summarize_processed_group_reports_sources_and_staleness(): + module, Sample, sample_status = load_fully_async_module() + + group = [ + Sample( + group_index=123, + metadata={"fully_async_sample_id": 123, "fully_async_schedule_versions": [1, 2], "policy_version": 2}, + weight_versions=["1"], + status=sample_status.COMPLETED, + ), + Sample( + group_index=123, + metadata={"fully_async_sample_id": 123, "fully_async_schedule_versions": [2], "policy_version": 2}, + weight_versions=[], + status=sample_status.COMPLETED, + ), + ] + + summary = module._summarize_processed_group(group, current_policy_version=3) + + assert summary["sample_id"] == 123 + assert summary["group_min_version"] == 1 + assert summary["group_max_version"] == 2 + assert summary["stale_sample"] is True + assert summary["stale_trajectory_count"] == 2 + assert summary["partial_span"] == 1 + assert summary["source_counts"] == {"weight_versions": 1, "fully_async_schedule_versions": 1} + assert summary["staleness_source_counts"] == {"fully_async_schedule_versions": 2} + + +def test_record_processed_samples_aligns_weight_version_fallback_to_policy_version(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, current_policy_version=1), data_buffer=None) + + worker.record_processed_samples( + [ + [ + Sample( + group_index=150, + metadata={"fully_async_sample_id": 150}, + weight_versions=["2"], + status=sample_status.COMPLETED, + ) + ] + ] + ) + metrics = worker._snapshot_processed_metrics() + + assert metrics["fully_async/count/stale_samples_processed"] == 0 + assert metrics["fully_async/count/stale_trajectory_processed"] == 0 + + +def test_after_weight_update_resets_partial_window_metrics(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, current_policy_version=2), data_buffer=None) + + worker.record_processed_samples( + [ + [ + Sample( + group_index=200, + metadata={"fully_async_sample_id": 200, "fully_async_schedule_versions": [1, 2]}, + weight_versions=["1", "2"], + status=sample_status.COMPLETED, + ) + ] + ] + ) + + after = worker.after_weight_update(policy_version=3) + + assert after["fully_async/partial/total_partial_num"] == 1 + assert after["fully_async/partial/partial_ratio"] == 1.0 + assert after["fully_async/partial/max_partial_span"] == 1 + + worker.record_processed_samples( + [ + [ + Sample( + group_index=201, + metadata={"fully_async_sample_id": 201, "fully_async_schedule_versions": [3]}, + weight_versions=["3"], + status=sample_status.COMPLETED, + ) + ] + ] + ) + fresh_metrics = worker._snapshot_processed_metrics() + assert fresh_metrics["fully_async/partial/total_partial_num"] == 0 + assert fresh_metrics["fully_async/partial/partial_ratio"] == 0.0 + assert fresh_metrics["fully_async/partial/max_partial_span"] == 0 + + +def test_flush_metrics_returns_metrics_only_when_window_has_data(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, current_policy_version=2), data_buffer=None) + + worker.record_processed_samples( + [ + [ + Sample( + group_index=250, + metadata={"fully_async_sample_id": 250, "fully_async_schedule_versions": [1, 2]}, + weight_versions=["1", "2"], + status=sample_status.COMPLETED, + ) + ] + ] + ) + worker.worker_thread = SimpleNamespace(is_alive=lambda: True) + + saved_global_worker = module._global_worker + module._global_worker = worker + try: + metrics = module.flush_metrics(args=None, data_buffer=None) + assert metrics["fully_async/count/stale_samples_processed"] == 0 + assert metrics["fully_async/count/stale_trajectory_processed"] == 0 + assert metrics["fully_async/partial/total_partial_num"] == 1 + assert metrics["fully_async/partial/partial_ratio"] == 1.0 + assert metrics["fully_async/partial/max_partial_span"] == 1 + + worker.after_weight_update(policy_version=3) + assert module.flush_metrics(args=None, data_buffer=None) is None + finally: + module._global_worker = saved_global_worker + + +def test_shutdown_worker_stops_and_clears_global_worker(): + module, _, _ = load_fully_async_module() + stop_calls = [] + dummy_worker = SimpleNamespace(stop=lambda: stop_calls.append("stopped")) + + saved_global_worker = module._global_worker + module._global_worker = dummy_worker + try: + module.shutdown_worker(args=None, data_buffer=None) + assert stop_calls == ["stopped"] + assert module._global_worker is None + finally: + module._global_worker = saved_global_worker + + +def test_after_weight_update_counts_recycled_partial_samples_as_stale(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, current_policy_version=2), data_buffer=None) + + worker.output_queue.put( + (300, [Sample(group_index=300, metadata={"fully_async_sample_id": 300}, status=sample_status.COMPLETED)]) + ) + worker._add_recycled_sample_ids([301, 302]) + + after = worker.after_weight_update(policy_version=3) + + assert after["stale_samples"] == 3 + assert worker.get_stale_sample_count() == 3 + + +def test_after_weight_update_evicts_completed_samples_outside_version_window(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker( + make_args( + fully_async_buffer_policy="window_evict", + fully_async_version_window=1, + current_policy_version=3, + ), + data_buffer=None, + ) + + for sample_id, version in ((320, 1), (321, 2), (322, 3)): + worker.output_queue.put( + ( + sample_id, + [ + Sample( + group_index=sample_id, + metadata={"fully_async_sample_id": sample_id, "fully_async_schedule_versions": [version]}, + weight_versions=[str(version)], + status=sample_status.COMPLETED, + ) + ], + ) + ) + worker._set_stale_sample_ids([320, 321, 322]) + + after = worker.after_weight_update(policy_version=4) + + assert after["fully_async/window/evicted_samples"] == 2 + assert after["fully_async/window/evicted_by_version"] == 2 + assert worker.get_queue_size() == 1 + assert worker.get_stale_sample_count() == 1 + remaining = worker.get_completed_samples(current_policy_version=4) + assert [sample_id for sample_id, _ in remaining] == [322] + + +def test_continuous_worker_loop_adds_newly_scheduled_samples_to_staleness_budget(): + module, Sample, sample_status = load_fully_async_module() + + class FakeDataBuffer: + def __init__(self): + self.returned = False + + def get_samples(self, num_samples): + assert num_samples == 1 + if self.returned: + return [] + self.returned = True + return [ + [ + Sample( + group_index=350, + index=350, + metadata={"fully_async_sample_id": 350}, + status=sample_status.COMPLETED, + ) + ] + ] + + worker = module.AsyncRolloutWorker( + make_args(staleness_threshold=0.0, update_weights_interval=1, rollout_batch_size=1), + data_buffer=FakeDataBuffer(), + ) + + async def run_test(): + loop_task = asyncio.create_task(worker.continuous_worker_loop()) + await asyncio.sleep(0.05) + worker.running = False + await loop_task + + asyncio.run(run_test()) + + assert worker.get_stale_sample_count() == 1 + + +def test_collect_task_result_clears_recycled_tracking_for_resumed_samples(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0), data_buffer=None) + + async def complete_recycled_task(): + group = [ + Sample(group_index=400, metadata={"fully_async_sample_id": 400}, status=sample_status.COMPLETED), + ] + task = asyncio.create_task(asyncio.sleep(0, result=group)) + worker.task_sample_ids[task] = 400 + worker._add_recycled_sample_ids([400]) + await task + worker._collect_task_result(task) + + asyncio.run(complete_recycled_task()) + + assert 400 not in worker._snapshot_recycled_sample_ids() + completed = worker.get_completed_samples() + assert [sample_id for sample_id, _ in completed] == [400] + + +def test_window_evict_policy_evicts_oldest_version_on_completed_store_overflow(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker( + make_args( + fully_async_buffer_policy="window_evict", + fully_async_version_window=10, + fully_async_max_completed_samples=2, + current_policy_version=2, + ), + data_buffer=None, + ) + + for sample_id, version in ((410, 0), (411, 1), (412, 2)): + worker.output_queue.put( + ( + sample_id, + [ + Sample( + group_index=sample_id, + index=sample_id, + metadata={"fully_async_sample_id": sample_id, "fully_async_schedule_versions": [version]}, + weight_versions=[str(version)], + status=sample_status.COMPLETED, + ) + ], + ) + ) + + remaining = worker.get_completed_samples(current_policy_version=2) + + assert [sample_id for sample_id, _ in remaining] == [411, 412] + + +def test_generate_rollout_async_only_drains_needed_completed_groups(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, rollout_batch_size=2), data_buffer=None) + worker.worker_thread = SimpleNamespace(is_alive=lambda: True) + + for sample_id in (500, 501, 502): + worker.output_queue.put( + ( + sample_id, + [ + Sample( + group_index=sample_id, + index=sample_id, + metadata={"fully_async_sample_id": sample_id}, + weight_versions=["1"], + status=sample_status.COMPLETED, + ) + ], + ) + ) + + saved_global_worker = module._global_worker + module._global_worker = worker + try: + output = asyncio.run( + module.generate_rollout_async(make_args(rollout_batch_size=2), rollout_id=0, data_buffer=None) + ) + finally: + module._global_worker = saved_global_worker + + assert [group[0].group_index for group in output.samples] == [500, 501] + assert worker.get_queue_size() == 1 + + +def test_generate_rollout_async_consumed_samples_refund_staleness_budget(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker(make_args(staleness_threshold=0.0, rollout_batch_size=2), data_buffer=None) + worker.worker_thread = SimpleNamespace(is_alive=lambda: True) + + for sample_id in (520, 521, 522): + worker.output_queue.put( + ( + sample_id, + [ + Sample( + group_index=sample_id, + index=sample_id, + metadata={"fully_async_sample_id": sample_id}, + weight_versions=["1"], + status=sample_status.COMPLETED, + ) + ], + ) + ) + worker._set_stale_sample_ids([520, 521, 522]) + + saved_global_worker = module._global_worker + module._global_worker = worker + try: + output = asyncio.run( + module.generate_rollout_async(make_args(rollout_batch_size=2), rollout_id=0, data_buffer=None) + ) + finally: + module._global_worker = saved_global_worker + + assert [group[0].group_index for group in output.samples] == [520, 521] + assert worker.get_queue_size() == 1 + assert worker.get_stale_sample_count() == 1 + + +def test_generate_rollout_async_window_evict_skips_out_of_window_completed_groups(): + module, Sample, sample_status = load_fully_async_module() + worker = module.AsyncRolloutWorker( + make_args( + fully_async_buffer_policy="window_evict", + fully_async_version_window=1, + current_policy_version=3, + rollout_batch_size=2, + ), + data_buffer=None, + ) + worker.worker_thread = SimpleNamespace(is_alive=lambda: True) + + for sample_id, version in ((540, 1), (541, 2), (542, 3)): + worker.output_queue.put( + ( + sample_id, + [ + Sample( + group_index=sample_id, + index=sample_id, + metadata={"fully_async_sample_id": sample_id, "fully_async_schedule_versions": [version]}, + weight_versions=[str(version)], + status=sample_status.COMPLETED, + ) + ], + ) + ) + + saved_global_worker = module._global_worker + module._global_worker = worker + try: + output = asyncio.run( + module.generate_rollout_async( + make_args( + fully_async_buffer_policy="window_evict", + fully_async_version_window=1, + current_policy_version=3, + rollout_batch_size=2, + ), + rollout_id=0, + data_buffer=None, + ) + ) + finally: + module._global_worker = saved_global_worker + + assert [group[0].group_index for group in output.samples] == [541, 542] + assert worker.get_queue_size() == 0 + + +def test_generate_rollout_async_recycled_samples_stay_in_staleness_budget(): + module, Sample, sample_status = load_fully_async_module() + + class FakeDataBuffer: + def __init__(self): + self.samples = [] + + def add_samples(self, samples): + self.samples.extend(samples) + + data_buffer = FakeDataBuffer() + worker = module.AsyncRolloutWorker( + make_args(staleness_threshold=0.0, rollout_batch_size=1), data_buffer=data_buffer + ) + worker.worker_thread = SimpleNamespace(is_alive=lambda: True) + + worker.output_queue.put( + ( + 530, + [ + Sample( + group_index=530, + index=530, + metadata={"fully_async_sample_id": 530}, + weight_versions=["1"], + status=sample_status.ABORTED, + ) + ], + ) + ) + worker.output_queue.put( + ( + 531, + [ + Sample( + group_index=531, + index=531, + metadata={"fully_async_sample_id": 531}, + weight_versions=["1"], + status=sample_status.COMPLETED, + ) + ], + ) + ) + worker._set_stale_sample_ids([530, 531]) + + saved_global_worker = module._global_worker + module._global_worker = worker + try: + output = asyncio.run( + module.generate_rollout_async(make_args(rollout_batch_size=1), rollout_id=0, data_buffer=data_buffer) + ) + finally: + module._global_worker = saved_global_worker + + assert [group[0].group_index for group in output.samples] == [531] + assert [group[0].group_index for group in data_buffer.samples] == [530] + assert worker.get_stale_sample_count() == 1 diff --git a/tests/test_rollout_manager_fully_async_metrics.py b/tests/test_rollout_manager_fully_async_metrics.py new file mode 100644 index 0000000000..c6fe11a31d --- /dev/null +++ b/tests/test_rollout_manager_fully_async_metrics.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from types import SimpleNamespace + + +def _package(name: str) -> types.ModuleType: + module = types.ModuleType(name) + module.__path__ = [] + return module + + +def load_rollout_manager_module(): + log_calls: list[tuple[object, dict, str]] = [] + + class DummyPlacementGroupSchedulingStrategy: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + class DummyLock: + @classmethod + def options(cls, **kwargs): + return cls() + + def remote(self): + return object() + + class DummySample: + class Status: + TRUNCATED = object() + + def fake_log(args, metrics, step_key): + log_calls.append((args, dict(metrics), step_key)) + + def fake_compute_rollout_step(args, rollout_id): + return rollout_id * 10 + 1 + + def fake_group_by(*args, **kwargs): + return {} + + def fake_load_function(*args, **kwargs): + return None + + modules = { + "ray": types.ModuleType("ray"), + "ray.util": _package("ray.util"), + "ray.util.scheduling_strategies": types.ModuleType("ray.util.scheduling_strategies"), + "torch": types.ModuleType("torch"), + "sglang": _package("sglang"), + "sglang.srt": _package("sglang.srt"), + "sglang.srt.constants": types.ModuleType("sglang.srt.constants"), + "slime": _package("slime"), + "slime.backends": _package("slime.backends"), + "slime.backends.sglang_utils": _package("slime.backends.sglang_utils"), + "slime.backends.sglang_utils.sglang_config": types.ModuleType("slime.backends.sglang_utils.sglang_config"), + "slime.backends.sglang_utils.sglang_engine": types.ModuleType("slime.backends.sglang_utils.sglang_engine"), + "slime.rollout": _package("slime.rollout"), + "slime.rollout.base_types": types.ModuleType("slime.rollout.base_types"), + "slime.utils": _package("slime.utils"), + "slime.utils.health_monitor": types.ModuleType("slime.utils.health_monitor"), + "slime.utils.http_utils": types.ModuleType("slime.utils.http_utils"), + "slime.utils.logging_utils": types.ModuleType("slime.utils.logging_utils"), + "slime.utils.metric_utils": types.ModuleType("slime.utils.metric_utils"), + "slime.utils.misc": types.ModuleType("slime.utils.misc"), + "slime.utils.seqlen_balancing": types.ModuleType("slime.utils.seqlen_balancing"), + "slime.utils.types": types.ModuleType("slime.utils.types"), + "slime.ray": _package("slime.ray"), + "slime.ray.utils": types.ModuleType("slime.ray.utils"), + } + + modules["ray"].remote = lambda obj=None, **kwargs: obj + modules["ray"].get = lambda value: value + modules["ray"].util = modules["ray.util"] + modules["ray.util"].scheduling_strategies = modules["ray.util.scheduling_strategies"] + modules["ray.util.scheduling_strategies"].PlacementGroupSchedulingStrategy = DummyPlacementGroupSchedulingStrategy + + modules["torch"].load = lambda *args, **kwargs: {"samples": []} + + modules["sglang"].srt = modules["sglang.srt"] + modules["sglang.srt"].constants = modules["sglang.srt.constants"] + modules["sglang.srt.constants"].GPU_MEMORY_TYPE_CUDA_GRAPH = 0 + modules["sglang.srt.constants"].GPU_MEMORY_TYPE_KV_CACHE = 1 + modules["sglang.srt.constants"].GPU_MEMORY_TYPE_WEIGHTS = 2 + + dummy_cls = type("Dummy", (), {}) + modules["slime.backends"].sglang_utils = modules["slime.backends.sglang_utils"] + modules["slime.backends.sglang_utils"].sglang_config = modules["slime.backends.sglang_utils.sglang_config"] + modules["slime.backends.sglang_utils"].sglang_engine = modules["slime.backends.sglang_utils.sglang_engine"] + modules["slime.backends.sglang_utils.sglang_config"].ModelConfig = dummy_cls + modules["slime.backends.sglang_utils.sglang_config"].ServerGroupConfig = dummy_cls + modules["slime.backends.sglang_utils.sglang_config"].SglangConfig = dummy_cls + modules["slime.backends.sglang_utils.sglang_engine"].SGLangEngine = dummy_cls + + modules["slime.rollout"].base_types = modules["slime.rollout.base_types"] + modules["slime.rollout.base_types"].call_rollout_fn = lambda *args, **kwargs: None + + modules["slime"].utils = modules["slime.utils"] + modules["slime.utils"].health_monitor = modules["slime.utils.health_monitor"] + modules["slime.utils"].http_utils = modules["slime.utils.http_utils"] + modules["slime.utils"].logging_utils = modules["slime.utils.logging_utils"] + modules["slime.utils"].metric_utils = modules["slime.utils.metric_utils"] + modules["slime.utils"].misc = modules["slime.utils.misc"] + modules["slime.utils"].seqlen_balancing = modules["slime.utils.seqlen_balancing"] + modules["slime.utils"].types = modules["slime.utils.types"] + + modules["slime.utils.health_monitor"].RolloutHealthMonitor = dummy_cls + modules["slime.utils.http_utils"]._wrap_ipv6 = lambda value: value + modules["slime.utils.http_utils"].find_available_port = lambda *args, **kwargs: 0 + modules["slime.utils.http_utils"].get_host_info = lambda *args, **kwargs: ("127.0.0.1", "localhost") + modules["slime.utils.http_utils"].init_http_client = lambda *args, **kwargs: None + modules["slime.utils.logging_utils"].configure_logger = lambda *args, **kwargs: None + modules["slime.utils.logging_utils"].finish_tracking = lambda *args, **kwargs: None + modules["slime.utils.logging_utils"].init_tracking = lambda *args, **kwargs: None + modules["slime.utils.logging_utils"].log = fake_log + modules["slime.utils.metric_utils"].compute_pass_rate = lambda *args, **kwargs: {} + modules["slime.utils.metric_utils"].compute_rollout_step = fake_compute_rollout_step + modules["slime.utils.metric_utils"].compute_statistics = lambda values: {} + modules["slime.utils.metric_utils"].dict_add_prefix = lambda d, prefix: {f"{prefix}{k}": v for k, v in d.items()} + modules["slime.utils.metric_utils"].has_repetition = lambda text: False + modules["slime.utils.misc"].Box = dict + modules["slime.utils.misc"].group_by = fake_group_by + modules["slime.utils.misc"].load_function = fake_load_function + modules["slime.utils.seqlen_balancing"].get_seqlen_balanced_partitions = lambda *args, **kwargs: [] + modules["slime.utils.types"].Sample = DummySample + + modules["slime"].ray = modules["slime.ray"] + modules["slime.ray"].utils = modules["slime.ray.utils"] + modules["slime.ray.utils"].NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [] + modules["slime.ray.utils"].Lock = DummyLock + + saved_modules = {name: sys.modules.get(name) for name in modules} + saved_test_module = sys.modules.get("slime.ray.rollout_under_test") + + try: + sys.modules.update(modules) + spec = importlib.util.spec_from_file_location( + "slime.ray.rollout_under_test", + Path(__file__).resolve().parents[1] / "slime" / "ray" / "rollout.py", + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module, log_calls + finally: + for name, saved in saved_modules.items(): + if saved is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = saved + if saved_test_module is None: + sys.modules.pop("slime.ray.rollout_under_test", None) + else: + sys.modules["slime.ray.rollout_under_test"] = saved_test_module + + +def make_args(**overrides): + defaults = dict( + use_wandb=False, + use_tensorboard=False, + wandb_always_use_train_step=False, + rollout_batch_size=4, + n_samples_per_prompt=2, + global_batch_size=8, + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def test_log_fully_async_metrics_uses_dedicated_step_axis(): + module, log_calls = load_rollout_manager_module() + args = make_args() + fake_manager = SimpleNamespace(args=args, rollout_id=7, _fully_async_log_step=0) + + module.RolloutManager._log_fully_async_metrics( + fake_manager, + { + "fully_async/count/stale_samples_processed": 3, + "fully_async/partial/total_partial_num": 2, + "stale_samples": 99, + }, + ) + + module.RolloutManager._log_fully_async_metrics( + fake_manager, + { + "fully_async/count/stale_samples_processed": 5, + }, + ) + + assert len(log_calls) == 2 + _, metrics, step_key = log_calls[0] + assert step_key == "fully_async/step" + assert metrics["fully_async/step"] == 0 + assert metrics["fully_async/count/stale_samples_processed"] == 3 + assert metrics["fully_async/partial/total_partial_num"] == 2 + assert "rollout/step" not in metrics + assert "stale_samples" not in metrics + _, second_metrics, second_step_key = log_calls[1] + assert second_step_key == "fully_async/step" + assert second_metrics["fully_async/step"] == 1 + assert second_metrics["fully_async/count/stale_samples_processed"] == 5 + assert fake_manager._fully_async_log_step == 2 + + +def test_log_fully_async_metrics_skips_initial_sync_and_empty_payloads(): + module, log_calls = load_rollout_manager_module() + args = make_args() + + module.RolloutManager._log_fully_async_metrics( + SimpleNamespace(args=args, rollout_id=-1), + {"fully_async/count/stale_samples_processed": 1}, + ) + module.RolloutManager._log_fully_async_metrics(SimpleNamespace(args=args, rollout_id=3), None) + module.RolloutManager._log_fully_async_metrics(SimpleNamespace(args=args, rollout_id=3), {"stale_samples": 1}) + + assert log_calls == [] + + +def test_after_weight_update_logs_hook_metrics(): + module, log_calls = load_rollout_manager_module() + args = make_args() + runtime_updates = [] + hook_calls = [] + hook_result = { + "fully_async/count/stale_samples_processed": 4, + "fully_async/partial/total_partial_num": 1, + } + fake_manager = SimpleNamespace( + args=args, + rollout_id=5, + _fully_async_log_step=0, + update_runtime_state=lambda **metadata: runtime_updates.append(metadata), + _call_generate_rollout_hook=lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or hook_result, + ) + fake_manager._log_fully_async_metrics = lambda result: module.RolloutManager._log_fully_async_metrics( + fake_manager, result + ) + + result = module.RolloutManager.after_weight_update(fake_manager, policy_version=3) + + assert result is hook_result + assert runtime_updates == [{"current_policy_version": 3}] + assert hook_calls == [("after_weight_update", {"policy_version": 3})] + assert len(log_calls) == 1 + _, metrics, step_key = log_calls[0] + assert step_key == "fully_async/step" + assert metrics["fully_async/step"] == 0 + assert fake_manager._fully_async_log_step == 1 + + +def test_dispose_flushes_tail_window_metrics(): + module, log_calls = load_rollout_manager_module() + args = make_args() + hook_calls = [] + stop_calls = [] + hook_results = { + "flush_metrics": {"fully_async/partial/total_partial_num": 3}, + "shutdown_worker": None, + } + fake_manager = SimpleNamespace( + args=args, + rollout_id=2, + _fully_async_log_step=0, + _health_monitors=[ + SimpleNamespace(stop=lambda: stop_calls.append("first")), + SimpleNamespace(stop=lambda: stop_calls.append("second")), + ], + _call_generate_rollout_hook=lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) + or hook_results[hook_name], + ) + fake_manager._log_fully_async_metrics = lambda result: module.RolloutManager._log_fully_async_metrics( + fake_manager, result + ) + + module.RolloutManager.dispose(fake_manager) + + assert hook_calls == [("flush_metrics", {}), ("shutdown_worker", {})] + assert len(log_calls) == 1 + _, metrics, step_key = log_calls[0] + assert step_key == "fully_async/step" + assert metrics["fully_async/step"] == 0 + assert metrics["fully_async/partial/total_partial_num"] == 3 + assert stop_calls == ["first", "second"] + assert fake_manager._fully_async_log_step == 1 diff --git a/train_async.py b/train_async.py index 94cc29694d..417b3b12f6 100644 --- a/train_async.py +++ b/train_async.py @@ -24,10 +24,12 @@ def train(args): # create the actor and critic models actor_model, critic_model = create_training_models(args, pgs, rollout_manager) + policy_version = 0 # always update weight first so that sglang has the loaded weights from training. if not args.critic_train_only: actor_model.update_weights() + ray.get(rollout_manager.after_weight_update.remote(policy_version)) if args.check_weight_update_equal: ray.get(rollout_manager.check_weights.remote(action="compare")) @@ -70,7 +72,10 @@ def train(args): rollout_data_curr_ref = ray.get(x) if (x := rollout_data_next_future) is not None else None rollout_data_next_future = None if not args.critic_train_only: + ray.get(rollout_manager.before_weight_update.remote(policy_version)) actor_model.update_weights() + policy_version += 1 + ray.get(rollout_manager.after_weight_update.remote(policy_version)) if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch): ray.get(rollout_manager.eval.remote(rollout_id))