-
Notifications
You must be signed in to change notification settings - Fork 778
Add rollout sampling-mask support #1795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
93353a6
615c3ee
a956538
6b7d6c5
c999011
01a9f27
5000636
6c2a577
d74fe96
8b67138
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -229,7 +229,7 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: | |
|
|
||
| rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) | ||
|
|
||
| for key in ["rollout_log_probs", "teacher_log_probs"]: | ||
| for key in ["rollout_log_probs", "teacher_log_probs", "sampling_logprob_sum"]: | ||
| if key not in rollout_data: | ||
| continue | ||
| rollout_data[key] = [ | ||
|
|
@@ -253,6 +253,26 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: | |
| ) | ||
| ) | ||
| ] | ||
| # sampling_token_ids: variable-length nested lists, apply CP slicing but keep as lists | ||
| if "sampling_token_ids" in rollout_data: | ||
| key = "sampling_token_ids" | ||
| rollout_data[key] = [ | ||
| slice_log_prob_with_cp( | ||
| token_ids, | ||
| total_length, | ||
| response_length, | ||
| self.args.qkv_format, | ||
| rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, | ||
| ) | ||
|
||
| for i, (token_ids, total_length, response_length) in enumerate( | ||
| zip( | ||
| rollout_data[key], | ||
| rollout_data["total_lengths"], | ||
| rollout_data["response_lengths"], | ||
| strict=False, | ||
| ) | ||
| ) | ||
| ] | ||
| if "rollout_routed_experts" in rollout_data: | ||
| rollout_data["rollout_routed_experts"] = [ | ||
| torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -396,7 +396,8 @@ def log_rollout_data( | |
| - Tensor-valued lists are concatenated and averaged. For token-level metrics | ||
| like log-probs/returns/advantages/values, computes a CP-correct sample mean | ||
| using `loss_masks` and total/response lengths. | ||
| - Non-tensor lists are averaged elementwise. | ||
| - Non-tensor lists are averaged elementwise. ``sampling_token_ids`` is | ||
| summarized by average candidate count per position. | ||
| - Scalars are converted to Python numbers. | ||
| """ | ||
| if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage(): | ||
|
|
@@ -422,6 +423,8 @@ def log_rollout_data( | |
| # There are the following assumptions: | ||
| # - Each dp rank has the same number of samples | ||
| if isinstance(val, (list, tuple)): | ||
| if len(val) == 0: | ||
| continue | ||
| if isinstance(val[0], torch.Tensor): | ||
| # NOTE: Here we have to do the clone().detach(), otherwise the tensor will be | ||
| # modified in place and will cause problem for the next rollout. | ||
|
|
@@ -447,6 +450,15 @@ def log_rollout_data( | |
| else: | ||
| val = torch.cat(val).clone().detach() | ||
| val = val.mean() * cp_size | ||
|
Comment on lines
450
to
452
|
||
| elif key == "sampling_token_ids": | ||
| num_positions = sum(len(sample) for sample in val) | ||
| val = ( | ||
| sum(len(token_ids) for sample in val for token_ids in sample) / num_positions | ||
| if num_positions > 0 | ||
| else 0.0 | ||
| ) | ||
| elif isinstance(val[0], (list, tuple)): | ||
| continue | ||
| else: | ||
| val = sum(val) / len(val) | ||
| elif isinstance(val, torch.Tensor): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,12 +14,14 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| calculate_log_probs_and_entropy, | ||||||||||||||||||||||||||||||||||||||||||||||
| compute_approx_kl, | ||||||||||||||||||||||||||||||||||||||||||||||
| compute_gspo_kl, | ||||||||||||||||||||||||||||||||||||||||||||||
| compute_log_probs, | ||||||||||||||||||||||||||||||||||||||||||||||
| compute_opsm_mask, | ||||||||||||||||||||||||||||||||||||||||||||||
| compute_policy_loss, | ||||||||||||||||||||||||||||||||||||||||||||||
| get_advantages_and_returns_batch, | ||||||||||||||||||||||||||||||||||||||||||||||
| get_grpo_returns, | ||||||||||||||||||||||||||||||||||||||||||||||
| get_reinforce_plus_plus_baseline_advantages, | ||||||||||||||||||||||||||||||||||||||||||||||
| get_reinforce_plus_plus_returns, | ||||||||||||||||||||||||||||||||||||||||||||||
| mask_logits_for_token_ids, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| from slime.utils.types import RolloutBatch | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -297,6 +299,107 @@ def get_log_probs_and_entropy( | |||||||||||||||||||||||||||||||||||||||||||||
| return torch.empty((0,), device=logits.device), res | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def get_masked_log_probs_for_token_ids( | ||||||||||||||||||||||||||||||||||||||||||||||
| logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||||||||
| args: Namespace, | ||||||||||||||||||||||||||||||||||||||||||||||
| unconcat_tokens: list[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||
| total_lengths: list[int], | ||||||||||||||||||||||||||||||||||||||||||||||
| response_lengths: list[int], | ||||||||||||||||||||||||||||||||||||||||||||||
| sampling_token_ids: list[list[list[int]]], | ||||||||||||||||||||||||||||||||||||||||||||||
| max_seq_lens: list[int] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> list[torch.Tensor]: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Compute per-token log-probabilities restricted to sampling token subsets. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| For each sample, masks logits to keep only the tokens in ``sampling_token_ids`` | ||||||||||||||||||||||||||||||||||||||||||||||
| (setting others to ``-inf``), then computes ``log softmax`` over the | ||||||||||||||||||||||||||||||||||||||||||||||
| restricted set. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||
| logits: Policy logits with shape ``[1, T, V]``. | ||||||||||||||||||||||||||||||||||||||||||||||
| args: Configuration (temperature applied in ``get_responses``). | ||||||||||||||||||||||||||||||||||||||||||||||
| unconcat_tokens: Per-sample token tensors. | ||||||||||||||||||||||||||||||||||||||||||||||
| total_lengths: Total sequence lengths per sample. | ||||||||||||||||||||||||||||||||||||||||||||||
| response_lengths: Response segment lengths per sample. | ||||||||||||||||||||||||||||||||||||||||||||||
| sampling_token_ids: Per-sample, per-position list of global token IDs to | ||||||||||||||||||||||||||||||||||||||||||||||
| keep. Shape: ``[num_samples][response_length_or_cp_chunk][variable]``. | ||||||||||||||||||||||||||||||||||||||||||||||
| max_seq_lens: Optional max sequence lengths per sample (for bshd). | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||
| List of ``[R]`` tensors — masked log-probabilities per sample. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| tp_rank = mpu.get_tensor_model_parallel_rank() | ||||||||||||||||||||||||||||||||||||||||||||||
| tp_group = mpu.get_tensor_model_parallel_group() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| masked_log_probs_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||
| for i, (logits_chunk, tokens_chunk) in enumerate( | ||||||||||||||||||||||||||||||||||||||||||||||
| get_responses( | ||||||||||||||||||||||||||||||||||||||||||||||
| logits, | ||||||||||||||||||||||||||||||||||||||||||||||
| args=args, | ||||||||||||||||||||||||||||||||||||||||||||||
| unconcat_tokens=unconcat_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||
| total_lengths=total_lengths, | ||||||||||||||||||||||||||||||||||||||||||||||
| response_lengths=response_lengths, | ||||||||||||||||||||||||||||||||||||||||||||||
| max_seq_lens=max_seq_lens, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| vocab_shard_size = logits_chunk.size(-1) | ||||||||||||||||||||||||||||||||||||||||||||||
| masked_logits = mask_logits_for_token_ids( | ||||||||||||||||||||||||||||||||||||||||||||||
| logits_chunk, sampling_token_ids[i], vocab_shard_size, tp_rank, tokens=tokens_chunk | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| # Clone before compute_log_probs: fused_vocab_parallel_cross_entropy | ||||||||||||||||||||||||||||||||||||||||||||||
| # modifies its input in-place (subtract max, exp, div), which would | ||||||||||||||||||||||||||||||||||||||||||||||
| # corrupt the autograd graph of masked_logits. | ||||||||||||||||||||||||||||||||||||||||||||||
| masked_lp = compute_log_probs(masked_logits.clone(), tokens_chunk, tp_group) | ||||||||||||||||||||||||||||||||||||||||||||||
| masked_log_probs_list.append(masked_lp.squeeze(-1)) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if args.allgather_cp: | ||||||||||||||||||||||||||||||||||||||||||||||
| res = {"log_probs": masked_log_probs_list} | ||||||||||||||||||||||||||||||||||||||||||||||
| _allgather_cp_redistribute( | ||||||||||||||||||||||||||||||||||||||||||||||
| res, | ||||||||||||||||||||||||||||||||||||||||||||||
| logits=logits, | ||||||||||||||||||||||||||||||||||||||||||||||
| args=args, | ||||||||||||||||||||||||||||||||||||||||||||||
| total_lengths=total_lengths, | ||||||||||||||||||||||||||||||||||||||||||||||
| response_lengths=response_lengths, | ||||||||||||||||||||||||||||||||||||||||||||||
| max_seq_lens=max_seq_lens, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| masked_log_probs_list = res["log_probs"] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return masked_log_probs_list | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def apply_sampling_mask_to_log_probs( | ||||||||||||||||||||||||||||||||||||||||||||||
| args: Namespace, | ||||||||||||||||||||||||||||||||||||||||||||||
| batch: RolloutBatch, | ||||||||||||||||||||||||||||||||||||||||||||||
| logits: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||
| log_probs: list[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||
| old_log_probs: list[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||
| total_lengths: list[int], | ||||||||||||||||||||||||||||||||||||||||||||||
| response_lengths: list[int], | ||||||||||||||||||||||||||||||||||||||||||||||
| max_seq_lens: list[int] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Apply rollout sampling-mask normalization to train-time log-probabilities.""" | ||||||||||||||||||||||||||||||||||||||||||||||
| sampling_token_ids = batch.get("sampling_token_ids") | ||||||||||||||||||||||||||||||||||||||||||||||
| mask_logprob_sum = batch.get("sampling_logprob_sum") | ||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||
| (not getattr(args, "use_topp_mask", False) and not getattr(args, "use_topk_mask", False)) | ||||||||||||||||||||||||||||||||||||||||||||||
| or sampling_token_ids is None | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| return log_probs, old_log_probs | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| # When sampling masks are enabled, ensure that sampling_logprob_sum exists | |
| # and is aligned with old_log_probs so that zip() does not silently truncate | |
| # or iterate over a None value. | |
| if mask_logprob_sum is None: | |
| raise ValueError( | |
| "batch['sampling_logprob_sum'] must be provided when sampling masks are enabled " | |
| "and 'sampling_token_ids' is present." | |
| ) | |
| if len(mask_logprob_sum) != len(old_log_probs): | |
| raise ValueError( | |
| "batch['sampling_logprob_sum'] must have the same number of elements as " | |
| "'old_log_probs' when sampling masks are enabled: " | |
| f"{len(mask_logprob_sum)} != {len(old_log_probs)}" | |
| ) | |
| for idx, (olp, tlse) in enumerate(zip(old_log_probs, mask_logprob_sum)): | |
| if olp.shape != tlse.shape: | |
| raise ValueError( | |
| "Shape mismatch between old_log_probs and sampling_logprob_sum at index " | |
| f"{idx}: {olp.shape} != {tlse.shape}" | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normalize old log-probs over the same masked support
This subtracts sampling_logprob_sum directly, but that sum is built from output_top_logprobs candidates only; when rollout_top_logprobs_num is too small, the sampled token can be missing from that set. The new-policy side explicitly keeps sampled tokens in the mask, so the two sides can end up normalized over different supports, biasing PPO ratios/KL instead of preserving rollout-training consistency. Include sampled-token probability in the stored sum (or adjust it here) before computing masked_old_log_probs.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This unconditionally applies
slice_log_prob_with_cptosampling_token_ids, but in--allgather_cpmode the training path first consumes contiguous CP chunks and only later redistributes to zigzag layout. That means the per-position candidate lists are misaligned withget_responses(..., allgather_cp=True)output, which can produce out-of-bounds indexing or incorrect token masks whenuse_topp_mask/use_topk_maskis enabled with CP>1. Guard this branch forallgather_cp(or defer slicing until after redistribution) so mask positions match the logits chunk layout.Useful? React with 👍 / 👎.