diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index 6e251a092..ebacd33a7 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -5,8 +5,7 @@ if TYPE_CHECKING: from roll.distributed.scheduler.protocol import DataProto -import enum -import traceback + import heapq from typing import Dict, List, Optional, Tuple, Union @@ -225,8 +224,14 @@ def entropy_from_logits(logits: torch.Tensor): return entropy -def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, batch_num_tokens: int = None, - global_valid_samples: int = None, weights: Optional[torch.Tensor] = None): +def agg_loss( + loss_mat: torch.Tensor, + loss_mask: torch.Tensor, + loss_agg_mode: str, + batch_num_tokens: int = None, + global_valid_samples: int = None, + weights: Optional[torch.Tensor] = None, +): """ ref: https://github.com/volcengine/verl/blob/78532923368aeb058f62201489546d013df47710/verl/trainer/ppo/core_algos.py#L370 Aggregate the loss matrix into a scalar. @@ -288,6 +293,7 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> to else: return (tensor * mask).sum() / (mask.sum() + 1e-8) + def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor: if dim is not None: mask_sum = mask.sum(axis=dim) @@ -445,6 +451,7 @@ def _parse_aggregation_func(metric_name: str): return metrics + def reduce_metrics_list(metrics_list: list, reduce_func=np.mean) -> dict: if len(metrics_list) == 0: return {} @@ -574,7 +581,7 @@ def reward_norm( reward_mean = reshape_reward.mean(dim=-1, keepdim=True) elif norm_mean_type == "running": reward_mean = running.mean - elif norm_mean_type == None: + elif norm_mean_type is None: reward_mean = 0.0 # 标准差计算 if norm_std_type == "batch": @@ -589,7 +596,7 @@ def reward_norm( if norm_std_type is not None: normalized_rewards = (rewards - reward_mean) / (reward_std + 1e-6) else: - normalized_rewards = (rewards - reward_mean) + normalized_rewards = rewards - reward_mean # 如果是对 group mean 归一化,需要恢复原始形状 if norm_mean_type == "group": @@ -653,12 +660,12 @@ def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_c pipeline_config.norm_mean_type, pipeline_config.norm_std_type = "group", "group" response_level_rewards = reward_norm( - response_level_rewards, - n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences, - running_ctrl=running_ctrl, - norm_mean_type=pipeline_config.norm_mean_type, - norm_std_type=pipeline_config.norm_std_type - ) + response_level_rewards, + n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences, + running_ctrl=running_ctrl, + norm_mean_type=pipeline_config.norm_mean_type, + norm_std_type=pipeline_config.norm_std_type, + ) # 对reward进行clip if pipeline_config.reward_clip: @@ -798,7 +805,9 @@ def compute_advantage( kld = None if is_pure_opd or use_opd: kld = compute_approx_kl( - log_probs=data.batch["old_log_probs"] if getattr(pipeline_config, "enable_old_logprobs_recompute", False) else data.batch["infer_logprobs"], + log_probs=data.batch["old_log_probs"] + if getattr(pipeline_config, "enable_old_logprobs_recompute", False) + else data.batch["infer_logprobs"], log_probs_base=data.batch["ref_log_probs"], action_mask=response_mask, kl_penalty=getattr(pipeline_config, "kl_penalty", "kl"), @@ -817,7 +826,8 @@ def compute_advantage( data.batch["token_level_rewards"] = token_level_rewards if adv_estimator == "gae": values = data.batch["values"].float() - data.batch["values"] = values * response_mask + values = values * response_mask + data.batch["values"] = values advantages, returns = compute_gae_advantage_return( token_level_rewards=token_level_rewards, values=values, gamma=gamma, lambd=lambd ) @@ -848,6 +858,7 @@ def compute_advantage( data.batch["returns"] = returns return data + def postprocess_generate( prompts: "DataProto", output: torch.Tensor, @@ -991,6 +1002,7 @@ def separate_prompt_response( response_ids = torch.where(response_mask_valid, input_ids, torch.full_like(input_ids, pad_id)) return prompt_ids, response_ids + def filter_func_args(func, forward_args): signature = inspect.signature(func) forward_params = signature.parameters.keys() @@ -1119,9 +1131,9 @@ def adjust_sequence_length(sequence, target_length, origin_seq_len, pad_value=0) return sequence[tuple(slices)] -def get_seqlen_balanced_partitions(seqlen_list: List[float], - k_partitions: int, - equal_size: bool = False) -> List[List[int]]: +def get_seqlen_balanced_partitions( + seqlen_list: List[float], k_partitions: int, equal_size: bool = False +) -> List[List[int]]: """ Reference: https://github.com/volcengine/verl/blob/468adf22c43b744348051fccd7a5d830c6c3c36a/verl/utils/seqlen_balancing.py @@ -1193,16 +1205,14 @@ def __lt__(self, other): return self.spread > other.spread return self.sets[0] > other.sets[0] - assert len(seqlen_list) >= k_partitions, \ - f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" + assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" # Sort by sequence length sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) states_pq = [] if equal_size: - assert len(seqlen_list) % k_partitions == 0, \ - f"{len(seqlen_list)} % {k_partitions} != 0" + assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" for offset in range(0, len(sorted_seqlen_list), k_partitions): items = [] for i in range(k_partitions): @@ -1262,7 +1272,7 @@ def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], pr # Iterate over each batch of sequence lengths for offset in range(0, len(seqlen_list), batch_size): - cur_sum_seqlen = sum(seqlen_list[offset: offset + batch_size]) + cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: min_sum_seqlen = cur_sum_seqlen if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen: @@ -1305,16 +1315,14 @@ def calculate_workload(seq_len_list): global_partition_lst = [[] for _ in range(world_size)] for i in range(minibatch_num): rearrange_minibatch_lst = get_seqlen_balanced_partitions( - workload_lst[i * minibatch_size: (i + 1) * minibatch_size], + workload_lst[i * minibatch_size : (i + 1) * minibatch_size], k_partitions=world_size, equal_size=True, ) for j, part in enumerate(rearrange_minibatch_lst): global_partition_lst[j].extend([x + minibatch_size * i for x in part]) else: - global_partition_lst = get_seqlen_balanced_partitions( - workload_lst, k_partitions=world_size, equal_size=True - ) + global_partition_lst = get_seqlen_balanced_partitions(workload_lst, k_partitions=world_size, equal_size=True) # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. for idx, partition in enumerate(global_partition_lst): partition.sort(key=lambda x: (workload_lst[x], x)) @@ -1329,4 +1337,3 @@ def calculate_workload(seq_len_list): metrics = {} metrics.update(global_balance_stats) return metrics - diff --git a/tests/utils/test_functionals.py b/tests/utils/test_functionals.py index 3b86c5f54..90dee96d6 100644 --- a/tests/utils/test_functionals.py +++ b/tests/utils/test_functionals.py @@ -3,8 +3,15 @@ import numpy as np import pytest import torch +from tensordict import TensorDict -from roll.utils.functionals import traverse_obj, divide_by_chunk_size, pad_to_length +from roll.distributed.scheduler.protocol import DataProto +from roll.utils.functionals import ( + compute_advantage, + divide_by_chunk_size, + pad_to_length, + traverse_obj, +) def visitor(obj: object, path: Tuple): @@ -22,6 +29,7 @@ def __init__(self): "nested_key1": torch.tensor([[1, 2], [3, 4]]), "nested_key2": [torch.tensor(5), np.array([6, 7])], } + class CustomObject: def __init__(self): self.attr1 = torch.tensor([1, 2, 3]) @@ -55,5 +63,35 @@ def test_pad_to_length(): print(padded_tensor) +def test_compute_advantage_masks_values_before_gae_bootstrap(): + response_mask = torch.tensor([[1.0, 1.0, 0.0]]) + token_level_rewards = torch.tensor([[0.0, 1.0, 0.0]]) + values = torch.tensor([[0.0, 0.0, 100.0]]) + data = DataProto( + batch=TensorDict( + { + "response_mask": response_mask.clone(), + "token_level_rewards": token_level_rewards.clone(), + "values": values.clone(), + }, + batch_size=[1], + ), + meta_info={}, + ) + + compute_advantage( + data=data, + gamma=torch.tensor(1.0), + lambd=torch.tensor(0.95), + adv_estimator="gae", + response_mask=response_mask, + ) + + expected = torch.tensor([[0.95, 1.0, 0.0]]) + torch.testing.assert_close(data.batch["values"], values * response_mask) + torch.testing.assert_close(data.batch["advantages"], expected) + torch.testing.assert_close(data.batch["returns"], expected) + + if __name__ == "__main__": pytest.main()