diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 8a7f768b3b..2b3d3d930d 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -21,6 +21,79 @@ logger = logging.getLogger(__name__) +# One-shot latch to suppress repeated empty-microbatch warnings (see get_batch). +_empty_microbatch_warned: bool = False + + +def _fill_empty_microbatch_placeholder(batch: dict, keys, pad_token_id: int) -> list: + """Populate ``batch`` with a self-consistent 1-token placeholder sample + and return the new ``tokens`` list. Invoked by ``get_batch`` only when + the local DP rank has no real samples for this micro-batch. + + Invariants after this function returns (matching the 1-token placeholder): + batch["tokens"] == [pad_token_tensor] + batch["total_lengths"] == [1] + batch["response_lengths"] == [0] + batch["loss_masks"] == [zero-size int tensor] + batch["max_seq_lens"] == [1] (bshd-only, if requested) + every other list-valued key in ``keys`` has exactly one entry + + The downstream effect of response_length=0 is that CP-chunk sizes are + 0 on every CP rank, so per-sample log_probs / rollout_log_probs / etc. + contribute 0-size tensors and ``sum_of_sample_mean`` splits a 0-sized + slice with ``split_sizes=[0]``. + + If a new rollout-schema field lands in ``keys`` without a corresponding + entry in the ``placeholder_for_key`` map below or matching the + per-token fp32 default, the post-fill invariant assertion fails here + — not 400 lines downstream inside ``torch.split``. + """ + device = torch.cuda.current_device() + placeholder_for_key = { + "total_lengths": 1, + "response_lengths": 0, + "max_seq_lens": 1, + # length-0 mask; the post-padding step in get_batch aligns it + # with the size-1 placeholder token. + "loss_masks": torch.zeros(0, dtype=torch.int, device=device), + "multimodal_train_inputs": None, + } + # Per-token tensor-valued keys (log_probs, ref_log_probs, advantages, + # etc.) default to an empty fp32 tensor so torch.cat downstream works. + # New per-token fields of a different dtype must add themselves to + # placeholder_for_key instead of relying on this default. + per_token_default = torch.zeros(0, dtype=torch.float32, device=device) + + placeholder = torch.tensor([pad_token_id], dtype=torch.long, device=device) + batch["tokens"] = [placeholder] + for key in keys: + if key == "tokens": + continue + v = batch.get(key) + if v is None: + continue + if isinstance(v, list) and len(v) == 0: + if key in placeholder_for_key: + batch[key] = [placeholder_for_key[key]] + else: + batch[key] = [per_token_default] + + # Post-fill invariant: every list-valued key now has exactly one entry. + # If this fails, a new schema field was added to `keys` without a + # placeholder rule above. + for key in keys: + if key == "tokens": + continue + v = batch.get(key) + if isinstance(v, list): + assert len(v) == 1, ( + f"empty-microbatch placeholder did not fill key={key!r}; " + f"got len(batch[{key!r}])={len(v)}, expected 1. Add a " + f"placeholder_for_key entry in _fill_empty_microbatch_placeholder." + ) + + return batch["tokens"] + def get_batch( data_iterator: "DataIterator", @@ -61,6 +134,37 @@ def get_batch( pad_token_id = 0 pad_size = mpu.get_tensor_model_parallel_world_size() * pad_multiplier + # DP-imbalance guard: when DP ranks need different microbatch counts, the + # pipeline schedule loops for `max(num_mbs)` steps and ranks with surplus + # see empty micro-batches. Before, this raised a confusing + # ``torch.cat(): expected a non-empty list of Tensors``; replace with a + # self-consistent single-token placeholder so downstream operations + # (loss_mask align, CP slicing, per-sample log-prob extraction, + # sum_of_sample_mean split) all agree on 0 response tokens. + # + # Invariants of the placeholder: + # tokens = [pad] (size 1 — prompt only) + # total_lengths = [1] + # response_lengths= [0] + # loss_masks = [ [] ] (0 response tokens → empty mask) + # max_seq_lens = [1] (bshd-only, if requested) + # With response_length=0 the CP-chunk sizes are 0 on every CP rank, so + # the fake sample contributes 0-size tensors to per-sample log_probs / + # rollout_log_probs etc. `split_sizes=[0]` then splits a 0-sized slice. + if not tokens: + # Log once per rank — on an unbalanced multi-hour training run, every + # empty microbatch fires this path and per-instance logging would + # drown the logs. + global _empty_microbatch_warned + if not _empty_microbatch_warned: + logger.warning( + "get_batch: empty micro-batch (DP rank has fewer partitions " + "than the collective max); inserting 1-token placeholder. " + "Further occurrences on this rank will not be re-logged." + ) + _empty_microbatch_warned = True + tokens = _fill_empty_microbatch_placeholder(batch, keys, pad_token_id) + # for cp, we need all tokens to calculate logprob batch["unconcat_tokens"] = tokens diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 0bbe5bf49b..b2e183f512 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -289,8 +289,36 @@ def forward_step( if args.use_dynamic_batch_size: # TODO: This is ugly... Find a better way to make the data have the same order. # TODO: move this out of the loop. - origin_values = [None] * len(values) origin_indices = sum(data_iterator[0].micro_batch_indices, []) + # Size the output to the real sample count. When DP ranks need + # unequal microbatch counts, the all_reduce(MAX) + cap-aware + # partitioning can yield empty partitions, and our get_batch + # placeholder guard then produces extra "fake" forward outputs. + # Those trailing fakes must be dropped here; otherwise the + # result list ends with `None`s that crash downstream + # (e.g. compute_advantages_and_returns on log_probs). + # + # Invariant (enforced by the first-fit bin-packer in + # seqlen_balancing): real samples occupy origin_indices in + # [0, len(origin_indices)), so sizing by len(origin_indices) + # is correct and any placeholder outputs past that index + # are silently dropped. If the partitioner ever changes to + # interleave empties, this guard would silently drop real + # samples — assert the invariant locally. + if origin_indices: + assert max(origin_indices) < len(origin_indices), ( + f"dynamic-batch reorder expects real samples at positions " + f"[0, {len(origin_indices)}); got max(origin_indices)=" + f"{max(origin_indices)}. The seqlen partitioner changed." + ) + # And check we have *at least* as many forward outputs as real + # samples — fewer would silently drop real data. + assert len(values) >= len(origin_indices), ( + f"forward produced {len(values)} outputs for " + f"{len(origin_indices)} real samples; real samples would be " + f"dropped by the zip below." + ) + origin_values = [None] * len(origin_indices) for value, origin_index in zip(values, origin_indices, strict=False): origin_values[origin_index] = value values = origin_values diff --git a/tests/test_dp_imbalance_placeholder.py b/tests/test_dp_imbalance_placeholder.py new file mode 100644 index 0000000000..2b8cce7cba --- /dev/null +++ b/tests/test_dp_imbalance_placeholder.py @@ -0,0 +1,129 @@ +"""Test for the get_batch empty-microbatch placeholder schema. + +When a DP rank has fewer real partitions than the collective MAX, +``get_batch`` calls ``_fill_empty_microbatch_placeholder`` to fabricate +a 1-token placeholder sample. The placeholder must be *self-consistent*: +every list-valued schema key must end up with exactly one entry, the +key-specific values must match the declared contract (total_lengths=[1], +response_lengths=[0], ...), and the post-fill invariant assertion must +trip when a schema field lands in ``keys`` without a placeholder rule. + +These tests import the real helper from ``data.py`` so a future change +to the placeholder logic fails here rather than documenting a +replica of it. +""" + +from __future__ import annotations + +import pytest + +try: + import torch +except ImportError: # pragma: no cover + pytest.skip("torch not available", allow_module_level=True) + +# The helper uses torch.cuda.current_device(); on a CPU-only environment +# CUDA is unavailable. Skip the whole module in that case (these tests +# exist to exercise the CUDA code path that fires in production). +if not torch.cuda.is_available(): # pragma: no cover + pytest.skip("CUDA required for the device=cuda placeholder tensors", allow_module_level=True) + +from slime.backends.megatron_utils.data import _fill_empty_microbatch_placeholder # noqa: E402 + + +def test_placeholder_fills_known_keys(): + batch = { + "tokens": [], + "total_lengths": [], + "response_lengths": [], + "max_seq_lens": [], + "loss_masks": [], + "log_probs": [], + "ref_log_probs": [], + } + keys = list(batch.keys()) + tokens = _fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0) + + # Returned tokens list == batch["tokens"] + assert tokens is batch["tokens"] + assert len(tokens) == 1 and tokens[0].numel() == 1 + assert tokens[0].dtype == torch.long + + # Known-schema fields match their contract. + assert batch["total_lengths"] == [1] + assert batch["response_lengths"] == [0] + assert batch["max_seq_lens"] == [1] + assert isinstance(batch["loss_masks"][0], torch.Tensor) + assert batch["loss_masks"][0].numel() == 0 + assert batch["loss_masks"][0].dtype == torch.int + # Per-token default for unknown fields: length-0 fp32. + assert batch["log_probs"][0].numel() == 0 + assert batch["log_probs"][0].dtype == torch.float32 + assert batch["ref_log_probs"][0].numel() == 0 + + +def test_placeholder_per_sample_counts_all_match_tokens(): + """Core invariant: every list-valued key has exactly one entry + after the fill, aligned with the 1-token placeholder.""" + batch = { + "tokens": [], + "total_lengths": [], + "response_lengths": [], + "loss_masks": [], + "log_probs": [], + "advantages": [], + } + keys = list(batch.keys()) + _fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0) + + for k, v in batch.items(): + assert len(v) == 1, f"key {k!r}: expected 1 entry, got {len(v)}" + + +def test_placeholder_unknown_key_gets_per_token_default(): + """A new schema field not in placeholder_for_key gets the per-token + default. This encodes the 'everything else is a per-token fp32 tensor' + assumption — if that stops being true, a future key should add + itself to placeholder_for_key.""" + batch = {"tokens": [], "something_new": []} + keys = list(batch.keys()) + _fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0) + assert batch["something_new"][0].dtype == torch.float32 + assert batch["something_new"][0].numel() == 0 + + +def test_placeholder_invariant_fires_on_unfilled_key(): + """The post-fill invariant assertion catches the case where a key + is present in ``keys`` but its entry in ``batch`` is not a list — + the fill loop skips it, then the invariant loop catches it. + + This exercises the actual assertion in ``_fill_empty_microbatch_placeholder``, + not a copy of it in the test — so future edits to the helper that + weaken or move the assertion make this test fail. + """ + # Simulate the broken-caller shape: 'total_lengths' is present but a + # non-list sentinel (typical bug: caller forgot to initialize as []). + # The helper's fill loop sees `isinstance(v, list)` False and skips + # it; the invariant loop then sees `isinstance(v, list)` False too + # and also skips it — so actually no AssertionError fires. + # + # Instead, construct a case where the key is a list with != 1 entries + # after fill. The simplest route: pass a key that isn't in batch at + # all (batch.get returns None, fill loop skips), but *also* add a + # sentinel entry to batch so that the invariant loop encounters a + # list with 0 entries. + batch = { + "tokens": [], + # 'bad_key' is in keys but batch[bad_key] is a list with 2 entries, + # so it skips the len==0 branch in the fill loop. The invariant + # then sees len == 2 and raises. + "bad_key": ["stale1", "stale2"], + } + keys = list(batch.keys()) + + with pytest.raises(AssertionError, match="bad_key"): + _fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_dp_imbalance_reorder.py b/tests/test_dp_imbalance_reorder.py new file mode 100644 index 0000000000..b5fee5ea82 --- /dev/null +++ b/tests/test_dp_imbalance_reorder.py @@ -0,0 +1,99 @@ +"""Regression test for forward_only's handling of empty partitions. + +Scenario: with `use_dynamic_batch_size=True`, an all_reduce(MAX) across DP +ranks can produce `num_microbatches` larger than the local rank has real +samples for. `_get_capped_partitions` then yields trailing empty partitions +like `[[0,3], [1,2], [], []]`, and `get_batch`'s placeholder guard emits a +1-token fake sample for each empty partition. Those fakes flow through the +forward pass and into `forward_data_store`. + +Before the fix, `forward_only` sized `origin_values` to `len(values)` (which +included the fake outputs), and `zip(values, origin_indices, strict=False)` +only populated the real-sample slots — so the tail of the returned list was +`None`. That caused a downstream crash in +`compute_advantages_and_returns` at + kl = [torch.zeros_like(x, ..., device=x.device) for x in xs] +with `AttributeError: 'NoneType' object has no attribute 'device'` on the +first `None` element. + +Related upstream reports: THUDM/slime#1838 (the `torch.cat([])` flavor of +this same DP-imbalance problem, which the `get_batch` placeholder guard +already fixes) and THUDM/slime#1839 (oversized-sample assertion). + +This test simulates the exact reorder logic in `forward_only` to pin the +fix in `origin_values = [None] * len(origin_indices)` (as opposed to +`len(values)`). +""" + +from __future__ import annotations + + +def _reorder(values: list, micro_batch_indices: list[list[int]]) -> list: + """Mirror the reorder in forward_only model.py, post-fix.""" + origin_indices = sum(micro_batch_indices, []) + origin_values = [None] * len(origin_indices) + for value, origin_index in zip(values, origin_indices, strict=False): + origin_values[origin_index] = value + return origin_values + + +def test_no_empty_partitions_preserves_order(): + # 4 real samples, 2 microbatches, no empty partitions. + mb_indices = [[0, 3], [1, 2]] + values = ["s0", "s3", "s1", "s2"] + out = _reorder(values, mb_indices) + assert out == ["s0", "s1", "s2", "s3"] + + +def test_trailing_empty_partitions_drop_fakes(): + # 4 real samples split across 2 real microbatches + 2 empty ones + # (produced by _get_capped_partitions when num_mbs exceeds what the + # local rank needs). Each empty partition contributes one 1-token + # placeholder via get_batch, which becomes a "fake" value in `values`. + mb_indices = [[0, 3], [1, 2], [], []] + values = ["s0", "s3", "s1", "s2", "FAKE0", "FAKE1"] + out = _reorder(values, mb_indices) + # Only real samples at origin positions; fakes are dropped. + assert out == ["s0", "s1", "s2", "s3"] + assert None not in out + + +def test_all_partitions_empty_returns_empty_list(): + # Degenerate rank with 0 real samples (empty micro_batch_indices list). + # Every microbatch is a fake; output is an empty list, not a list of None. + mb_indices = [[], [], []] + values = ["FAKE0", "FAKE1", "FAKE2"] + out = _reorder(values, mb_indices) + assert out == [] + + +def test_empty_partitions_only_at_tail(): + # _get_capped_partitions uses first-fit bin-packing: samples always go to + # the lowest-index partition that fits, so any unused capacity appears in + # trailing partitions — never interleaved. This documents that + # invariant; the simpler zip-based reorder relies on it. + # (If KK ever produced mid-schedule empties, len(partition) > 0 would + # assert in get_seqlen_balanced_partitions, not reach this code path.) + mb_indices = [[0, 1, 2], [3], [], []] + values = ["s0", "s1", "s2", "s3", "FAKE0", "FAKE1"] + out = _reorder(values, mb_indices) + assert out == ["s0", "s1", "s2", "s3"] + assert None not in out + + +def test_fix_vs_pre_fix_behavior_divergence(): + """Sanity check that the OLD reorder (sizing by len(values)) produced + trailing Nones. Documents the specific bug shape.""" + mb_indices = [[0, 3], [1, 2], [], []] + values = ["s0", "s3", "s1", "s2", "FAKE0", "FAKE1"] + + # Pre-fix (bug): origin_values sized by len(values) → 2 trailing Nones + pre_fix = [None] * len(values) + origin_indices = sum(mb_indices, []) + for value, origin_index in zip(values, origin_indices, strict=False): + pre_fix[origin_index] = value + assert pre_fix[-2:] == [None, None], "pre-fix list should end with Nones" + + # Post-fix: origin_values sized by len(origin_indices) → no trailing Nones + post_fix = _reorder(values, mb_indices) + assert None not in post_fix