Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,263 changes: 1,025 additions & 238 deletions docker/patch/latest/sglang.patch

Large diffs are not rendered by default.

1,263 changes: 1,025 additions & 238 deletions docker/patch/v0.5.9/sglang.patch

Large diffs are not rendered by default.

24 changes: 23 additions & 1 deletion slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:

rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"])

for key in ["rollout_log_probs", "teacher_log_probs"]:
for key in ["rollout_log_probs", "teacher_log_probs", "sampling_logprob_sum"]:
if key not in rollout_data:
continue
rollout_data[key] = [
Expand All @@ -253,6 +253,28 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
)
)
]
# sampling_token_ids: variable-length nested lists, apply zigzag CP slicing but keep as lists.
# For allgather_cp, skip zigzag slicing — contiguous slicing is done at training time
# inside get_masked_log_probs_for_token_ids to match the allgather logits layout.
if "sampling_token_ids" in rollout_data and not self.args.allgather_cp:
key = "sampling_token_ids"
rollout_data[key] = [
slice_log_prob_with_cp(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip zigzag CP slicing for sampling ids in allgather mode

This unconditionally applies slice_log_prob_with_cp to sampling_token_ids, but in --allgather_cp mode the training path first consumes contiguous CP chunks and only later redistributes to zigzag layout. That means the per-position candidate lists are misaligned with get_responses(..., allgather_cp=True) output, which can produce out-of-bounds indexing or incorrect token masks when use_topp_mask/use_topk_mask is enabled with CP>1. Guard this branch for allgather_cp (or defer slicing until after redistribution) so mask positions match the logits chunk layout.

Useful? React with 👍 / 👎.

token_ids,
total_length,
response_length,
self.args.qkv_format,
rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None,
)
for i, (token_ids, total_length, response_length) in enumerate(
zip(
rollout_data[key],
rollout_data["total_lengths"],
rollout_data["response_lengths"],
strict=False,
)
)
]
if "rollout_routed_experts" in rollout_data:
rollout_data["rollout_routed_experts"] = [
torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"]
Expand Down
18 changes: 17 additions & 1 deletion slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ def log_rollout_data(
- Tensor-valued lists are concatenated and averaged. For token-level metrics
like log-probs/returns/advantages/values, computes a CP-correct sample mean
using `loss_masks` and total/response lengths.
- Non-tensor lists are averaged elementwise.
- Non-tensor lists are averaged elementwise. ``sampling_token_ids`` is
summarized by average candidate count per position.
- Scalars are converted to Python numbers.
"""
if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage():
Expand All @@ -422,6 +423,8 @@ def log_rollout_data(
# There are the following assumptions:
# - Each dp rank has the same number of samples
if isinstance(val, (list, tuple)):
if len(val) == 0:
continue
if isinstance(val[0], torch.Tensor):
# NOTE: Here we have to do the clone().detach(), otherwise the tensor will be
# modified in place and will cause problem for the next rollout.
Expand All @@ -447,6 +450,19 @@ def log_rollout_data(
else:
val = torch.cat(val).clone().detach()
val = val.mean() * cp_size
Comment on lines 450 to 452
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_logprob_sum (introduced for sampling-mask support) will hit the generic tensor aggregation path here (torch.cat(val).mean() * cp_size) because it isn’t handled in the token-level metrics branch. Since it’s per-token and CP-sliced, aggregating it like a scalar can produce misleading values under CP. Consider treating it like other token-level tensors (use get_sum_of_sample_mean(...)) or explicitly skipping it in this logger.

Copilot uses AI. Check for mistakes.
elif key == "sampling_token_ids":
num_positions = sum(len(sample) for sample in val)
val = (
sum(len(token_ids) for sample in val for token_ids in sample) / num_positions
if num_positions > 0
else 0.0
)
elif key == "sampling_logprob_sum":
# Per-token tensor; skip generic aggregation — already
# summarized indirectly via sampling_token_ids metrics.
continue
elif isinstance(val[0], (list, tuple)):
continue
else:
val = sum(val) / len(val)
elif isinstance(val, torch.Tensor):
Expand Down
149 changes: 149 additions & 0 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
calculate_log_probs_and_entropy,
compute_approx_kl,
compute_gspo_kl,
compute_log_probs,
compute_opsm_mask,
compute_policy_loss,
get_advantages_and_returns_batch,
get_grpo_returns,
get_reinforce_plus_plus_baseline_advantages,
get_reinforce_plus_plus_returns,
mask_logits_for_token_ids,
)
from slime.utils.types import RolloutBatch

Expand Down Expand Up @@ -297,6 +299,141 @@ def get_log_probs_and_entropy(
return torch.empty((0,), device=logits.device), res


def get_masked_log_probs_for_token_ids(
logits: torch.Tensor,
*,
args: Namespace,
unconcat_tokens: list[torch.Tensor],
total_lengths: list[int],
response_lengths: list[int],
sampling_token_ids: list[list[list[int]]],
max_seq_lens: list[int] | None = None,
) -> list[torch.Tensor]:
"""Compute per-token log-probabilities restricted to sampling token subsets.

For each sample, masks logits to keep only the tokens in ``sampling_token_ids``
(setting others to ``-inf``), then computes ``log softmax`` over the
restricted set.

Args:
logits: Policy logits with shape ``[1, T, V]``.
args: Configuration (temperature applied in ``get_responses``).
unconcat_tokens: Per-sample token tensors.
total_lengths: Total sequence lengths per sample.
response_lengths: Response segment lengths per sample.
sampling_token_ids: Per-sample, per-position list of global token IDs to
keep. Shape: ``[num_samples][response_length_or_cp_chunk][variable]``.
max_seq_lens: Optional max sequence lengths per sample (for bshd).

Returns:
List of ``[R]`` tensors — masked log-probabilities per sample.
"""
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_group = mpu.get_tensor_model_parallel_group()

# For allgather_cp, sampling_token_ids is NOT pre-sliced (full response length).
# Compute contiguous CP offsets to extract positions matching each logits chunk.
_allgather_cp = getattr(args, "allgather_cp", False) and mpu.get_context_parallel_world_size() > 1
if _allgather_cp:
_logits_local_len = logits.view(-1, logits.size(-1)).size(0)
_cp_rank = mpu.get_context_parallel_rank()
_chunk_start = _cp_rank * _logits_local_len
_chunk_end = _chunk_start + _logits_local_len

masked_log_probs_list = []
_seq_start = 0
for i, (logits_chunk, tokens_chunk) in enumerate(
get_responses(
logits,
args=args,
unconcat_tokens=unconcat_tokens,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
)
):
# Determine per-position token IDs for masking.
if _allgather_cp:
prompt_length = total_lengths[i] - response_lengths[i]
logit_global_start = _seq_start + prompt_length - 1
logit_global_end = _seq_start + total_lengths[i] - 1
s = max(logit_global_start, _chunk_start)
e = min(logit_global_end, _chunk_end)
if e <= s:
per_pos_ids = []
else:
resp_start = s - logit_global_start
resp_end = e - logit_global_start
per_pos_ids = sampling_token_ids[i][resp_start:resp_end]
else:
per_pos_ids = sampling_token_ids[i]
_seq_start += total_lengths[i]

vocab_shard_size = logits_chunk.size(-1)
masked_logits = mask_logits_for_token_ids(logits_chunk, per_pos_ids, vocab_shard_size, tp_rank)
# Clone before compute_log_probs: fused_vocab_parallel_cross_entropy
# modifies its input in-place (subtract max, exp, div), which would
# corrupt the autograd graph of masked_logits.
masked_lp = compute_log_probs(masked_logits.clone(), tokens_chunk, tp_group)
masked_log_probs_list.append(masked_lp.squeeze(-1))

if args.allgather_cp:
res = {"log_probs": masked_log_probs_list}
_allgather_cp_redistribute(
res,
logits=logits,
args=args,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
)
masked_log_probs_list = res["log_probs"]

return masked_log_probs_list


def apply_sampling_mask_to_log_probs(
args: Namespace,
batch: RolloutBatch,
logits: torch.Tensor,
log_probs: list[torch.Tensor],
old_log_probs: list[torch.Tensor],
total_lengths: list[int],
response_lengths: list[int],
max_seq_lens: list[int] | None = None,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Apply rollout sampling-mask normalization to train-time log-probabilities."""
sampling_token_ids = batch.get("sampling_token_ids")
mask_logprob_sum = batch.get("sampling_logprob_sum")
if (
not getattr(args, "use_topp_mask", False) and not getattr(args, "use_topk_mask", False)
) or sampling_token_ids is None:
return log_probs, old_log_probs

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_sampling_mask_to_log_probs reads sampling_logprob_sum but only guards on sampling_token_ids being present. If sampling_logprob_sum is missing (or not aligned), the later zip(old_log_probs, mask_logprob_sum) will throw or silently truncate. Add validation that sampling_logprob_sum exists and matches the per-sample/per-token structure whenever sampling masks are enabled.

Suggested change
# When sampling masks are enabled, ensure that sampling_logprob_sum exists
# and is aligned with old_log_probs so that zip() does not silently truncate
# or iterate over a None value.
if mask_logprob_sum is None:
raise ValueError(
"batch['sampling_logprob_sum'] must be provided when sampling masks are enabled "
"and 'sampling_token_ids' is present."
)
if len(mask_logprob_sum) != len(old_log_probs):
raise ValueError(
"batch['sampling_logprob_sum'] must have the same number of elements as "
"'old_log_probs' when sampling masks are enabled: "
f"{len(mask_logprob_sum)} != {len(old_log_probs)}"
)
for idx, (olp, tlse) in enumerate(zip(old_log_probs, mask_logprob_sum)):
if olp.shape != tlse.shape:
raise ValueError(
"Shape mismatch between old_log_probs and sampling_logprob_sum at index "
f"{idx}: {olp.shape} != {tlse.shape}"
)

Copilot uses AI. Check for mistakes.
if mask_logprob_sum is None:
raise ValueError(
"batch['sampling_logprob_sum'] must be provided when sampling masks are enabled "
"and 'sampling_token_ids' is present."
)
if len(mask_logprob_sum) != len(old_log_probs):
raise ValueError(
f"sampling_logprob_sum has {len(mask_logprob_sum)} samples but " f"old_log_probs has {len(old_log_probs)}"
)

masked_log_probs = get_masked_log_probs_for_token_ids(
logits,
args=args,
unconcat_tokens=batch["unconcat_tokens"],
total_lengths=total_lengths,
response_lengths=response_lengths,
sampling_token_ids=sampling_token_ids,
max_seq_lens=max_seq_lens,
)

masked_old_log_probs = [olp - tlse for olp, tlse in zip(old_log_probs, mask_logprob_sum, strict=True)]
return masked_log_probs, masked_old_log_probs


def get_values(
logits: torch.Tensor,
*,
Expand Down Expand Up @@ -659,6 +796,18 @@ def policy_loss_function(

log_probs = log_probs_and_entropy["log_probs"]

if getattr(args, "use_topp_mask", False) or getattr(args, "use_topk_mask", False):
log_probs, old_log_probs = apply_sampling_mask_to_log_probs(
args,
batch,
logits,
log_probs,
old_log_probs,
total_lengths=total_lengths,
response_lengths=response_lengths,
max_seq_lens=max_seq_lens,
)

# Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering
need_full_log_probs = args.use_opsm or args.advantage_estimator == "gspo"

Expand Down
2 changes: 2 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
"rollout_log_probs",
"max_seq_lens",
"teacher_log_probs",
"sampling_token_ids",
"sampling_logprob_sum",
],
args.data_pad_size_multiplier,
args.qkv_format,
Expand Down
9 changes: 9 additions & 0 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,12 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
if samples[0].teacher_log_probs is not None:
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]

if samples[0].sampling_token_ids is not None:
train_data["sampling_token_ids"] = [sample.sampling_token_ids for sample in samples]

if samples[0].sampling_logprob_sum is not None:
train_data["sampling_logprob_sum"] = [sample.sampling_logprob_sum for sample in samples]

return train_data

def set_train_parallel_config(self, config: dict):
Expand Down Expand Up @@ -782,6 +788,8 @@ def _split_train_data_by_dp(self, data, dp_size):
"rollout_routed_experts",
"prompt",
"teacher_log_probs",
"sampling_token_ids",
"sampling_logprob_sum",
]:
if key not in data:
continue
Expand Down Expand Up @@ -1200,6 +1208,7 @@ def compute_metrics_from_samples(args, samples):
log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/")
log_dict |= _compute_zero_std_metrics(args, samples)
log_dict |= _compute_reward_cat_metrics(args, samples)

log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item()
log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item()
return log_dict
Expand Down
Loading
Loading