Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:

rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"])

for key in ["rollout_log_probs", "teacher_log_probs"]:
for key in ["rollout_log_probs", "teacher_log_probs", "sampling_logprob_sum"]:
if key not in rollout_data:
continue
rollout_data[key] = [
Expand All @@ -253,6 +253,26 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
)
)
]
# sampling_token_ids: variable-length nested lists, apply CP slicing but keep as lists
if "sampling_token_ids" in rollout_data:
key = "sampling_token_ids"
rollout_data[key] = [
slice_log_prob_with_cp(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip zigzag CP slicing for sampling ids in allgather mode

This unconditionally applies slice_log_prob_with_cp to sampling_token_ids, but in --allgather_cp mode the training path first consumes contiguous CP chunks and only later redistributes to zigzag layout. That means the per-position candidate lists are misaligned with get_responses(..., allgather_cp=True) output, which can produce out-of-bounds indexing or incorrect token masks when use_topp_mask/use_topk_mask is enabled with CP>1. Guard this branch for allgather_cp (or defer slicing until after redistribution) so mask positions match the logits chunk layout.

Useful? React with 👍 / 👎.

token_ids,
total_length,
response_length,
self.args.qkv_format,
rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None,
)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_token_ids is CP-sliced with slice_log_prob_with_cp, which produces the zigzag CP layout. However, when --allgather-cp is enabled, get_responses() (used by sampling-mask logprob computation) operates on contiguous per-rank chunks before _allgather_cp_redistribute. This mismatch means the per-position sampling_token_ids won’t line up with the logits positions being masked. Either (a) add a dedicated slicing path for sampling_token_ids under allgather_cp that matches the contiguous layout, or (b) explicitly disallow sampling masks with allgather_cp via argument validation.

Copilot uses AI. Check for mistakes.
for i, (token_ids, total_length, response_length) in enumerate(
zip(
rollout_data[key],
rollout_data["total_lengths"],
rollout_data["response_lengths"],
strict=False,
)
)
]
if "rollout_routed_experts" in rollout_data:
rollout_data["rollout_routed_experts"] = [
torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"]
Expand Down
14 changes: 13 additions & 1 deletion slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ def log_rollout_data(
- Tensor-valued lists are concatenated and averaged. For token-level metrics
like log-probs/returns/advantages/values, computes a CP-correct sample mean
using `loss_masks` and total/response lengths.
- Non-tensor lists are averaged elementwise.
- Non-tensor lists are averaged elementwise. ``sampling_token_ids`` is
summarized by average candidate count per position.
- Scalars are converted to Python numbers.
"""
if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage():
Expand All @@ -422,6 +423,8 @@ def log_rollout_data(
# There are the following assumptions:
# - Each dp rank has the same number of samples
if isinstance(val, (list, tuple)):
if len(val) == 0:
continue
if isinstance(val[0], torch.Tensor):
# NOTE: Here we have to do the clone().detach(), otherwise the tensor will be
# modified in place and will cause problem for the next rollout.
Expand All @@ -447,6 +450,15 @@ def log_rollout_data(
else:
val = torch.cat(val).clone().detach()
val = val.mean() * cp_size
Comment on lines 450 to 452
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_logprob_sum (introduced for sampling-mask support) will hit the generic tensor aggregation path here (torch.cat(val).mean() * cp_size) because it isn’t handled in the token-level metrics branch. Since it’s per-token and CP-sliced, aggregating it like a scalar can produce misleading values under CP. Consider treating it like other token-level tensors (use get_sum_of_sample_mean(...)) or explicitly skipping it in this logger.

Copilot uses AI. Check for mistakes.
elif key == "sampling_token_ids":
num_positions = sum(len(sample) for sample in val)
val = (
sum(len(token_ids) for sample in val for token_ids in sample) / num_positions
if num_positions > 0
else 0.0
)
elif isinstance(val[0], (list, tuple)):
continue
else:
val = sum(val) / len(val)
elif isinstance(val, torch.Tensor):
Expand Down
114 changes: 114 additions & 0 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
calculate_log_probs_and_entropy,
compute_approx_kl,
compute_gspo_kl,
compute_log_probs,
compute_opsm_mask,
compute_policy_loss,
get_advantages_and_returns_batch,
get_grpo_returns,
get_reinforce_plus_plus_baseline_advantages,
get_reinforce_plus_plus_returns,
mask_logits_for_token_ids,
)
from slime.utils.types import RolloutBatch

Expand Down Expand Up @@ -297,6 +299,107 @@ def get_log_probs_and_entropy(
return torch.empty((0,), device=logits.device), res


def get_masked_log_probs_for_token_ids(
logits: torch.Tensor,
*,
args: Namespace,
unconcat_tokens: list[torch.Tensor],
total_lengths: list[int],
response_lengths: list[int],
sampling_token_ids: list[list[list[int]]],
max_seq_lens: list[int] | None = None,
) -> list[torch.Tensor]:
"""Compute per-token log-probabilities restricted to sampling token subsets.

For each sample, masks logits to keep only the tokens in ``sampling_token_ids``
(setting others to ``-inf``), then computes ``log softmax`` over the
restricted set.

Args:
logits: Policy logits with shape ``[1, T, V]``.
args: Configuration (temperature applied in ``get_responses``).
unconcat_tokens: Per-sample token tensors.
total_lengths: Total sequence lengths per sample.
response_lengths: Response segment lengths per sample.
sampling_token_ids: Per-sample, per-position list of global token IDs to
keep. Shape: ``[num_samples][response_length_or_cp_chunk][variable]``.
max_seq_lens: Optional max sequence lengths per sample (for bshd).

Returns:
List of ``[R]`` tensors — masked log-probabilities per sample.
"""
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_group = mpu.get_tensor_model_parallel_group()

masked_log_probs_list = []
for i, (logits_chunk, tokens_chunk) in enumerate(
get_responses(
logits,
args=args,
unconcat_tokens=unconcat_tokens,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
)
):
vocab_shard_size = logits_chunk.size(-1)
masked_logits = mask_logits_for_token_ids(
logits_chunk, sampling_token_ids[i], vocab_shard_size, tp_rank, tokens=tokens_chunk
)
# Clone before compute_log_probs: fused_vocab_parallel_cross_entropy
# modifies its input in-place (subtract max, exp, div), which would
# corrupt the autograd graph of masked_logits.
masked_lp = compute_log_probs(masked_logits.clone(), tokens_chunk, tp_group)
masked_log_probs_list.append(masked_lp.squeeze(-1))

if args.allgather_cp:
res = {"log_probs": masked_log_probs_list}
_allgather_cp_redistribute(
res,
logits=logits,
args=args,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
)
masked_log_probs_list = res["log_probs"]

return masked_log_probs_list


def apply_sampling_mask_to_log_probs(
args: Namespace,
batch: RolloutBatch,
logits: torch.Tensor,
log_probs: list[torch.Tensor],
old_log_probs: list[torch.Tensor],
total_lengths: list[int],
response_lengths: list[int],
max_seq_lens: list[int] | None = None,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Apply rollout sampling-mask normalization to train-time log-probabilities."""
sampling_token_ids = batch.get("sampling_token_ids")
mask_logprob_sum = batch.get("sampling_logprob_sum")
if (
(not getattr(args, "use_topp_mask", False) and not getattr(args, "use_topk_mask", False))
or sampling_token_ids is None
):
return log_probs, old_log_probs

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_sampling_mask_to_log_probs reads sampling_logprob_sum but only guards on sampling_token_ids being present. If sampling_logprob_sum is missing (or not aligned), the later zip(old_log_probs, mask_logprob_sum) will throw or silently truncate. Add validation that sampling_logprob_sum exists and matches the per-sample/per-token structure whenever sampling masks are enabled.

Suggested change
# When sampling masks are enabled, ensure that sampling_logprob_sum exists
# and is aligned with old_log_probs so that zip() does not silently truncate
# or iterate over a None value.
if mask_logprob_sum is None:
raise ValueError(
"batch['sampling_logprob_sum'] must be provided when sampling masks are enabled "
"and 'sampling_token_ids' is present."
)
if len(mask_logprob_sum) != len(old_log_probs):
raise ValueError(
"batch['sampling_logprob_sum'] must have the same number of elements as "
"'old_log_probs' when sampling masks are enabled: "
f"{len(mask_logprob_sum)} != {len(old_log_probs)}"
)
for idx, (olp, tlse) in enumerate(zip(old_log_probs, mask_logprob_sum)):
if olp.shape != tlse.shape:
raise ValueError(
"Shape mismatch between old_log_probs and sampling_logprob_sum at index "
f"{idx}: {olp.shape} != {tlse.shape}"
)

Copilot uses AI. Check for mistakes.
masked_log_probs = get_masked_log_probs_for_token_ids(
logits,
args=args,
unconcat_tokens=batch["unconcat_tokens"],
total_lengths=total_lengths,
response_lengths=response_lengths,
sampling_token_ids=sampling_token_ids,
max_seq_lens=max_seq_lens,
)

masked_old_log_probs = [olp - tlse for olp, tlse in zip(old_log_probs, mask_logprob_sum)]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize old log-probs over the same masked support

This subtracts sampling_logprob_sum directly, but that sum is built from output_top_logprobs candidates only; when rollout_top_logprobs_num is too small, the sampled token can be missing from that set. The new-policy side explicitly keeps sampled tokens in the mask, so the two sides can end up normalized over different supports, biasing PPO ratios/KL instead of preserving rollout-training consistency. Include sampled-token probability in the stored sum (or adjust it here) before computing masked_old_log_probs.

Useful? React with 👍 / 👎.

return masked_log_probs, masked_old_log_probs


def get_values(
logits: torch.Tensor,
*,
Expand Down Expand Up @@ -659,6 +762,17 @@ def policy_loss_function(

log_probs = log_probs_and_entropy["log_probs"]

log_probs, old_log_probs = apply_sampling_mask_to_log_probs(
args,
batch,
logits,
log_probs,
old_log_probs,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
)

# Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering
need_full_log_probs = args.use_opsm or args.advantage_estimator == "gspo"

Expand Down
2 changes: 2 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"rollout_log_probs",
"max_seq_lens",
"teacher_log_probs",
"sampling_token_ids",
"sampling_logprob_sum",
],
args.data_pad_size_multiplier,
args.qkv_format,
Expand Down
31 changes: 30 additions & 1 deletion slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,12 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
if samples[0].teacher_log_probs is not None:
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]

if samples[0].sampling_token_ids is not None:
train_data["sampling_token_ids"] = [sample.sampling_token_ids for sample in samples]

if samples[0].sampling_logprob_sum is not None:
train_data["sampling_logprob_sum"] = [sample.sampling_logprob_sum for sample in samples]

return train_data

def set_train_parallel_config(self, config: dict):
Expand Down Expand Up @@ -782,6 +788,8 @@ def _split_train_data_by_dp(self, data, dp_size):
"rollout_routed_experts",
"prompt",
"teacher_log_probs",
"sampling_token_ids",
"sampling_logprob_sum",
]:
if key not in data:
continue
Expand Down Expand Up @@ -1192,14 +1200,14 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_
log_dict["rollout/step"] = step
logging_utils.log(args, log_dict, step_key="rollout/step")


def compute_metrics_from_samples(args, samples):
response_lengths = [sample.effective_response_length for sample in samples]

log_dict = {}
log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/")
log_dict |= _compute_zero_std_metrics(args, samples)
log_dict |= _compute_reward_cat_metrics(args, samples)
log_dict |= _compute_topp_mask_metrics(args, samples)
log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item()
log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item()
return log_dict
Expand Down Expand Up @@ -1284,3 +1292,24 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]):
samples_of_reward_cat = group_by(all_samples, lambda s: s.reward[reward_cat_key])

return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()}


def _compute_topp_mask_metrics(args, all_samples: list[Sample]):
if not getattr(args, "use_topp_mask", False) and not getattr(args, "use_topk_mask", False):
return {}
if not getattr(args, "use_topp_mask", False):
return {"sampling_mask_topk_fallback_ratio": 0.0}

positions = [
(len(token_ids), np.exp(logprob_sum).item())
for sample in all_samples
if sample.sampling_token_ids is not None and sample.sampling_logprob_sum is not None
for token_ids, logprob_sum in zip(sample.sampling_token_ids, sample.sampling_logprob_sum, strict=False)
]
if not positions:
return {}
topk_fallback = sum(
candidate_size >= args.rollout_top_logprobs_num and prob + 1e-6 < args.rollout_top_p
for candidate_size, prob in positions
)
return {"sampling_mask_topk_fallback_ratio": topk_fallback / len(positions)}
Loading
Loading