Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,79 @@

logger = logging.getLogger(__name__)

# One-shot latch to suppress repeated empty-microbatch warnings (see get_batch).
_empty_microbatch_warned: bool = False


def _fill_empty_microbatch_placeholder(batch: dict, keys, pad_token_id: int) -> list:
"""Populate ``batch`` with a self-consistent 1-token placeholder sample
and return the new ``tokens`` list. Invoked by ``get_batch`` only when
the local DP rank has no real samples for this micro-batch.

Invariants after this function returns (matching the 1-token placeholder):
batch["tokens"] == [pad_token_tensor]
batch["total_lengths"] == [1]
batch["response_lengths"] == [0]
batch["loss_masks"] == [zero-size int tensor]
batch["max_seq_lens"] == [1] (bshd-only, if requested)
every other list-valued key in ``keys`` has exactly one entry

The downstream effect of response_length=0 is that CP-chunk sizes are
0 on every CP rank, so per-sample log_probs / rollout_log_probs / etc.
contribute 0-size tensors and ``sum_of_sample_mean`` splits a 0-sized
slice with ``split_sizes=[0]``.

If a new rollout-schema field lands in ``keys`` without a corresponding
entry in the ``placeholder_for_key`` map below or matching the
per-token fp32 default, the post-fill invariant assertion fails here
— not 400 lines downstream inside ``torch.split``.
"""
device = torch.cuda.current_device()
placeholder_for_key = {
"total_lengths": 1,
"response_lengths": 0,
"max_seq_lens": 1,
# length-0 mask; the post-padding step in get_batch aligns it
# with the size-1 placeholder token.
"loss_masks": torch.zeros(0, dtype=torch.int, device=device),
"multimodal_train_inputs": None,
}
# Per-token tensor-valued keys (log_probs, ref_log_probs, advantages,
# etc.) default to an empty fp32 tensor so torch.cat downstream works.
# New per-token fields of a different dtype must add themselves to
# placeholder_for_key instead of relying on this default.
per_token_default = torch.zeros(0, dtype=torch.float32, device=device)

placeholder = torch.tensor([pad_token_id], dtype=torch.long, device=device)
batch["tokens"] = [placeholder]
for key in keys:
if key == "tokens":
continue
v = batch.get(key)
if v is None:
continue
if isinstance(v, list) and len(v) == 0:
if key in placeholder_for_key:
batch[key] = [placeholder_for_key[key]]
else:
batch[key] = [per_token_default]

# Post-fill invariant: every list-valued key now has exactly one entry.
# If this fails, a new schema field was added to `keys` without a
# placeholder rule above.
for key in keys:
if key == "tokens":
continue
v = batch.get(key)
if isinstance(v, list):
assert len(v) == 1, (
f"empty-microbatch placeholder did not fill key={key!r}; "
f"got len(batch[{key!r}])={len(v)}, expected 1. Add a "
f"placeholder_for_key entry in _fill_empty_microbatch_placeholder."
)

return batch["tokens"]


def get_batch(
data_iterator: "DataIterator",
Expand Down Expand Up @@ -61,6 +134,37 @@ def get_batch(
pad_token_id = 0
pad_size = mpu.get_tensor_model_parallel_world_size() * pad_multiplier

# DP-imbalance guard: when DP ranks need different microbatch counts, the
# pipeline schedule loops for `max(num_mbs)` steps and ranks with surplus
# see empty micro-batches. Before, this raised a confusing
# ``torch.cat(): expected a non-empty list of Tensors``; replace with a
# self-consistent single-token placeholder so downstream operations
# (loss_mask align, CP slicing, per-sample log-prob extraction,
# sum_of_sample_mean split) all agree on 0 response tokens.
#
# Invariants of the placeholder:
# tokens = [pad] (size 1 — prompt only)
# total_lengths = [1]
# response_lengths= [0]
# loss_masks = [ [] ] (0 response tokens → empty mask)
# max_seq_lens = [1] (bshd-only, if requested)
# With response_length=0 the CP-chunk sizes are 0 on every CP rank, so
# the fake sample contributes 0-size tensors to per-sample log_probs /
# rollout_log_probs etc. `split_sizes=[0]` then splits a 0-sized slice.
if not tokens:
# Log once per rank — on an unbalanced multi-hour training run, every
# empty microbatch fires this path and per-instance logging would
# drown the logs.
global _empty_microbatch_warned
if not _empty_microbatch_warned:
logger.warning(
"get_batch: empty micro-batch (DP rank has fewer partitions "
"than the collective max); inserting 1-token placeholder. "
"Further occurrences on this rank will not be re-logged."
)
_empty_microbatch_warned = True
tokens = _fill_empty_microbatch_placeholder(batch, keys, pad_token_id)

# for cp, we need all tokens to calculate logprob
batch["unconcat_tokens"] = tokens

Expand Down
30 changes: 29 additions & 1 deletion slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,36 @@ def forward_step(
if args.use_dynamic_batch_size:
# TODO: This is ugly... Find a better way to make the data have the same order.
# TODO: move this out of the loop.
origin_values = [None] * len(values)
origin_indices = sum(data_iterator[0].micro_batch_indices, [])
# Size the output to the real sample count. When DP ranks need
# unequal microbatch counts, the all_reduce(MAX) + cap-aware
# partitioning can yield empty partitions, and our get_batch
# placeholder guard then produces extra "fake" forward outputs.
# Those trailing fakes must be dropped here; otherwise the
# result list ends with `None`s that crash downstream
# (e.g. compute_advantages_and_returns on log_probs).
#
# Invariant (enforced by the first-fit bin-packer in
# seqlen_balancing): real samples occupy origin_indices in
# [0, len(origin_indices)), so sizing by len(origin_indices)
# is correct and any placeholder outputs past that index
# are silently dropped. If the partitioner ever changes to
# interleave empties, this guard would silently drop real
# samples — assert the invariant locally.
if origin_indices:
assert max(origin_indices) < len(origin_indices), (
f"dynamic-batch reorder expects real samples at positions "
f"[0, {len(origin_indices)}); got max(origin_indices)="
f"{max(origin_indices)}. The seqlen partitioner changed."
)
# And check we have *at least* as many forward outputs as real
# samples — fewer would silently drop real data.
assert len(values) >= len(origin_indices), (
f"forward produced {len(values)} outputs for "
f"{len(origin_indices)} real samples; real samples would be "
f"dropped by the zip below."
)
origin_values = [None] * len(origin_indices)
for value, origin_index in zip(values, origin_indices, strict=False):
origin_values[origin_index] = value
values = origin_values
Expand Down
129 changes: 129 additions & 0 deletions tests/test_dp_imbalance_placeholder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Test for the get_batch empty-microbatch placeholder schema.

When a DP rank has fewer real partitions than the collective MAX,
``get_batch`` calls ``_fill_empty_microbatch_placeholder`` to fabricate
a 1-token placeholder sample. The placeholder must be *self-consistent*:
every list-valued schema key must end up with exactly one entry, the
key-specific values must match the declared contract (total_lengths=[1],
response_lengths=[0], ...), and the post-fill invariant assertion must
trip when a schema field lands in ``keys`` without a placeholder rule.

These tests import the real helper from ``data.py`` so a future change
to the placeholder logic fails here rather than documenting a
replica of it.
"""

from __future__ import annotations

import pytest

try:
import torch
except ImportError: # pragma: no cover
pytest.skip("torch not available", allow_module_level=True)

# The helper uses torch.cuda.current_device(); on a CPU-only environment
# CUDA is unavailable. Skip the whole module in that case (these tests
# exist to exercise the CUDA code path that fires in production).
if not torch.cuda.is_available(): # pragma: no cover
pytest.skip("CUDA required for the device=cuda placeholder tensors", allow_module_level=True)

from slime.backends.megatron_utils.data import _fill_empty_microbatch_placeholder # noqa: E402


def test_placeholder_fills_known_keys():
batch = {
"tokens": [],
"total_lengths": [],
"response_lengths": [],
"max_seq_lens": [],
"loss_masks": [],
"log_probs": [],
"ref_log_probs": [],
}
keys = list(batch.keys())
tokens = _fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0)

# Returned tokens list == batch["tokens"]
assert tokens is batch["tokens"]
assert len(tokens) == 1 and tokens[0].numel() == 1
assert tokens[0].dtype == torch.long

# Known-schema fields match their contract.
assert batch["total_lengths"] == [1]
assert batch["response_lengths"] == [0]
assert batch["max_seq_lens"] == [1]
assert isinstance(batch["loss_masks"][0], torch.Tensor)
assert batch["loss_masks"][0].numel() == 0
assert batch["loss_masks"][0].dtype == torch.int
# Per-token default for unknown fields: length-0 fp32.
assert batch["log_probs"][0].numel() == 0
assert batch["log_probs"][0].dtype == torch.float32
assert batch["ref_log_probs"][0].numel() == 0


def test_placeholder_per_sample_counts_all_match_tokens():
"""Core invariant: every list-valued key has exactly one entry
after the fill, aligned with the 1-token placeholder."""
batch = {
"tokens": [],
"total_lengths": [],
"response_lengths": [],
"loss_masks": [],
"log_probs": [],
"advantages": [],
}
keys = list(batch.keys())
_fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0)

for k, v in batch.items():
assert len(v) == 1, f"key {k!r}: expected 1 entry, got {len(v)}"


def test_placeholder_unknown_key_gets_per_token_default():
"""A new schema field not in placeholder_for_key gets the per-token
default. This encodes the 'everything else is a per-token fp32 tensor'
assumption — if that stops being true, a future key should add
itself to placeholder_for_key."""
batch = {"tokens": [], "something_new": []}
keys = list(batch.keys())
_fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0)
assert batch["something_new"][0].dtype == torch.float32
assert batch["something_new"][0].numel() == 0


def test_placeholder_invariant_fires_on_unfilled_key():
"""The post-fill invariant assertion catches the case where a key
is present in ``keys`` but its entry in ``batch`` is not a list —
the fill loop skips it, then the invariant loop catches it.

This exercises the actual assertion in ``_fill_empty_microbatch_placeholder``,
not a copy of it in the test — so future edits to the helper that
weaken or move the assertion make this test fail.
"""
# Simulate the broken-caller shape: 'total_lengths' is present but a
# non-list sentinel (typical bug: caller forgot to initialize as []).
# The helper's fill loop sees `isinstance(v, list)` False and skips
# it; the invariant loop then sees `isinstance(v, list)` False too
# and also skips it — so actually no AssertionError fires.
#
# Instead, construct a case where the key is a list with != 1 entries
# after fill. The simplest route: pass a key that isn't in batch at
# all (batch.get returns None, fill loop skips), but *also* add a
# sentinel entry to batch so that the invariant loop encounters a
# list with 0 entries.
batch = {
"tokens": [],
# 'bad_key' is in keys but batch[bad_key] is a list with 2 entries,
# so it skips the len==0 branch in the fill loop. The invariant
# then sees len == 2 and raises.
"bad_key": ["stale1", "stale2"],
}
keys = list(batch.keys())

with pytest.raises(AssertionError, match="bad_key"):
_fill_empty_microbatch_placeholder(batch, keys, pad_token_id=0)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
99 changes: 99 additions & 0 deletions tests/test_dp_imbalance_reorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Regression test for forward_only's handling of empty partitions.

Scenario: with `use_dynamic_batch_size=True`, an all_reduce(MAX) across DP
ranks can produce `num_microbatches` larger than the local rank has real
samples for. `_get_capped_partitions` then yields trailing empty partitions
like `[[0,3], [1,2], [], []]`, and `get_batch`'s placeholder guard emits a
1-token fake sample for each empty partition. Those fakes flow through the
forward pass and into `forward_data_store`.

Before the fix, `forward_only` sized `origin_values` to `len(values)` (which
included the fake outputs), and `zip(values, origin_indices, strict=False)`
only populated the real-sample slots — so the tail of the returned list was
`None`. That caused a downstream crash in
`compute_advantages_and_returns` at
kl = [torch.zeros_like(x, ..., device=x.device) for x in xs]
with `AttributeError: 'NoneType' object has no attribute 'device'` on the
first `None` element.

Related upstream reports: THUDM/slime#1838 (the `torch.cat([])` flavor of
this same DP-imbalance problem, which the `get_batch` placeholder guard
already fixes) and THUDM/slime#1839 (oversized-sample assertion).

This test simulates the exact reorder logic in `forward_only` to pin the
fix in `origin_values = [None] * len(origin_indices)` (as opposed to
`len(values)`).
"""

from __future__ import annotations


def _reorder(values: list, micro_batch_indices: list[list[int]]) -> list:
"""Mirror the reorder in forward_only model.py, post-fix."""
origin_indices = sum(micro_batch_indices, [])
origin_values = [None] * len(origin_indices)
for value, origin_index in zip(values, origin_indices, strict=False):
origin_values[origin_index] = value
return origin_values


def test_no_empty_partitions_preserves_order():
# 4 real samples, 2 microbatches, no empty partitions.
mb_indices = [[0, 3], [1, 2]]
values = ["s0", "s3", "s1", "s2"]
out = _reorder(values, mb_indices)
assert out == ["s0", "s1", "s2", "s3"]


def test_trailing_empty_partitions_drop_fakes():
# 4 real samples split across 2 real microbatches + 2 empty ones
# (produced by _get_capped_partitions when num_mbs exceeds what the
# local rank needs). Each empty partition contributes one 1-token
# placeholder via get_batch, which becomes a "fake" value in `values`.
mb_indices = [[0, 3], [1, 2], [], []]
values = ["s0", "s3", "s1", "s2", "FAKE0", "FAKE1"]
out = _reorder(values, mb_indices)
# Only real samples at origin positions; fakes are dropped.
assert out == ["s0", "s1", "s2", "s3"]
assert None not in out


def test_all_partitions_empty_returns_empty_list():
# Degenerate rank with 0 real samples (empty micro_batch_indices list).
# Every microbatch is a fake; output is an empty list, not a list of None.
mb_indices = [[], [], []]
values = ["FAKE0", "FAKE1", "FAKE2"]
out = _reorder(values, mb_indices)
assert out == []


def test_empty_partitions_only_at_tail():
# _get_capped_partitions uses first-fit bin-packing: samples always go to
# the lowest-index partition that fits, so any unused capacity appears in
# trailing partitions — never interleaved. This documents that
# invariant; the simpler zip-based reorder relies on it.
# (If KK ever produced mid-schedule empties, len(partition) > 0 would
# assert in get_seqlen_balanced_partitions, not reach this code path.)
mb_indices = [[0, 1, 2], [3], [], []]
values = ["s0", "s1", "s2", "s3", "FAKE0", "FAKE1"]
out = _reorder(values, mb_indices)
assert out == ["s0", "s1", "s2", "s3"]
assert None not in out


def test_fix_vs_pre_fix_behavior_divergence():
"""Sanity check that the OLD reorder (sizing by len(values)) produced
trailing Nones. Documents the specific bug shape."""
mb_indices = [[0, 3], [1, 2], [], []]
values = ["s0", "s3", "s1", "s2", "FAKE0", "FAKE1"]

# Pre-fix (bug): origin_values sized by len(values) → 2 trailing Nones
pre_fix = [None] * len(values)
origin_indices = sum(mb_indices, [])
for value, origin_index in zip(values, origin_indices, strict=False):
pre_fix[origin_index] = value
assert pre_fix[-2:] == [None, None], "pre-fix list should end with Nones"

# Post-fix: origin_values sized by len(origin_indices) → no trailing Nones
post_fix = _reorder(values, mb_indices)
assert None not in post_fix