diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 8001ceb73d..f21b415e99 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -1,5 +1,5 @@ diff --git a/.codespellrc b/.codespellrc -index 808a344b4..a34624958 100644 +index 808a344..a346249 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ @@ -8,7 +8,7 @@ index 808a344b4..a34624958 100644 +ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, medias skip = *.json,*.jsonl,*.patch,*.txt diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index 6fbd1db82..4c681b58d 100644 +index 6fbd1db..4c681b5 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -274,6 +274,7 @@ class ModelConfig: @@ -34,7 +34,7 @@ index 6fbd1db82..4c681b58d 100644 elif not needs_tf_v5: logger.warning( diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py -index da4629e52..c03f98231 100644 +index da4629e..c03f982 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -17,6 +17,7 @@ class KVArgs: @@ -46,7 +46,7 @@ index da4629e52..c03f98231 100644 aux_data_lens: List[int] aux_item_lens: List[int] diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py -index 67fe82ad6..2ef25c49b 100644 +index 67fe82a..2ef25c4 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -24,6 +24,7 @@ from sglang.srt.disaggregation.base.conn import ( @@ -126,7 +126,7 @@ index 67fe82ad6..2ef25c49b 100644 "prefill_pp_size": self.pp_size, "prefill_page_size": self.page_size, diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index 1d8baf002..1ebb95929 100644 +index 1d8baf0..1ebb959 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server @@ -312,7 +312,7 @@ index 1d8baf002..1ebb95929 100644 if not hasattr(self, "polling_count"): diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py -index a2d08e0e3..ed0790604 100644 +index a2d08e0..ed07906 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -117,7 +117,7 @@ def _convert(data): @@ -353,7 +353,7 @@ index a2d08e0e3..ed0790604 100644 mm_item = MultimodalDataItem.from_dict( { diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index d0d4efd95..b3a207063 100644 +index d0d4efd..b3a2070 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -30,7 +30,7 @@ from sglang.srt.disaggregation.common.utils import ( @@ -541,7 +541,7 @@ index d0d4efd95..b3a207063 100644 def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index fbc801635..ade111c9f 100644 +index fbc8016..ade111c 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,6 +20,7 @@ Life cycle of a request in the prefill server @@ -715,7 +715,7 @@ index fbc801635..ade111c9f 100644 transferred_rids: List[str] = [] diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py -index 6d58f415a..84723c342 100644 +index 6d58f41..84723c3 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -21,6 +21,17 @@ if TYPE_CHECKING: @@ -907,7 +907,7 @@ index 6d58f415a..84723c342 100644 ######################### diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index 8f1069c00..e47589295 100644 +index 8f1069c..e475892 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1999,7 +1999,10 @@ def get_tensor_model_parallel_world_size(): @@ -923,7 +923,7 @@ index 8f1069c00..e47589295 100644 # ATTN_TP diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index 0ed5a1b44..67e33c650 100644 +index 0ed5a1b..67e33c6 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import ( @@ -960,7 +960,7 @@ index 0ed5a1b44..67e33c650 100644 """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 1d6816c01..402b42e05 100644 +index 1d6816c..402b42e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -115,6 +115,7 @@ from sglang.srt.managers.io_struct import ( @@ -1032,7 +1032,7 @@ index 1d6816c01..402b42e05 100644 @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py -index 8293796a2..bff34e422 100644 +index 8293796..bff34e4 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -244,6 +244,7 @@ class Envs: @@ -1044,7 +1044,7 @@ index 8293796a2..bff34e422 100644 # Scheduler: others: SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1) # in seconds. Set if you observe high memory accumulation over a long serving period. diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py -index 1cdf65b91..4783cd18f 100644 +index 1cdf65b..4783cd1 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -630,7 +630,6 @@ def _get_k_and_s_triton( @@ -1064,7 +1064,7 @@ index 1cdf65b91..4783cd18f 100644 buf_numel_per_page: tl.constexpr, index_head_dim: tl.constexpr, diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py -index ca54a931b..3540f77ba 100644 +index ca54a93..3540f77 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -1,6 +1,7 @@ @@ -1149,7 +1149,7 @@ index ca54a931b..3540f77ba 100644 if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py -index 00ef96f9b..c2c2c78fe 100644 +index 00ef96f..c2c2c78 100644 --- a/python/sglang/srt/layers/attention/nsa/utils.py +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -91,20 +91,29 @@ def nsa_cp_round_robin_split_data(input_: Union[torch.Tensor, List]): @@ -1215,7 +1215,7 @@ index 00ef96f9b..c2c2c78fe 100644 position_id_list = list( diff --git a/python/sglang/srt/layers/communicator_nsa_cp.py b/python/sglang/srt/layers/communicator_nsa_cp.py -index 296d14568..f4606a769 100644 +index 296d145..f4606a7 100644 --- a/python/sglang/srt/layers/communicator_nsa_cp.py +++ b/python/sglang/srt/layers/communicator_nsa_cp.py @@ -34,7 +34,6 @@ from sglang.srt.layers.communicator import ( @@ -1254,7 +1254,7 @@ index 296d14568..f4606a769 100644 attn_cp_all_gather_into_tensor( hidden_states, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py -index 5bf5aa0c8..e52f39fd8 100644 +index 5bf5aa0..e52f39f 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -90,11 +90,11 @@ class _DpGatheredBufferWrapper: @@ -1275,10 +1275,122 @@ index 5bf5aa0c8..e52f39fd8 100644 @classmethod def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device): diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py -index aff05bf42..130359232 100644 +index aff05bf..68b67c7 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py -@@ -872,11 +872,6 @@ class LogitsProcessor(nn.Module): +@@ -47,6 +47,7 @@ from sglang.srt.layers.utils.logprob import ( + get_token_ids_logprobs_prefill, + get_top_logprobs_chunk, + get_top_logprobs_prefill, ++ get_top_p_logprobs_prefill, + ) + from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding + from sglang.srt.model_executor.forward_batch_info import ( +@@ -78,6 +79,9 @@ class LogitsProcessorOutput: + # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] + next_token_top_logprobs_val: Optional[List] = None + next_token_top_logprobs_idx: Optional[List] = None ++ # The logprobs and ids of the top-p tokens in output positions (variable-length per position) ++ next_token_top_p_logprobs_val: Optional[List] = None ++ next_token_top_p_logprobs_idx: Optional[List] = None + # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids) + # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests) + next_token_token_ids_logprobs_val: Optional[ +@@ -91,6 +95,9 @@ class LogitsProcessorOutput: + # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] + input_top_logprobs_val: Optional[List] = None + input_top_logprobs_idx: Optional[List] = None ++ # The logprobs and ids of the top-p tokens in input positions (variable-length per position) ++ input_top_p_logprobs_val: Optional[List] = None ++ input_top_p_logprobs_idx: Optional[List] = None + # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids) + # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization) + input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = ( +@@ -115,12 +122,14 @@ class LogitsMetadata: + + extend_return_logprob: bool = False + extend_return_top_logprob: bool = False ++ extend_return_top_p_logprob: bool = False + extend_token_ids_logprob: bool = False + extend_seq_lens: Optional[torch.Tensor] = None + extend_seq_lens_cpu: Optional[List[int]] = None + extend_logprob_start_lens_cpu: Optional[List[int]] = None + extend_logprob_pruned_lens_cpu: Optional[List[int]] = None + top_logprobs_nums: Optional[List[int]] = None ++ top_logprobs_ps: Optional[List[float]] = None + extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None + token_ids_logprobs: Optional[List[List[int]]] = None + +@@ -160,6 +169,10 @@ class LogitsMetadata: + extend_return_top_logprob = any( + x > 0 for x in forward_batch.top_logprobs_nums + ) ++ extend_return_top_p_logprob = ( ++ forward_batch.top_logprobs_ps is not None ++ and any(x > 0.0 for x in forward_batch.top_logprobs_ps) ++ ) + extend_token_ids_logprob = any( + x is not None for x in forward_batch.token_ids_logprobs + ) +@@ -174,6 +187,8 @@ class LogitsMetadata: + extend_logprob_pruned_lens_cpu.append(extend_len - start_len) + else: + extend_return_logprob = extend_return_top_logprob = ( ++ extend_return_top_p_logprob ++ ) = ( + extend_token_ids_logprob + ) = extend_logprob_pruned_lens_cpu = False + +@@ -183,12 +198,14 @@ class LogitsMetadata: + next_token_logits_buffer=forward_batch.next_token_logits_buffer, + extend_return_logprob=extend_return_logprob, + extend_return_top_logprob=extend_return_top_logprob, ++ extend_return_top_p_logprob=extend_return_top_p_logprob, + extend_token_ids_logprob=extend_token_ids_logprob, + extend_seq_lens=forward_batch.extend_seq_lens, + extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, + extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, + top_logprobs_nums=forward_batch.top_logprobs_nums, ++ top_logprobs_ps=forward_batch.top_logprobs_ps, + token_ids_logprobs=forward_batch.token_ids_logprobs, + extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu, + padded_static_len=forward_batch.padded_static_len, +@@ -391,6 +408,8 @@ class LogitsProcessor(nn.Module): + input_token_logprobs=logprobs_result.input_token_logprobs, + input_top_logprobs_val=logprobs_result.input_top_logprobs_val, + input_top_logprobs_idx=logprobs_result.input_top_logprobs_idx, ++ input_top_p_logprobs_val=logprobs_result.input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=logprobs_result.input_top_p_logprobs_idx, + input_token_ids_logprobs_val=logprobs_result.input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=logprobs_result.input_token_ids_logprobs_idx, + mm_input_embeds=logits_metadata.mm_input_embeds, +@@ -619,6 +638,15 @@ class LogitsProcessor(nn.Module): + else: + input_top_logprobs_val = input_top_logprobs_idx = None + ++ # Get the logprob of top-p tokens ++ if logits_metadata.extend_return_top_p_logprob: ++ ( ++ input_top_p_logprobs_val, ++ input_top_p_logprobs_idx, ++ ) = get_top_p_logprobs_prefill(input_logprobs, logits_metadata) ++ else: ++ input_top_p_logprobs_val = input_top_p_logprobs_idx = None ++ + # Get the logprob of given token id + if logits_metadata.extend_token_ids_logprob: + ( +@@ -637,6 +665,8 @@ class LogitsProcessor(nn.Module): + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, ++ input_top_p_logprobs_val=input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=input_top_p_logprobs_idx, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, + ) +@@ -872,11 +902,6 @@ class LogitsProcessor(nn.Module): None, # bias True, # is_vnni ) @@ -1290,160 +1402,17 @@ index aff05bf42..130359232 100644 else: logits = torch.matmul( hidden_states.to(lm_head.weight.dtype), lm_head.weight.T -diff --git a/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py -new file mode 100644 -index 000000000..8d3d0f92e ---- /dev/null -+++ b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py -@@ -0,0 +1,146 @@ -+"""Fused Triton kernels for DeepEP BF16 low-latency MoE decode. -+ -+Replaces the naive activation + masking pipeline (5+ CUDA kernels for silu+mul -+and arange+comparison+masked_fill+copy) with a single Triton elementwise kernel, -+while keeping cuBLAS batched GEMM for the matrix multiplies. -+ -+Pipeline: bmm → fused_act_mul_masked (in-place) → bmm(out=hidden) -+ (3 ops total: 2 cuBLAS + 1 Triton, vs original 7-8 separate CUDA kernels) -+""" -+ -+import torch -+import triton -+import triton.language as tl -+ -+ -+@triton.jit -+def _silu_mul_masked_kernel( -+ gate_up_ptr, -+ masked_m_ptr, -+ M, -+ N, -+ stride_ge, -+ stride_gm, -+ stride_gn, -+ BLOCK: tl.constexpr, -+): -+ """Fused SiLU(gate) * up with per-expert masking, written in-place. -+ -+ gate_up: [E, M, 2*N] — first N cols are gate, last N cols are up. -+ Writes SiLU(gate)*up to gate_up[:,:,:N] in-place. -+ Rows m >= masked_m[e] are zeroed. -+ """ -+ expert_id = tl.program_id(1) -+ pid = tl.program_id(0) -+ -+ expert_valid_m = tl.load(masked_m_ptr + expert_id) -+ -+ offs = pid * BLOCK + tl.arange(0, BLOCK) -+ total = M * N -+ mask = offs < total -+ -+ m = offs // N -+ n = offs % N -+ -+ gate_base = gate_up_ptr + expert_id * stride_ge -+ -+ gate_val = tl.load(gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0) -+ up_val = tl.load( -+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0 -+ ) -+ -+ gate_f32 = gate_val.to(tl.float32) -+ result = (gate_f32 * tl.sigmoid(gate_f32)) * up_val.to(tl.float32) -+ -+ # Zero invalid rows -+ valid = m < expert_valid_m -+ result = tl.where(valid, result, 0.0) -+ -+ tl.store( -+ gate_base + m * stride_gm + n * stride_gn, -+ result.to(gate_up_ptr.dtype.element_ty), -+ mask=mask, -+ ) -+ -+ -+@triton.jit -+def _gelu_mul_masked_kernel( -+ gate_up_ptr, -+ masked_m_ptr, -+ M, -+ N, -+ stride_ge, -+ stride_gm, -+ stride_gn, -+ BLOCK: tl.constexpr, -+): -+ """Fused GELU(gate) * up with per-expert masking, written in-place.""" -+ expert_id = tl.program_id(1) -+ pid = tl.program_id(0) -+ -+ expert_valid_m = tl.load(masked_m_ptr + expert_id) -+ -+ offs = pid * BLOCK + tl.arange(0, BLOCK) -+ total = M * N -+ mask = offs < total -+ -+ m = offs // N -+ n = offs % N -+ -+ gate_base = gate_up_ptr + expert_id * stride_ge -+ -+ gate_val = tl.load(gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0) -+ up_val = tl.load( -+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0 -+ ) -+ -+ g = gate_val.to(tl.float32) -+ kAlpha = 0.7978845608028654 -+ gate_act = 0.5 * g * (1.0 + tl.math.tanh(kAlpha * (g + 0.044715 * g * g * g))) -+ result = gate_act * up_val.to(tl.float32) -+ -+ valid = m < expert_valid_m -+ result = tl.where(valid, result, 0.0) -+ -+ tl.store( -+ gate_base + m * stride_gm + n * stride_gn, -+ result.to(gate_up_ptr.dtype.element_ty), -+ mask=mask, -+ ) -+ -+ -+def fused_act_mul_masked_inplace( -+ gate_up: torch.Tensor, -+ intermediate_size: int, -+ masked_m: torch.Tensor, -+ use_gelu: bool = False, -+) -> None: -+ """Fused activation + multiply + masking, written in-place to gate_up[:,:,:I]. -+ -+ After this call, gate_up[:, :, :intermediate_size] contains the masked -+ activated intermediate, suitable for the down projection GEMM. -+ -+ Args: -+ gate_up: [E, M, 2*I] output of bmm(tokens, w13.T), modified in-place -+ intermediate_size: I -+ masked_m: [E] per-expert valid token count -+ use_gelu: use GELU instead of SiLU -+ """ -+ E, M, _ = gate_up.shape -+ N = intermediate_size -+ -+ total = M * N -+ BLOCK = 1024 -+ grid = (triton.cdiv(total, BLOCK), E) -+ -+ kernel = _gelu_mul_masked_kernel if use_gelu else _silu_mul_masked_kernel -+ kernel[grid]( -+ gate_up, -+ masked_m, -+ M, -+ N, -+ gate_up.stride(0), -+ gate_up.stride(1), -+ gate_up.stride(2), -+ BLOCK=BLOCK, -+ ) +@@ -1073,6 +1098,8 @@ class LogitsProcessor(nn.Module): + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, ++ input_top_p_logprobs_val=input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=input_top_p_logprobs_idx, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, + # FIXME: These fields are not logits-related but are passed through here as a diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py -index ebcc696ec..3b527021a 100644 +index ebcc696..3b52702 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -132,11 +132,12 @@ class DeepEPMoE(FusedMoE): @@ -1563,7 +1532,7 @@ index ebcc696ec..3b527021a 100644 self, dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -index de8a07ab3..952f8a67b 100644 +index de8a07a..952f8a6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -697,6 +697,7 @@ class FusedMoE(torch.nn.Module): @@ -1607,7 +1576,7 @@ index de8a07ab3..952f8a67b 100644 ) diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py -index 00bd68755..12d5577af 100644 +index 00bd687..12d5577 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -8,10 +8,15 @@ import torch @@ -1668,7 +1637,7 @@ index 00bd68755..12d5577af 100644 def get_routed_experts( diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -index 8539639d5..d44496c2f 100644 +index 8539639..d44496c 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -388,6 +388,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): @@ -1742,7 +1711,7 @@ index 8539639d5..d44496c2f 100644 buffer = self._get_buffer() diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index 4cbfed6f9..88b452744 100644 +index 4cbfed6..88b4527 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -499,7 +499,7 @@ class CompressedTensorsConfig(QuantizationConfig): @@ -1765,7 +1734,7 @@ index 4cbfed6f9..88b452744 100644 self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py -index 6264f36d0..f0310e305 100644 +index 6264f36..f0310e3 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -17,7 +17,10 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( @@ -1884,7 +1853,7 @@ index 6264f36d0..f0310e305 100644 is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py -index ae0614635..3b6a8d254 100644 +index ae06146..3b6a8d2 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -305,9 +305,6 @@ class RotaryEmbedding(MultiPlatformOp): @@ -1907,11 +1876,206 @@ index ae0614635..3b6a8d254 100644 # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 assert ( fused_set_kv_buffer_arg is None +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index f78d83d..706b0eb 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -12,7 +12,11 @@ from sglang.srt.layers.dp_attention import ( + ) + from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.layers.utils.hash import murmur_hash32 +-from sglang.srt.layers.utils.logprob import get_token_ids_logprobs, get_top_logprobs ++from sglang.srt.layers.utils.logprob import ( ++ get_token_ids_logprobs, ++ get_top_logprobs, ++ get_top_p_logprobs, ++) + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + from sglang.srt.sampling.sampling_params import TOP_K_ALL + from sglang.srt.server_args import get_global_server_args +@@ -79,6 +83,7 @@ class Sampler(nn.Module): + sampling_info: SamplingBatchInfo, + return_logprob: bool, + top_logprobs_nums: List[int], ++ top_logprobs_ps: Optional[List[float]], + token_ids_logprobs: List[List[int]], + positions: torch.Tensor, + ): +@@ -176,6 +181,7 @@ class Sampler(nn.Module): + logits_output, + logprobs, + top_logprobs_nums, ++ top_logprobs_ps, + token_ids_logprobs, + sampling_info, + batch_next_token_ids, +@@ -314,6 +320,7 @@ class Sampler(nn.Module): + logits_output: LogitsProcessorOutput, + logprobs: torch.Tensor, + top_logprobs_nums: List[int], ++ top_logprobs_ps: Optional[List[float]], + token_ids_logprobs: List[List[int]], + sampling_info: SamplingBatchInfo, + batch_next_token_ids: torch.Tensor, +@@ -328,6 +335,13 @@ class Sampler(nn.Module): + logits_output.next_token_top_logprobs_idx, + ) = get_top_logprobs(logprobs, top_logprobs_nums) + ++ # Attach top-p logprobs ++ if top_logprobs_ps is not None and any(x > 0.0 for x in top_logprobs_ps): ++ ( ++ logits_output.next_token_top_p_logprobs_val, ++ logits_output.next_token_top_p_logprobs_idx, ++ ) = get_top_p_logprobs(logprobs, top_logprobs_ps) ++ + if any(x is not None for x in token_ids_logprobs): + ( + logits_output.next_token_token_ids_logprobs_val, +@@ -362,6 +376,7 @@ class Sampler(nn.Module): + sampling_info: SamplingBatchInfo, + return_logprob: bool, + top_logprobs_nums: List[int], ++ top_logprobs_ps: Optional[List[float]], + token_ids_logprobs: List[List[int]], + ) -> None: + """ +diff --git a/python/sglang/srt/layers/utils/logprob.py b/python/sglang/srt/layers/utils/logprob.py +index 6f84c15..3e526d3 100644 +--- a/python/sglang/srt/layers/utils/logprob.py ++++ b/python/sglang/srt/layers/utils/logprob.py +@@ -25,6 +25,8 @@ class InputLogprobsResult: + input_token_logprobs: torch.Tensor + input_top_logprobs_val: Optional[List] = None + input_top_logprobs_idx: Optional[List] = None ++ input_top_p_logprobs_val: Optional[List] = None ++ input_top_p_logprobs_idx: Optional[List] = None + input_token_ids_logprobs_val: Optional[List] = None + input_token_ids_logprobs_idx: Optional[List] = None + +@@ -96,6 +98,75 @@ def get_top_logprobs_raw( + return top_logprobs_val, top_logprobs_idx + + ++ ++def get_top_p_logprobs_raw( ++ logprobs: torch.Tensor, ++ top_logprobs_ps: List[float], ++ stage: LogprobStage, ++ extend_logprob_pruned_lens_cpu: Optional[List[int]] = None, ++ no_copy_to_cpu: bool = False, ++): ++ """Get top-p logprobs: return tokens whose cumulative probability >= top_p threshold.""" ++ sorted_logprobs, sorted_indices = logprobs.sort(dim=-1, descending=True) ++ sorted_probs = sorted_logprobs.exp() ++ cumsum_probs = torch.cumsum(sorted_probs, dim=-1) ++ ++ top_logprobs_val = [] ++ top_logprobs_idx = [] ++ ++ if stage == LogprobStage.DECODE: ++ cumsum_cpu = cumsum_probs.cpu() ++ sorted_logprobs_cpu = sorted_logprobs.cpu() ++ sorted_indices_cpu = sorted_indices.cpu() ++ ++ for i, p in enumerate(top_logprobs_ps): ++ if p <= 0.0: ++ top_logprobs_val.append([]) ++ top_logprobs_idx.append([]) ++ continue ++ mask = cumsum_cpu[i] >= p ++ if mask.any(): ++ cutoff = mask.nonzero(as_tuple=True)[0][0].item() + 1 ++ else: ++ cutoff = sorted_logprobs_cpu.shape[1] ++ cutoff = max(cutoff, 1) ++ top_logprobs_val.append(sorted_logprobs_cpu[i, :cutoff].tolist()) ++ top_logprobs_idx.append(sorted_indices_cpu[i, :cutoff].tolist()) ++ else: ++ cumsum_cpu = cumsum_probs.cpu() ++ sorted_logprobs_cpu = sorted_logprobs.cpu() ++ sorted_indices_cpu = sorted_indices.cpu() ++ ++ pt = 0 ++ for p, pruned_len in zip(top_logprobs_ps, extend_logprob_pruned_lens_cpu): ++ if pruned_len <= 0: ++ top_logprobs_val.append([]) ++ top_logprobs_idx.append([]) ++ continue ++ ++ pos_vals = [] ++ pos_idxs = [] ++ for j in range(pruned_len): ++ row = pt + j ++ if p <= 0.0: ++ pos_vals.append([]) ++ pos_idxs.append([]) ++ continue ++ mask = cumsum_cpu[row] >= p ++ if mask.any(): ++ cutoff = mask.nonzero(as_tuple=True)[0][0].item() + 1 ++ else: ++ cutoff = sorted_logprobs_cpu.shape[1] ++ cutoff = max(cutoff, 1) ++ pos_vals.append(sorted_logprobs_cpu[row, :cutoff].tolist()) ++ pos_idxs.append(sorted_indices_cpu[row, :cutoff].tolist()) ++ top_logprobs_val.append(pos_vals) ++ top_logprobs_idx.append(pos_idxs) ++ pt += pruned_len ++ ++ return top_logprobs_val, top_logprobs_idx ++ ++ + def get_top_logprobs_prefill( + all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata + ): +@@ -114,6 +185,31 @@ def get_top_logprobs( + return get_top_logprobs_raw(logprobs, top_logprobs_nums, stage=LogprobStage.DECODE) + + ++def get_top_p_logprobs_prefill( ++ all_logprobs: torch.Tensor, logits_metadata: "LogitsMetadata" ++): ++ return get_top_p_logprobs_raw( ++ all_logprobs, ++ logits_metadata.top_logprobs_ps, ++ stage=LogprobStage.PREFILL, ++ extend_logprob_pruned_lens_cpu=logits_metadata.extend_logprob_pruned_lens_cpu, ++ ) ++ ++ ++def get_top_p_logprobs( ++ logprobs: torch.Tensor, ++ top_logprobs_ps: List[float], ++ no_copy_to_cpu: bool = False, ++): ++ result = get_top_p_logprobs_raw( ++ logprobs, ++ top_logprobs_ps, ++ stage=LogprobStage.DECODE, ++ no_copy_to_cpu=no_copy_to_cpu, ++ ) ++ return result ++ ++ + def get_token_ids_logprobs_raw( + logprobs: torch.Tensor, + token_ids_logprobs: List[Optional[List[int]]], diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py -index 652227860..7d3a5d0c4 100644 +index 6522278..0db2cbc 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py -@@ -405,6 +405,17 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): +@@ -387,6 +387,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, ++ input_top_p_logprobs_val=recv_obj.input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=recv_obj.input_top_p_logprobs_idx, ++ output_top_p_logprobs_val=recv_obj.output_top_p_logprobs_val, ++ output_top_p_logprobs_idx=recv_obj.output_top_p_logprobs_idx, + input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx, + output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, +@@ -405,6 +409,17 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): prefill_launch_delay=recv_obj.prefill_launch_delay, prefill_launch_latency=recv_obj.prefill_launch_latency, prefill_finished_ts=recv_obj.prefill_finished_ts, @@ -1930,7 +2094,7 @@ index 652227860..7d3a5d0c4 100644 def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index ff1774567..f947e71d7 100644 +index ff17745..df07928 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -101,6 +101,42 @@ class RequestTimingMetricsMixin: @@ -1976,7 +2140,80 @@ index ff1774567..f947e71d7 100644 @dataclass class SpeculativeDecodingMetricsMixin: -@@ -1403,6 +1439,20 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): +@@ -198,8 +234,12 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + top_logprobs_num: Optional[Union[List[int], int]] = None + # If return logprobs, the token ids to return logprob for. + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None ++ # If return logprobs, the top-p threshold for returning variable-length top logprobs per position. ++ top_logprobs_p: Optional[Union[List[float], float]] = None + # Whether to detokenize tokens in text in the returned logprobs. + return_text_in_logprobs: bool = False ++ # Whether to return logprobs encoded in base64 format. ++ return_logprobs_in_base64: bool = False + # Whether to stream output. + stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) +@@ -389,6 +429,8 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + self.top_logprobs_num = 0 + if not self.token_ids_logprob: # covers both None and [] + self.token_ids_logprob = None ++ if self.top_logprobs_p is None: ++ self.top_logprobs_p = 0.0 + + def _normalize_batch_inputs(self): + """Normalize inputs for a batch of examples, including parallel sampling expansion.""" +@@ -552,6 +594,9 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + self.top_logprobs_num = normalize_param( + self.top_logprobs_num, 0, "top_logprobs_num" + ) ++ self.top_logprobs_p = normalize_param( ++ self.top_logprobs_p, 0.0, "top_logprobs_p" ++ ) + + # Handle token_ids_logprob specially due to its nested structure + if not self.token_ids_logprob: # covers both None and [] +@@ -636,7 +681,9 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + logprob_start_len=self.logprob_start_len[i], + top_logprobs_num=self.top_logprobs_num[i], + token_ids_logprob=self.token_ids_logprob[i], ++ top_logprobs_p=self.top_logprobs_p[i], + return_text_in_logprobs=self.return_text_in_logprobs, ++ return_logprobs_in_base64=self.return_logprobs_in_base64, + stream=self.stream, + log_metrics=self.log_metrics, + return_hidden_states=( +@@ -709,6 +756,8 @@ class TokenizedGenerateReqInput(BaseReq): + top_logprobs_num: int + # If return logprobs, the token id to return logprob for + token_ids_logprob: List[int] ++ # If return logprobs, the top-p threshold for variable-length top logprobs ++ top_logprobs_p: float + # Whether to stream output + stream: bool + +@@ -1008,6 +1057,10 @@ class BatchTokenIDOutput( + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] ++ input_top_p_logprobs_val: List[List] ++ input_top_p_logprobs_idx: List[List] ++ output_top_p_logprobs_val: List[List] ++ output_top_p_logprobs_idx: List[List] + input_token_ids_logprobs_val: List[List] + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] +@@ -1098,6 +1151,10 @@ class BatchStrOutput( + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] ++ input_top_p_logprobs_val: List[List] ++ input_top_p_logprobs_idx: List[List] ++ output_top_p_logprobs_val: List[List] ++ output_top_p_logprobs_idx: List[List] + input_token_ids_logprobs_val: List[List] + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] +@@ -1403,6 +1460,20 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): message: str @@ -1997,7 +2234,7 @@ index ff1774567..f947e71d7 100644 @dataclass class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): success: bool -@@ -1802,6 +1852,10 @@ class GetLoadReqOutput(BaseReq): +@@ -1802,6 +1873,10 @@ class GetLoadReqOutput(BaseReq): num_waiting_reqs: int num_tokens: int ts_tic: float @@ -2009,7 +2246,7 @@ index ff1774567..f947e71d7 100644 @dataclass diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py -index e1236aa0f..daa598a1f 100644 +index e1236aa..daa598a 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -142,6 +142,39 @@ def _handle_output_by_index(output, i): @@ -2205,10 +2442,89 @@ index e1236aa0f..daa598a1f 100644 class SenderWrapper: def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index c07995798..dd8ca7167 100644 +index c079957..71912f3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py -@@ -1869,7 +1869,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -520,6 +520,7 @@ class Req(ReqDllmMixin): + sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, ++ top_logprobs_p: float = 0.0, + dllm_config: Optional[DllmConfig] = None, + token_ids_logprob: List[int] = None, + stream: bool = False, +@@ -691,6 +692,7 @@ class Req(ReqDllmMixin): + # Start index to compute logprob from. + self.logprob_start_len = 0 + self.top_logprobs_num = top_logprobs_num ++ self.top_logprobs_p = top_logprobs_p + self.token_ids_logprob = token_ids_logprob + self.temp_scaled_logprobs = False + self.top_p_normalized_logprobs = False +@@ -702,12 +704,16 @@ class Req(ReqDllmMixin): + self.input_token_logprobs_idx: Optional[List[int]] = None + self.input_top_logprobs_val: Optional[List[float]] = None + self.input_top_logprobs_idx: Optional[List[int]] = None ++ self.input_top_p_logprobs_val: Optional[List] = None ++ self.input_top_p_logprobs_idx: Optional[List] = None + self.input_token_ids_logprobs_val: Optional[List[float]] = None + self.input_token_ids_logprobs_idx: Optional[List[int]] = None + # Temporary holder to store input_token_logprobs. + self.input_token_logprobs: Optional[List[Tuple[int]]] = None + self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None + self.temp_input_top_logprobs_idx: Optional[List[int]] = None ++ self.temp_input_top_p_logprobs_val: Optional[List] = None ++ self.temp_input_top_p_logprobs_idx: Optional[List] = None + self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None + self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None + +@@ -718,6 +724,9 @@ class Req(ReqDllmMixin): + # shape: (bs, k) + self.output_top_logprobs_val = [] + self.output_top_logprobs_idx = [] ++ # shape: (bs, variable) for top-p logprobs ++ self.output_top_p_logprobs_val = [] ++ self.output_top_p_logprobs_idx = [] + # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring) + self.output_token_ids_logprobs_val: List[ + Union[List[float], torch.Tensor] +@@ -726,7 +735,9 @@ class Req(ReqDllmMixin): + else: + self.output_token_logprobs_val = self.output_token_logprobs_idx = ( + self.output_top_logprobs_val +- ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( ++ ) = self.output_top_logprobs_idx = self.output_top_p_logprobs_val = ( ++ self.output_top_p_logprobs_idx ++ ) = self.output_token_ids_logprobs_val = ( + self.output_token_ids_logprobs_idx + ) = None + self.hidden_states: List[List[float]] = [] +@@ -1116,6 +1127,8 @@ class Req(ReqDllmMixin): + self.input_token_logprobs = None + self.temp_input_top_logprobs_val = None + self.temp_input_top_logprobs_idx = None ++ self.temp_input_top_p_logprobs_val = None ++ self.temp_input_top_p_logprobs_idx = None + self.extend_logprob_start_len = 0 + self.is_chunked = 0 + self.mamba_pool_idx = None +@@ -1260,6 +1273,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + # For processing logprobs + return_logprob: bool = False + top_logprobs_nums: Optional[List[int]] = None ++ top_logprobs_ps: Optional[List[float]] = None + token_ids_logprobs: Optional[List[List[int]]] = None + + # For logits and logprob post processing +@@ -1651,6 +1665,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + if self.return_logprob: + self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] ++ self.top_logprobs_ps = [r.top_logprobs_p for r in reqs] + self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] + + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] +@@ -1869,7 +1884,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): while first_iter or ( not self.check_decode_mem(selected_indices=sorted_indices) ): @@ -2220,8 +2536,52 @@ index c07995798..dd8ca7167 100644 # Always keep at least one request break +@@ -2094,9 +2112,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.return_logprob = any(req.return_logprob for req in self.reqs) + if self.return_logprob: + self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices] ++ self.top_logprobs_ps = [self.top_logprobs_ps[i] for i in keep_indices] + self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices] + else: + self.top_logprobs_nums = None ++ self.top_logprobs_ps = None + self.token_ids_logprobs = None + + self.has_stream = any(req.stream for req in self.reqs) +@@ -2143,12 +2163,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.mamba_track_seqlens = None + if self.return_logprob and other.return_logprob: + self.top_logprobs_nums.extend(other.top_logprobs_nums) ++ self.top_logprobs_ps.extend(other.top_logprobs_ps) + self.token_ids_logprobs.extend(other.token_ids_logprobs) + elif self.return_logprob: + self.top_logprobs_nums.extend([0] * len(other.reqs)) ++ self.top_logprobs_ps.extend([0.0] * len(other.reqs)) + self.token_ids_logprobs.extend([None] * len(other.reqs)) + elif other.return_logprob: + self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums ++ self.top_logprobs_ps = [0.0] * len(self.reqs) + other.top_logprobs_ps + self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs + self.reqs.extend(other.reqs) + if self.multimodal_inputs is not None: +@@ -2193,6 +2216,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + seq_lens_sum=self.seq_lens_sum, + return_logprob=self.return_logprob, + top_logprobs_nums=self.top_logprobs_nums, ++ top_logprobs_ps=self.top_logprobs_ps, + token_ids_logprobs=self.token_ids_logprobs, + global_num_tokens=self.global_num_tokens, + global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, +@@ -2352,6 +2376,7 @@ class ModelWorkerBatch: + # For logprob + return_logprob: bool + top_logprobs_nums: Optional[List[int]] ++ top_logprobs_ps: Optional[List[float]] + token_ids_logprobs: Optional[List[List[int]]] + + # For DP attention diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index a9ff0ac94..c124f43bc 100644 +index a9ff0ac..264b177 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -114,6 +114,7 @@ from sglang.srt.managers.io_struct import ( @@ -2261,8 +2621,16 @@ index a9ff0ac94..c124f43bc 100644 (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), +@@ -1505,6 +1512,7 @@ class Scheduler( + recv_req.sampling_params, + return_logprob=recv_req.return_logprob, + top_logprobs_num=recv_req.top_logprobs_num, ++ top_logprobs_p=recv_req.top_logprobs_p, + token_ids_logprob=recv_req.token_ids_logprob, + stream=recv_req.stream, + lora_id=recv_req.lora_id, diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py -index 30b2732b9..68090b161 100644 +index 30b2732..68090b1 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -609,12 +609,54 @@ class SchedulerMetricsMixin: @@ -2321,10 +2689,109 @@ index 30b2732b9..68090b161 100644 def get_loads(self: Scheduler, req: GetLoadsReqInput = None) -> GetLoadsReqOutput: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index 482bc6ca6..fbc486417 100644 +index 482bc6c..9b8b5cb 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -@@ -922,6 +922,18 @@ class SchedulerOutputProcessorMixin: +@@ -494,6 +494,13 @@ class SchedulerOutputProcessorMixin: + req.output_top_logprobs_idx.append( + logits_output.next_token_top_logprobs_idx[i] + ) ++ if req.top_logprobs_p > 0.0 and logits_output.next_token_top_p_logprobs_val is not None: ++ req.output_top_p_logprobs_val.append( ++ logits_output.next_token_top_p_logprobs_val[i] ++ ) ++ req.output_top_p_logprobs_idx.append( ++ logits_output.next_token_top_p_logprobs_idx[i] ++ ) + if req.token_ids_logprob is not None: + req.output_token_ids_logprobs_val.append( + logits_output.next_token_token_ids_logprobs_val[i] +@@ -623,6 +630,32 @@ class SchedulerOutputProcessorMixin: + # Clean up temp storage + req.temp_input_top_logprobs_idx = None + req.temp_input_top_logprobs_val = None ++ def _process_input_top_p_logprobs(self: Scheduler, req: Req) -> None: ++ """Process input top-p logprobs.""" ++ if req.top_logprobs_p <= 0.0: ++ return ++ ++ is_multi_item_scoring = self._is_multi_item_scoring(req) ++ ++ req.input_top_p_logprobs_val = [] if is_multi_item_scoring else [None] ++ req.input_top_p_logprobs_idx = [] if is_multi_item_scoring else [None] ++ ++ for val, idx in zip( ++ req.temp_input_top_p_logprobs_val, ++ req.temp_input_top_p_logprobs_idx, ++ strict=True, ++ ): ++ req.input_top_p_logprobs_val.extend(val) ++ req.input_top_p_logprobs_idx.extend(idx) ++ ++ if not is_multi_item_scoring: ++ req.input_top_p_logprobs_val.pop() ++ req.input_top_p_logprobs_idx.pop() ++ ++ req.temp_input_top_p_logprobs_idx = None ++ req.temp_input_top_p_logprobs_val = None ++ ++ + + def _process_input_token_ids_logprobs(self, req: Req) -> None: + """Process input token IDs logprobs.""" +@@ -737,6 +770,10 @@ class SchedulerOutputProcessorMixin: + req.temp_input_top_logprobs_val = [] + if req.temp_input_top_logprobs_idx is None: + req.temp_input_top_logprobs_idx = [] ++ if req.temp_input_top_p_logprobs_val is None: ++ req.temp_input_top_p_logprobs_val = [] ++ if req.temp_input_top_p_logprobs_idx is None: ++ req.temp_input_top_p_logprobs_idx = [] + if req.temp_input_token_ids_logprobs_val is None: + req.temp_input_token_ids_logprobs_val = [] + if req.temp_input_token_ids_logprobs_idx is None: +@@ -761,6 +798,10 @@ class SchedulerOutputProcessorMixin: + req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i]) + req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) + ++ if req.top_logprobs_p > 0.0 and output.input_top_p_logprobs_val is not None: ++ req.temp_input_top_p_logprobs_val.append(output.input_top_p_logprobs_val[i]) ++ req.temp_input_top_p_logprobs_idx.append(output.input_top_p_logprobs_idx[i]) ++ + if req.token_ids_logprob is not None: + req.temp_input_token_ids_logprobs_val.append( + output.input_token_ids_logprobs_val[i] +@@ -780,6 +821,7 @@ class SchedulerOutputProcessorMixin: + # Process all input logprob types using helper functions + self._process_input_token_logprobs(req, input_token_logprobs) + self._process_input_top_logprobs(req) ++ self._process_input_top_p_logprobs(req) + + self._process_input_token_ids_logprobs(req) + +@@ -822,6 +864,10 @@ class SchedulerOutputProcessorMixin: + req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) + ++ if req.top_logprobs_p > 0.0 and output.next_token_top_p_logprobs_val is not None: ++ req.output_top_p_logprobs_val.append(output.next_token_top_p_logprobs_val[i]) ++ req.output_top_p_logprobs_idx.append(output.next_token_top_p_logprobs_idx[i]) ++ + if ( + req.token_ids_logprob is not None + and output.next_token_token_ids_logprobs_val is not None +@@ -852,6 +898,10 @@ class SchedulerOutputProcessorMixin: + req.input_top_logprobs_val = [] + if req.input_top_logprobs_idx is None: + req.input_top_logprobs_idx = [] ++ if req.input_top_p_logprobs_val is None: ++ req.input_top_p_logprobs_val = [] ++ if req.input_top_p_logprobs_idx is None: ++ req.input_top_p_logprobs_idx = [] + if req.input_token_ids_logprobs_val is None: + req.input_token_ids_logprobs_val = [] + if req.input_token_ids_logprobs_idx is None: +@@ -922,6 +972,18 @@ class SchedulerOutputProcessorMixin: prefill_launch_delays = [] prefill_launch_latencies = [] prefill_finished_timestamps = [] @@ -2343,7 +2810,33 @@ index 482bc6ca6..fbc486417 100644 if return_logprob: input_token_logprobs_val = [] -@@ -1037,6 +1049,40 @@ class SchedulerOutputProcessorMixin: +@@ -932,6 +994,10 @@ class SchedulerOutputProcessorMixin: + input_top_logprobs_idx = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] ++ input_top_p_logprobs_val = [] ++ input_top_p_logprobs_idx = [] ++ output_top_p_logprobs_val = [] ++ output_top_p_logprobs_idx = [] + input_token_ids_logprobs_val = [] + input_token_ids_logprobs_idx = [] + output_token_ids_logprobs_val = [] +@@ -942,8 +1008,12 @@ class SchedulerOutputProcessorMixin: + ) = output_token_logprobs_idx = input_top_logprobs_val = ( + input_top_logprobs_idx + ) = output_top_logprobs_val = output_top_logprobs_idx = ( +- input_token_ids_logprobs_val +- ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = ( ++ input_top_p_logprobs_val ++ ) = input_top_p_logprobs_idx = output_top_p_logprobs_val = ( ++ output_top_p_logprobs_idx ++ ) = input_token_ids_logprobs_val = ( ++ input_token_ids_logprobs_idx ++ ) = output_token_ids_logprobs_val = ( + output_token_ids_logprobs_idx + ) = None + +@@ -1037,6 +1107,40 @@ class SchedulerOutputProcessorMixin: prefill_finished_timestamps.append( req.time_stats.get_prefill_finished_ts() ) @@ -2384,7 +2877,51 @@ index 482bc6ca6..fbc486417 100644 if not self.spec_algorithm.is_none(): spec_verify_ct.append(req.spec_verify_ct) -@@ -1134,7 +1180,7 @@ class SchedulerOutputProcessorMixin: +@@ -1054,6 +1158,8 @@ class SchedulerOutputProcessorMixin: + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) ++ input_top_p_logprobs_val.append(req.input_top_p_logprobs_val) ++ input_top_p_logprobs_idx.append(req.input_top_p_logprobs_idx) + input_token_ids_logprobs_val.append( + req.input_token_ids_logprobs_val + ) +@@ -1066,6 +1172,8 @@ class SchedulerOutputProcessorMixin: + input_token_logprobs_idx.append([]) + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) ++ input_top_p_logprobs_val.append([]) ++ input_top_p_logprobs_idx.append([]) + input_token_ids_logprobs_val.append([]) + input_token_ids_logprobs_idx.append([]) + +@@ -1090,6 +1198,16 @@ class SchedulerOutputProcessorMixin: + send_output_token_logprobs_offset: + ] + ) ++ output_top_p_logprobs_val.append( ++ req.output_top_p_logprobs_val[ ++ send_output_token_logprobs_offset: ++ ] ++ ) ++ output_top_p_logprobs_idx.append( ++ req.output_top_p_logprobs_idx[ ++ send_output_token_logprobs_offset: ++ ] ++ ) + output_token_ids_logprobs_val.append( + req.output_token_ids_logprobs_val[ + send_output_token_logprobs_offset: +@@ -1108,6 +1226,8 @@ class SchedulerOutputProcessorMixin: + output_token_logprobs_idx.append([]) + output_top_logprobs_val.append([]) + output_top_logprobs_idx.append([]) ++ output_top_p_logprobs_val.append([]) ++ output_top_p_logprobs_idx.append([]) + output_token_ids_logprobs_val.append([]) + output_token_ids_logprobs_idx.append([]) + +@@ -1134,7 +1254,7 @@ class SchedulerOutputProcessorMixin: req.log_time_stats() # Send to detokenizer @@ -2393,7 +2930,7 @@ index 482bc6ca6..fbc486417 100644 if self.model_config.is_multimodal_gen: return self.send_to_detokenizer.send_output( -@@ -1149,6 +1195,17 @@ class SchedulerOutputProcessorMixin: +@@ -1149,6 +1269,17 @@ class SchedulerOutputProcessorMixin: prefill_launch_delay=prefill_launch_delays, prefill_launch_latency=prefill_launch_latencies, prefill_finished_ts=prefill_finished_timestamps, @@ -2411,7 +2948,18 @@ index 482bc6ca6..fbc486417 100644 finished_reasons=finished_reasons, decoded_texts=decoded_texts, decode_ids=decode_ids_list, -@@ -1198,6 +1255,18 @@ class SchedulerOutputProcessorMixin: +@@ -1169,6 +1300,10 @@ class SchedulerOutputProcessorMixin: + input_top_logprobs_idx=input_top_logprobs_idx, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ++ input_top_p_logprobs_val=input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=input_top_p_logprobs_idx, ++ output_top_p_logprobs_val=output_top_p_logprobs_val, ++ output_top_p_logprobs_idx=output_top_p_logprobs_idx, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, + output_token_ids_logprobs_val=output_token_ids_logprobs_val, +@@ -1198,6 +1333,18 @@ class SchedulerOutputProcessorMixin: prefill_launch_delays = [] prefill_launch_latencies = [] prefill_finished_timestamps = [] @@ -2430,7 +2978,7 @@ index 482bc6ca6..fbc486417 100644 retraction_counts = [] for req in reqs: if req.finished(): -@@ -1221,6 +1290,40 @@ class SchedulerOutputProcessorMixin: +@@ -1221,6 +1368,40 @@ class SchedulerOutputProcessorMixin: prefill_finished_timestamps.append( req.time_stats.get_prefill_finished_ts() ) @@ -2471,7 +3019,7 @@ index 482bc6ca6..fbc486417 100644 retraction_counts.append(req.retraction_count) self.send_to_detokenizer.send_output( BatchEmbeddingOutput( -@@ -1231,6 +1334,17 @@ class SchedulerOutputProcessorMixin: +@@ -1231,6 +1412,17 @@ class SchedulerOutputProcessorMixin: prefill_launch_delay=prefill_launch_delays, prefill_launch_latency=prefill_launch_latencies, prefill_finished_ts=prefill_finished_timestamps, @@ -2490,7 +3038,7 @@ index 482bc6ca6..fbc486417 100644 embeddings=embeddings, prompt_tokens=prompt_tokens, diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py -index 1a65a3c3d..f76606469 100644 +index 1a65a3c..f766064 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -20,6 +20,7 @@ from sglang.srt.layers.dp_attention import ( @@ -2687,7 +3235,7 @@ index 1a65a3c3d..f76606469 100644 rids: List = [] for poll_statuses in poll_statuses_group: diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py -index 7d08f12b3..afc045da2 100644 +index 7d08f12..afc045d 100644 --- a/python/sglang/srt/managers/scheduler_profiler_mixin.py +++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py @@ -347,7 +347,7 @@ class SchedulerProfilerMixin: @@ -2700,7 +3248,7 @@ index 7d08f12b3..afc045da2 100644 if self.profile_in_progress: # force trace flush diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index 293a84350..244ea4eb1 100644 +index 293a843..244ea4e 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -12,6 +12,7 @@ from sglang.srt.constants import ( @@ -2765,7 +3313,7 @@ index 293a84350..244ea4eb1 100644 def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py -index f2ffa9909..6e4d1d460 100644 +index f2ffa99..6e4d1d4 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import ( @@ -2817,10 +3365,30 @@ index f2ffa9909..6e4d1d460 100644 self, obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 0914a5230..33bb3844a 100644 +index 0914a52..bae5652 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -324,8 +324,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -163,6 +163,10 @@ class ReqState: + input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) ++ input_top_p_logprobs_val: List = dataclasses.field(default_factory=list) ++ input_top_p_logprobs_idx: List = dataclasses.field(default_factory=list) ++ output_top_p_logprobs_val: List = dataclasses.field(default_factory=list) ++ output_top_p_logprobs_idx: List = dataclasses.field(default_factory=list) + input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) + input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) + output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) +@@ -173,6 +177,8 @@ class ReqState: + output_token_logprobs: List[Any] = dataclasses.field(default_factory=list) + input_top_logprobs: List[Any] = dataclasses.field(default_factory=list) + output_top_logprobs: List[Any] = dataclasses.field(default_factory=list) ++ input_top_p_logprobs: List[Any] = dataclasses.field(default_factory=list) ++ output_top_p_logprobs: List[Any] = dataclasses.field(default_factory=list) + input_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) + output_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) + +@@ -324,8 +330,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi context, zmq.PULL, port_args.tokenizer_ipc_name, True ) if self.server_args.tokenizer_worker_num == 1: @@ -2834,7 +3402,15 @@ index 0914a5230..33bb3844a 100644 ) else: from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper -@@ -1327,7 +1331,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -927,6 +937,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + obj.logprob_start_len, + obj.top_logprobs_num, + obj.token_ids_logprob, ++ obj.top_logprobs_p, + obj.stream, + rid=obj.rid, + http_worker_ipc=obj.http_worker_ipc, +@@ -1327,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi async with self.is_pause_cond: self.is_pause = True if obj.mode != "abort": @@ -2843,7 +3419,7 @@ index 0914a5230..33bb3844a 100644 else: # we are using the model_update_lock to check if there is still on-going requests. while True: -@@ -1341,7 +1345,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1341,7 +1352,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi async def continue_generation(self, obj: ContinueGenerationReqInput): async with self.is_pause_cond: self.is_pause = False @@ -2852,7 +3428,7 @@ index 0914a5230..33bb3844a 100644 self.is_pause_cond.notify_all() async def update_weights_from_disk( -@@ -1510,6 +1514,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1510,6 +1521,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi self._add_metric_if_present( recv_obj, "prefill_finished_ts", meta_info, i ) @@ -2893,7 +3469,186 @@ index 0914a5230..33bb3844a 100644 if getattr(state.obj, "return_logprob", False): self.convert_logprob_style( -@@ -1955,19 +1993,17 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1684,7 +1729,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + meta_info["input_top_logprobs"] = state.input_top_logprobs + meta_info["output_top_logprobs"] = state.output_top_logprobs + +- # 3. Handle token_ids_logprob ++ # 3. Handle top-p logprobs ++ top_logprobs_p = getattr(state.obj, "top_logprobs_p", 0.0) or 0.0 ++ if top_logprobs_p > 0.0: ++ if len(state.input_top_p_logprobs_val) > len(state.input_top_p_logprobs): ++ state.input_top_p_logprobs.extend( ++ self.detokenize_top_logprobs_tokens( ++ state.input_top_p_logprobs_val[ ++ len(state.input_top_p_logprobs) : ++ ], ++ state.input_top_p_logprobs_idx[ ++ len(state.input_top_p_logprobs) : ++ ], ++ return_text_in_logprobs, ++ ) ++ ) ++ if len(state.output_top_p_logprobs_val) > len( ++ state.output_top_p_logprobs ++ ): ++ state.output_top_p_logprobs.extend( ++ self.detokenize_top_logprobs_tokens( ++ state.output_top_p_logprobs_val[ ++ len(state.output_top_p_logprobs) : ++ ], ++ state.output_top_p_logprobs_idx[ ++ len(state.output_top_p_logprobs) : ++ ], ++ return_text_in_logprobs, ++ ) ++ ) ++ ++ meta_info["input_top_p_logprobs"] = state.input_top_p_logprobs ++ meta_info["output_top_p_logprobs"] = state.output_top_p_logprobs ++ ++ # 4. Handle token_ids_logprob + if token_ids_logprob is not None: + if len(state.input_token_ids_logprobs_val) > len( + state.input_token_ids_logprobs +@@ -1718,6 +1796,105 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + meta_info["input_token_ids_logprobs"] = state.input_token_ids_logprobs + meta_info["output_token_ids_logprobs"] = state.output_token_ids_logprobs + ++ # 5. Handle base64 encoding for top-k and top-p logprobs ++ return_logprobs_in_base64 = getattr( ++ state.obj, "return_logprobs_in_base64", False ++ ) ++ if return_logprobs_in_base64: ++ self._encode_logprobs_as_base64(meta_info, state, top_logprobs_num) ++ ++ ++ @staticmethod ++ def _encode_logprobs_as_base64( ++ meta_info: dict, state: "ReqState", top_logprobs_num: int ++ ): ++ """Encode top-k and top-p logprobs as base64-encoded numpy arrays.""" ++ import numpy as np ++ import pybase64 ++ ++ def encode_fixed_length(vals_raw, idxs_raw, key_prefix): ++ if not vals_raw: ++ return ++ filtered_vals = [] ++ filtered_idxs = [] ++ none_positions = [] ++ for pos_i, (v, x) in enumerate(zip(vals_raw, idxs_raw)): ++ if v is None: ++ none_positions.append(pos_i) ++ else: ++ filtered_vals.append(v) ++ filtered_idxs.append(x) ++ if not filtered_vals: ++ return ++ val_arr = np.array(filtered_vals, dtype=np.float32) ++ idx_arr = np.array(filtered_idxs, dtype=np.int32) ++ meta_info[f"{key_prefix}_val_base64"] = pybase64.b64encode( ++ val_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_idx_base64"] = pybase64.b64encode( ++ idx_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_shape"] = list(val_arr.shape) ++ if none_positions: ++ meta_info[f"{key_prefix}_none_positions"] = none_positions ++ if key_prefix in meta_info: ++ del meta_info[key_prefix] ++ ++ def encode_variable_length(vals_raw, idxs_raw, key_prefix): ++ if not vals_raw: ++ return ++ flat_vals = [] ++ flat_idxs = [] ++ lengths = [] ++ for v, x in zip(vals_raw, idxs_raw): ++ if v is None: ++ lengths.append(-1) ++ elif isinstance(v, list): ++ lengths.append(len(v)) ++ flat_vals.extend(v) ++ flat_idxs.extend(x) ++ else: ++ lengths.append(0) ++ if not flat_vals: ++ meta_info[f"{key_prefix}_lengths"] = lengths ++ return ++ val_arr = np.array(flat_vals, dtype=np.float32) ++ idx_arr = np.array(flat_idxs, dtype=np.int32) ++ meta_info[f"{key_prefix}_val_base64"] = pybase64.b64encode( ++ val_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_idx_base64"] = pybase64.b64encode( ++ idx_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_lengths"] = lengths ++ if key_prefix in meta_info: ++ del meta_info[key_prefix] ++ ++ if top_logprobs_num > 0: ++ encode_fixed_length( ++ state.input_top_logprobs_val, ++ state.input_top_logprobs_idx, ++ "input_top_logprobs", ++ ) ++ encode_fixed_length( ++ state.output_top_logprobs_val, ++ state.output_top_logprobs_idx, ++ "output_top_logprobs", ++ ) ++ ++ top_logprobs_p = getattr(state.obj, "top_logprobs_p", 0.0) or 0.0 ++ if top_logprobs_p > 0.0: ++ encode_variable_length( ++ state.input_top_p_logprobs_val, ++ state.input_top_p_logprobs_idx, ++ "input_top_p_logprobs", ++ ) ++ encode_variable_length( ++ state.output_top_p_logprobs_val, ++ state.output_top_p_logprobs_idx, ++ "output_top_p_logprobs", ++ ) ++ + def convert_logprob_style( + self, + meta_info: dict, +@@ -1763,6 +1940,30 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + recv_obj.output_top_logprobs_idx[recv_obj_index] + ) + ++ top_logprobs_p = getattr(state.obj, "top_logprobs_p", 0.0) or 0.0 ++ if top_logprobs_p > 0.0: ++ if ( ++ recv_obj.input_top_p_logprobs_val is not None ++ and len(recv_obj.input_top_p_logprobs_val) > 0 ++ and recv_obj.input_top_p_logprobs_val[recv_obj_index] ++ ): ++ state.input_top_p_logprobs_val.extend( ++ recv_obj.input_top_p_logprobs_val[recv_obj_index] ++ ) ++ state.input_top_p_logprobs_idx.extend( ++ recv_obj.input_top_p_logprobs_idx[recv_obj_index] ++ ) ++ if ( ++ recv_obj.output_top_p_logprobs_val is not None ++ and len(recv_obj.output_top_p_logprobs_val) > 0 ++ ): ++ state.output_top_p_logprobs_val.extend( ++ recv_obj.output_top_p_logprobs_val[recv_obj_index] ++ ) ++ state.output_top_p_logprobs_idx.extend( ++ recv_obj.output_top_p_logprobs_idx[recv_obj_index] ++ ) ++ + if token_ids_logprob is not None: + if len(recv_obj.input_token_ids_logprobs_val) > 0: + state.input_token_ids_logprobs_val.extend( +@@ -1955,19 +2156,17 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi if custom_labels else self.metrics_collector.labels ) @@ -2919,7 +3674,7 @@ index 0914a5230..33bb3844a 100644 new_time = time.time() interval = new_time - state.last_time self.metrics_collector.observe_inter_token_latency( -@@ -1976,7 +2012,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1976,7 +2175,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi num_new_tokens, ) state.last_time = new_time @@ -2929,7 +3684,7 @@ index 0914a5230..33bb3844a 100644 if state.finished: retraction_count = ( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py -index 86b009df4..16ebd52ae 100644 +index 86b009d..16ebd52 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ from sglang.srt.managers.io_struct import ( @@ -2953,7 +3708,7 @@ index 86b009df4..16ebd52ae 100644 parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py -index fa08bb66a..fa539315c 100644 +index fa08bb6..fa53931 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -347,6 +347,84 @@ def alloc_decode_kernel( @@ -3060,7 +3815,7 @@ index fa08bb66a..fa539315c 100644 prefix_lens, seq_lens, diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py -index d7cd472a9..81fae740f 100644 +index d7cd472..81fae74 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -76,6 +76,7 @@ class HiRadixCache(RadixCache): @@ -3117,7 +3872,7 @@ index d7cd472a9..81fae740f 100644 self._inc_hit_count(new_node, chunked) total_prefix_length += prefix_len diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py -index 1d917137c..669e5c518 100644 +index 1d91713..669e5c5 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1777,9 +1777,12 @@ class NSATokenToKVPool(MLATokenToKVPool): @@ -3200,7 +3955,7 @@ index 1d917137c..669e5c518 100644 kv_size_bytes = super().get_kv_size_bytes() for index_k_cache in self.index_k_with_scale_buffer: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py -index 42b169728..8e799196a 100644 +index 42b1697..8e79919 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -495,7 +495,17 @@ class RadixCache(BasePrefixCache): @@ -3235,7 +3990,7 @@ index 42b169728..8e799196a 100644 return delta diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py -index 255d41ccc..f93bedb4d 100644 +index 255d41c..f93bedb 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -20,7 +20,10 @@ import time @@ -3418,10 +4173,26 @@ index 255d41ccc..f93bedb4d 100644 if self.disagg_mode == DisaggregationMode.NULL: queue_duration = self.forward_entry_time - self.wait_queue_entry_time diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py -index 234523532..f5d479945 100644 +index 2345235..307b656 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py -@@ -909,6 +909,28 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): +@@ -266,6 +266,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): + # For logprob + return_logprob: bool = False + top_logprobs_nums: Optional[List[int]] = None ++ top_logprobs_ps: Optional[List[float]] = None + token_ids_logprobs: Optional[List[List[int]]] = None + + # For logits and logprobs post processing +@@ -399,6 +400,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): + orig_seq_lens=batch.orig_seq_lens, + return_logprob=batch.return_logprob, + top_logprobs_nums=batch.top_logprobs_nums, ++ top_logprobs_ps=batch.top_logprobs_ps, + token_ids_logprobs=batch.token_ids_logprobs, + is_extend_in_batch=batch.is_extend_in_batch, + can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, +@@ -909,6 +911,28 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): tokens_padded = (tokens + rank_size - 1) // rank_size * rank_size self._pad_inputs_to_size(model_runner, tokens_padded, self.batch_size) @@ -3451,7 +4222,7 @@ index 234523532..f5d479945 100644 self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 275775a73..f0bd3ebf8 100644 +index 275775a..fdeb693 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -395,7 +395,12 @@ class ModelRunner(ModelRunnerKVCacheMixin): @@ -3513,7 +4284,23 @@ index 275775a73..f0bd3ebf8 100644 # Normalize num_token_non_padded to be local to this attention TP rank if needed. if ( -@@ -2664,6 +2681,42 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -2553,6 +2570,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): + forward_batch.sampling_info, + forward_batch.return_logprob, + forward_batch.top_logprobs_nums, ++ forward_batch.top_logprobs_ps, + forward_batch.token_ids_logprobs, + # For prefill, we only use the position of the last token. + ( +@@ -2592,6 +2610,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): + forward_batch.sampling_info, + forward_batch.return_logprob, + forward_batch.top_logprobs_nums, ++ forward_batch.top_logprobs_ps, + forward_batch.token_ids_logprobs, + ) + +@@ -2664,6 +2683,42 @@ class ModelRunner(ModelRunnerKVCacheMixin): device=self.device, ) @@ -3557,7 +4344,7 @@ index 275775a73..f0bd3ebf8 100644 def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/models/deepseek_common/attention_backend_handler.py b/python/sglang/srt/models/deepseek_common/attention_backend_handler.py -index cc673a9ca..06c430d2c 100644 +index cc673a9..06c430d 100644 --- a/python/sglang/srt/models/deepseek_common/attention_backend_handler.py +++ b/python/sglang/srt/models/deepseek_common/attention_backend_handler.py @@ -1,4 +1,5 @@ @@ -3576,7 +4363,7 @@ index cc673a9ca..06c430d2c 100644 return AttnForwardMethod.MHA_ONE_SHOT return AttnForwardMethod.MLA diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py -index cb13a7c67..d9669ce08 100644 +index cb13a7c..d9669ce 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -29,6 +29,7 @@ from sglang.srt.layers.attention.nsa.utils import ( @@ -3607,7 +4394,7 @@ index cb13a7c67..d9669ce08 100644 if not forward_batch.forward_mode.is_idle(): if residual is not None: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py -index 1583dd788..a35c00f96 100644 +index 1583dd7..a35c00f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1085,6 +1085,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin): @@ -3828,7 +4615,7 @@ index 1583dd788..a35c00f96 100644 if normal_end_layer != self.end_layer: hidden_states, residual = model_forward_maybe_tbo( diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py -index db8c1c7ce..53ffadf6d 100644 +index db8c1c7..53ffadf 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -678,8 +678,13 @@ class Glm4MoeDecoderLayer(nn.Module): @@ -3856,7 +4643,7 @@ index db8c1c7ce..53ffadf6d 100644 hidden_states, residual = self.layer_communicator.prepare_attn( diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py -index 1f6e75364..546cce4ab 100644 +index 1f6e753..546cce4 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -103,7 +103,7 @@ class Glm4MoeModelNextN(nn.Module): @@ -3869,7 +4656,7 @@ index 1f6e75364..546cce4ab 100644 ) diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py -index 324de18b4..fc72faa03 100644 +index 324de18..fc72faa 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -52,11 +52,31 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): @@ -3974,7 +4761,7 @@ index 324de18b4..fc72faa03 100644 continue diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py -index 2cf813bce..1250c49e4 100644 +index 2cf813b..1250c49 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -17,6 +17,7 @@ @@ -4088,7 +4875,7 @@ index 2cf813bce..1250c49e4 100644 weights_out_dict = dict(weights_in) diff --git a/python/sglang/srt/models/kimi_k25.py b/python/sglang/srt/models/kimi_k25.py -index d8399a691..0277bc671 100644 +index d8399a6..0277bc6 100644 --- a/python/sglang/srt/models/kimi_k25.py +++ b/python/sglang/srt/models/kimi_k25.py @@ -666,25 +666,30 @@ class KimiK25ForConditionalGeneration(nn.Module): @@ -4208,7 +4995,7 @@ index d8399a691..0277bc671 100644 diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py -index 49f938a1c..8eea383bb 100644 +index 49f938a..8eea383 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -85,6 +85,11 @@ class LlamaDecoderLayer(LlamaDecoderLayer): @@ -4236,7 +5023,7 @@ index 49f938a1c..8eea383bb 100644 # idle batch diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py -index f01225487..1dad8bb8e 100644 +index f012254..1dad8bb 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -372,6 +372,7 @@ class Qwen3_5LinearDecoderLayer(nn.Module): @@ -4312,7 +5099,7 @@ index f01225487..1dad8bb8e 100644 return hidden_states, residual diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index d641826e3..3abc39ef3 100644 +index d641826..3abc39e 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -711,14 +711,19 @@ class Qwen3LLMModel(Qwen3Model): @@ -4340,7 +5127,7 @@ index d641826e3..3abc39ef3 100644 positions, hidden_states, diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py -index 33cce6fe2..0970c4550 100644 +index 33cce6f..0970c45 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -1,6 +1,9 @@ @@ -4400,7 +5187,7 @@ index 33cce6fe2..0970c4550 100644 self, image_data: List[Union[str, bytes]], diff --git a/python/sglang/srt/multimodal/processors/kimi_k25.py b/python/sglang/srt/multimodal/processors/kimi_k25.py -index d8bb9ceb3..9311a431b 100644 +index d8bb9ce..9311a43 100644 --- a/python/sglang/srt/multimodal/processors/kimi_k25.py +++ b/python/sglang/srt/multimodal/processors/kimi_k25.py @@ -25,6 +25,18 @@ class KimiK2_5VLImageProcessor(SGLangBaseProcessor): @@ -4423,7 +5210,7 @@ index d8bb9ceb3..9311a431b 100644 async def process_mm_data_async( self, diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py -index 4395654e4..f9b5ea4ab 100644 +index 4395654..f9b5ea4 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -317,7 +317,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): @@ -4436,7 +5223,7 @@ index 4395654e4..f9b5ea4ab 100644 image_data=image_data, video_data=request_obj.video_data, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index b080aeb16..b0322fef4 100644 +index b080aeb..b0322fe 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -580,6 +580,7 @@ class ServerArgs: @@ -4554,7 +5341,7 @@ index b080aeb16..b0322fef4 100644 return PortArgs( tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -index 5fe45086c..b283d2e9b 100644 +index 5fe4508..b283d2e 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: @@ -4585,7 +5372,7 @@ index 5fe45086c..b283d2e9b 100644 self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py -index ac629c7ee..904f54b4a 100644 +index ac629c7..904f54b 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -337,7 +337,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): @@ -4637,7 +5424,7 @@ index ac629c7ee..904f54b4a 100644 @dataclass diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py -index 32b3a520a..d7f940147 100644 +index 32b3a52..d7f9401 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -234,7 +234,10 @@ class EAGLEWorker(TpModelWorker): @@ -4653,7 +5440,7 @@ index 32b3a520a..d7f940147 100644 Device2DraftCudaGraphRunner = { diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py -index 4636128fa..a9b61df39 100644 +index 4636128..a9b61df 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2359,6 +2359,8 @@ class SafeUnpickler(pickle.Unpickler): @@ -4666,7 +5453,7 @@ index 4636128fa..a9b61df39 100644 DENY_CLASSES = { diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py -index 3be16446e..1b2371c83 100644 +index 3be1644..1b2371c 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -69,6 +69,9 @@ def _check_tensors( diff --git a/docker/patch/v0.5.9/sglang.patch b/docker/patch/v0.5.9/sglang.patch index 8001ceb73d..f21b415e99 100644 --- a/docker/patch/v0.5.9/sglang.patch +++ b/docker/patch/v0.5.9/sglang.patch @@ -1,5 +1,5 @@ diff --git a/.codespellrc b/.codespellrc -index 808a344b4..a34624958 100644 +index 808a344..a346249 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ @@ -8,7 +8,7 @@ index 808a344b4..a34624958 100644 +ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, medias skip = *.json,*.jsonl,*.patch,*.txt diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index 6fbd1db82..4c681b58d 100644 +index 6fbd1db..4c681b5 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -274,6 +274,7 @@ class ModelConfig: @@ -34,7 +34,7 @@ index 6fbd1db82..4c681b58d 100644 elif not needs_tf_v5: logger.warning( diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py -index da4629e52..c03f98231 100644 +index da4629e..c03f982 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -17,6 +17,7 @@ class KVArgs: @@ -46,7 +46,7 @@ index da4629e52..c03f98231 100644 aux_data_lens: List[int] aux_item_lens: List[int] diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py -index 67fe82ad6..2ef25c49b 100644 +index 67fe82a..2ef25c4 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -24,6 +24,7 @@ from sglang.srt.disaggregation.base.conn import ( @@ -126,7 +126,7 @@ index 67fe82ad6..2ef25c49b 100644 "prefill_pp_size": self.pp_size, "prefill_page_size": self.page_size, diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py -index 1d8baf002..1ebb95929 100644 +index 1d8baf0..1ebb959 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,6 +21,7 @@ Life cycle of a request in the decode server @@ -312,7 +312,7 @@ index 1d8baf002..1ebb95929 100644 if not hasattr(self, "polling_count"): diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py -index a2d08e0e3..ed0790604 100644 +index a2d08e0..ed07906 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -117,7 +117,7 @@ def _convert(data): @@ -353,7 +353,7 @@ index a2d08e0e3..ed0790604 100644 mm_item = MultimodalDataItem.from_dict( { diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py -index d0d4efd95..b3a207063 100644 +index d0d4efd..b3a2070 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -30,7 +30,7 @@ from sglang.srt.disaggregation.common.utils import ( @@ -541,7 +541,7 @@ index d0d4efd95..b3a207063 100644 def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py -index fbc801635..ade111c9f 100644 +index fbc8016..ade111c 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,6 +20,7 @@ Life cycle of a request in the prefill server @@ -715,7 +715,7 @@ index fbc801635..ade111c9f 100644 transferred_rids: List[str] = [] diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py -index 6d58f415a..84723c342 100644 +index 6d58f41..84723c3 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -21,6 +21,17 @@ if TYPE_CHECKING: @@ -907,7 +907,7 @@ index 6d58f415a..84723c342 100644 ######################### diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index 8f1069c00..e47589295 100644 +index 8f1069c..e475892 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1999,7 +1999,10 @@ def get_tensor_model_parallel_world_size(): @@ -923,7 +923,7 @@ index 8f1069c00..e47589295 100644 # ATTN_TP diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py -index 0ed5a1b44..67e33c650 100644 +index 0ed5a1b..67e33c6 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import ( @@ -960,7 +960,7 @@ index 0ed5a1b44..67e33c650 100644 """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 1d6816c01..402b42e05 100644 +index 1d6816c..402b42e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -115,6 +115,7 @@ from sglang.srt.managers.io_struct import ( @@ -1032,7 +1032,7 @@ index 1d6816c01..402b42e05 100644 @auth_level(AuthLevel.ADMIN_OPTIONAL) async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py -index 8293796a2..bff34e422 100644 +index 8293796..bff34e4 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -244,6 +244,7 @@ class Envs: @@ -1044,7 +1044,7 @@ index 8293796a2..bff34e422 100644 # Scheduler: others: SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1) # in seconds. Set if you observe high memory accumulation over a long serving period. diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py -index 1cdf65b91..4783cd18f 100644 +index 1cdf65b..4783cd1 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -630,7 +630,6 @@ def _get_k_and_s_triton( @@ -1064,7 +1064,7 @@ index 1cdf65b91..4783cd18f 100644 buf_numel_per_page: tl.constexpr, index_head_dim: tl.constexpr, diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py -index ca54a931b..3540f77ba 100644 +index ca54a93..3540f77 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -1,6 +1,7 @@ @@ -1149,7 +1149,7 @@ index ca54a931b..3540f77ba 100644 if enable_dual_stream: current_stream = torch.cuda.current_stream() diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py -index 00ef96f9b..c2c2c78fe 100644 +index 00ef96f..c2c2c78 100644 --- a/python/sglang/srt/layers/attention/nsa/utils.py +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -91,20 +91,29 @@ def nsa_cp_round_robin_split_data(input_: Union[torch.Tensor, List]): @@ -1215,7 +1215,7 @@ index 00ef96f9b..c2c2c78fe 100644 position_id_list = list( diff --git a/python/sglang/srt/layers/communicator_nsa_cp.py b/python/sglang/srt/layers/communicator_nsa_cp.py -index 296d14568..f4606a769 100644 +index 296d145..f4606a7 100644 --- a/python/sglang/srt/layers/communicator_nsa_cp.py +++ b/python/sglang/srt/layers/communicator_nsa_cp.py @@ -34,7 +34,6 @@ from sglang.srt.layers.communicator import ( @@ -1254,7 +1254,7 @@ index 296d14568..f4606a769 100644 attn_cp_all_gather_into_tensor( hidden_states, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py -index 5bf5aa0c8..e52f39fd8 100644 +index 5bf5aa0..e52f39f 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -90,11 +90,11 @@ class _DpGatheredBufferWrapper: @@ -1275,10 +1275,122 @@ index 5bf5aa0c8..e52f39fd8 100644 @classmethod def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device): diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py -index aff05bf42..130359232 100644 +index aff05bf..68b67c7 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py -@@ -872,11 +872,6 @@ class LogitsProcessor(nn.Module): +@@ -47,6 +47,7 @@ from sglang.srt.layers.utils.logprob import ( + get_token_ids_logprobs_prefill, + get_top_logprobs_chunk, + get_top_logprobs_prefill, ++ get_top_p_logprobs_prefill, + ) + from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding + from sglang.srt.model_executor.forward_batch_info import ( +@@ -78,6 +79,9 @@ class LogitsProcessorOutput: + # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] + next_token_top_logprobs_val: Optional[List] = None + next_token_top_logprobs_idx: Optional[List] = None ++ # The logprobs and ids of the top-p tokens in output positions (variable-length per position) ++ next_token_top_p_logprobs_val: Optional[List] = None ++ next_token_top_p_logprobs_idx: Optional[List] = None + # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids) + # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests) + next_token_token_ids_logprobs_val: Optional[ +@@ -91,6 +95,9 @@ class LogitsProcessorOutput: + # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] + input_top_logprobs_val: Optional[List] = None + input_top_logprobs_idx: Optional[List] = None ++ # The logprobs and ids of the top-p tokens in input positions (variable-length per position) ++ input_top_p_logprobs_val: Optional[List] = None ++ input_top_p_logprobs_idx: Optional[List] = None + # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids) + # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization) + input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = ( +@@ -115,12 +122,14 @@ class LogitsMetadata: + + extend_return_logprob: bool = False + extend_return_top_logprob: bool = False ++ extend_return_top_p_logprob: bool = False + extend_token_ids_logprob: bool = False + extend_seq_lens: Optional[torch.Tensor] = None + extend_seq_lens_cpu: Optional[List[int]] = None + extend_logprob_start_lens_cpu: Optional[List[int]] = None + extend_logprob_pruned_lens_cpu: Optional[List[int]] = None + top_logprobs_nums: Optional[List[int]] = None ++ top_logprobs_ps: Optional[List[float]] = None + extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None + token_ids_logprobs: Optional[List[List[int]]] = None + +@@ -160,6 +169,10 @@ class LogitsMetadata: + extend_return_top_logprob = any( + x > 0 for x in forward_batch.top_logprobs_nums + ) ++ extend_return_top_p_logprob = ( ++ forward_batch.top_logprobs_ps is not None ++ and any(x > 0.0 for x in forward_batch.top_logprobs_ps) ++ ) + extend_token_ids_logprob = any( + x is not None for x in forward_batch.token_ids_logprobs + ) +@@ -174,6 +187,8 @@ class LogitsMetadata: + extend_logprob_pruned_lens_cpu.append(extend_len - start_len) + else: + extend_return_logprob = extend_return_top_logprob = ( ++ extend_return_top_p_logprob ++ ) = ( + extend_token_ids_logprob + ) = extend_logprob_pruned_lens_cpu = False + +@@ -183,12 +198,14 @@ class LogitsMetadata: + next_token_logits_buffer=forward_batch.next_token_logits_buffer, + extend_return_logprob=extend_return_logprob, + extend_return_top_logprob=extend_return_top_logprob, ++ extend_return_top_p_logprob=extend_return_top_p_logprob, + extend_token_ids_logprob=extend_token_ids_logprob, + extend_seq_lens=forward_batch.extend_seq_lens, + extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, + extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, + top_logprobs_nums=forward_batch.top_logprobs_nums, ++ top_logprobs_ps=forward_batch.top_logprobs_ps, + token_ids_logprobs=forward_batch.token_ids_logprobs, + extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu, + padded_static_len=forward_batch.padded_static_len, +@@ -391,6 +408,8 @@ class LogitsProcessor(nn.Module): + input_token_logprobs=logprobs_result.input_token_logprobs, + input_top_logprobs_val=logprobs_result.input_top_logprobs_val, + input_top_logprobs_idx=logprobs_result.input_top_logprobs_idx, ++ input_top_p_logprobs_val=logprobs_result.input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=logprobs_result.input_top_p_logprobs_idx, + input_token_ids_logprobs_val=logprobs_result.input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=logprobs_result.input_token_ids_logprobs_idx, + mm_input_embeds=logits_metadata.mm_input_embeds, +@@ -619,6 +638,15 @@ class LogitsProcessor(nn.Module): + else: + input_top_logprobs_val = input_top_logprobs_idx = None + ++ # Get the logprob of top-p tokens ++ if logits_metadata.extend_return_top_p_logprob: ++ ( ++ input_top_p_logprobs_val, ++ input_top_p_logprobs_idx, ++ ) = get_top_p_logprobs_prefill(input_logprobs, logits_metadata) ++ else: ++ input_top_p_logprobs_val = input_top_p_logprobs_idx = None ++ + # Get the logprob of given token id + if logits_metadata.extend_token_ids_logprob: + ( +@@ -637,6 +665,8 @@ class LogitsProcessor(nn.Module): + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, ++ input_top_p_logprobs_val=input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=input_top_p_logprobs_idx, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, + ) +@@ -872,11 +902,6 @@ class LogitsProcessor(nn.Module): None, # bias True, # is_vnni ) @@ -1290,160 +1402,17 @@ index aff05bf42..130359232 100644 else: logits = torch.matmul( hidden_states.to(lm_head.weight.dtype), lm_head.weight.T -diff --git a/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py -new file mode 100644 -index 000000000..8d3d0f92e ---- /dev/null -+++ b/python/sglang/srt/layers/moe/ep_moe/deepep_bf16_kernels.py -@@ -0,0 +1,146 @@ -+"""Fused Triton kernels for DeepEP BF16 low-latency MoE decode. -+ -+Replaces the naive activation + masking pipeline (5+ CUDA kernels for silu+mul -+and arange+comparison+masked_fill+copy) with a single Triton elementwise kernel, -+while keeping cuBLAS batched GEMM for the matrix multiplies. -+ -+Pipeline: bmm → fused_act_mul_masked (in-place) → bmm(out=hidden) -+ (3 ops total: 2 cuBLAS + 1 Triton, vs original 7-8 separate CUDA kernels) -+""" -+ -+import torch -+import triton -+import triton.language as tl -+ -+ -+@triton.jit -+def _silu_mul_masked_kernel( -+ gate_up_ptr, -+ masked_m_ptr, -+ M, -+ N, -+ stride_ge, -+ stride_gm, -+ stride_gn, -+ BLOCK: tl.constexpr, -+): -+ """Fused SiLU(gate) * up with per-expert masking, written in-place. -+ -+ gate_up: [E, M, 2*N] — first N cols are gate, last N cols are up. -+ Writes SiLU(gate)*up to gate_up[:,:,:N] in-place. -+ Rows m >= masked_m[e] are zeroed. -+ """ -+ expert_id = tl.program_id(1) -+ pid = tl.program_id(0) -+ -+ expert_valid_m = tl.load(masked_m_ptr + expert_id) -+ -+ offs = pid * BLOCK + tl.arange(0, BLOCK) -+ total = M * N -+ mask = offs < total -+ -+ m = offs // N -+ n = offs % N -+ -+ gate_base = gate_up_ptr + expert_id * stride_ge -+ -+ gate_val = tl.load(gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0) -+ up_val = tl.load( -+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0 -+ ) -+ -+ gate_f32 = gate_val.to(tl.float32) -+ result = (gate_f32 * tl.sigmoid(gate_f32)) * up_val.to(tl.float32) -+ -+ # Zero invalid rows -+ valid = m < expert_valid_m -+ result = tl.where(valid, result, 0.0) -+ -+ tl.store( -+ gate_base + m * stride_gm + n * stride_gn, -+ result.to(gate_up_ptr.dtype.element_ty), -+ mask=mask, -+ ) -+ -+ -+@triton.jit -+def _gelu_mul_masked_kernel( -+ gate_up_ptr, -+ masked_m_ptr, -+ M, -+ N, -+ stride_ge, -+ stride_gm, -+ stride_gn, -+ BLOCK: tl.constexpr, -+): -+ """Fused GELU(gate) * up with per-expert masking, written in-place.""" -+ expert_id = tl.program_id(1) -+ pid = tl.program_id(0) -+ -+ expert_valid_m = tl.load(masked_m_ptr + expert_id) -+ -+ offs = pid * BLOCK + tl.arange(0, BLOCK) -+ total = M * N -+ mask = offs < total -+ -+ m = offs // N -+ n = offs % N -+ -+ gate_base = gate_up_ptr + expert_id * stride_ge -+ -+ gate_val = tl.load(gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0) -+ up_val = tl.load( -+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0 -+ ) -+ -+ g = gate_val.to(tl.float32) -+ kAlpha = 0.7978845608028654 -+ gate_act = 0.5 * g * (1.0 + tl.math.tanh(kAlpha * (g + 0.044715 * g * g * g))) -+ result = gate_act * up_val.to(tl.float32) -+ -+ valid = m < expert_valid_m -+ result = tl.where(valid, result, 0.0) -+ -+ tl.store( -+ gate_base + m * stride_gm + n * stride_gn, -+ result.to(gate_up_ptr.dtype.element_ty), -+ mask=mask, -+ ) -+ -+ -+def fused_act_mul_masked_inplace( -+ gate_up: torch.Tensor, -+ intermediate_size: int, -+ masked_m: torch.Tensor, -+ use_gelu: bool = False, -+) -> None: -+ """Fused activation + multiply + masking, written in-place to gate_up[:,:,:I]. -+ -+ After this call, gate_up[:, :, :intermediate_size] contains the masked -+ activated intermediate, suitable for the down projection GEMM. -+ -+ Args: -+ gate_up: [E, M, 2*I] output of bmm(tokens, w13.T), modified in-place -+ intermediate_size: I -+ masked_m: [E] per-expert valid token count -+ use_gelu: use GELU instead of SiLU -+ """ -+ E, M, _ = gate_up.shape -+ N = intermediate_size -+ -+ total = M * N -+ BLOCK = 1024 -+ grid = (triton.cdiv(total, BLOCK), E) -+ -+ kernel = _gelu_mul_masked_kernel if use_gelu else _silu_mul_masked_kernel -+ kernel[grid]( -+ gate_up, -+ masked_m, -+ M, -+ N, -+ gate_up.stride(0), -+ gate_up.stride(1), -+ gate_up.stride(2), -+ BLOCK=BLOCK, -+ ) +@@ -1073,6 +1098,8 @@ class LogitsProcessor(nn.Module): + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, ++ input_top_p_logprobs_val=input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=input_top_p_logprobs_idx, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, + # FIXME: These fields are not logits-related but are passed through here as a diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py -index ebcc696ec..3b527021a 100644 +index ebcc696..3b52702 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -132,11 +132,12 @@ class DeepEPMoE(FusedMoE): @@ -1563,7 +1532,7 @@ index ebcc696ec..3b527021a 100644 self, dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py -index de8a07ab3..952f8a67b 100644 +index de8a07a..952f8a6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -697,6 +697,7 @@ class FusedMoE(torch.nn.Module): @@ -1607,7 +1576,7 @@ index de8a07ab3..952f8a67b 100644 ) diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py -index 00bd68755..12d5577af 100644 +index 00bd687..12d5577 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -8,10 +8,15 @@ import torch @@ -1668,7 +1637,7 @@ index 00bd68755..12d5577af 100644 def get_routed_experts( diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -index 8539639d5..d44496c2f 100644 +index 8539639..d44496c 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -388,6 +388,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): @@ -1742,7 +1711,7 @@ index 8539639d5..d44496c2f 100644 buffer = self._get_buffer() diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py -index 4cbfed6f9..88b452744 100644 +index 4cbfed6..88b4527 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -499,7 +499,7 @@ class CompressedTensorsConfig(QuantizationConfig): @@ -1765,7 +1734,7 @@ index 4cbfed6f9..88b452744 100644 self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py -index 6264f36d0..f0310e305 100644 +index 6264f36..f0310e3 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -17,7 +17,10 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( @@ -1884,7 +1853,7 @@ index 6264f36d0..f0310e305 100644 is_k_full=self.is_k_full, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py -index ae0614635..3b6a8d254 100644 +index ae06146..3b6a8d2 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -305,9 +305,6 @@ class RotaryEmbedding(MultiPlatformOp): @@ -1907,11 +1876,206 @@ index ae0614635..3b6a8d254 100644 # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 assert ( fused_set_kv_buffer_arg is None +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index f78d83d..706b0eb 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -12,7 +12,11 @@ from sglang.srt.layers.dp_attention import ( + ) + from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.layers.utils.hash import murmur_hash32 +-from sglang.srt.layers.utils.logprob import get_token_ids_logprobs, get_top_logprobs ++from sglang.srt.layers.utils.logprob import ( ++ get_token_ids_logprobs, ++ get_top_logprobs, ++ get_top_p_logprobs, ++) + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + from sglang.srt.sampling.sampling_params import TOP_K_ALL + from sglang.srt.server_args import get_global_server_args +@@ -79,6 +83,7 @@ class Sampler(nn.Module): + sampling_info: SamplingBatchInfo, + return_logprob: bool, + top_logprobs_nums: List[int], ++ top_logprobs_ps: Optional[List[float]], + token_ids_logprobs: List[List[int]], + positions: torch.Tensor, + ): +@@ -176,6 +181,7 @@ class Sampler(nn.Module): + logits_output, + logprobs, + top_logprobs_nums, ++ top_logprobs_ps, + token_ids_logprobs, + sampling_info, + batch_next_token_ids, +@@ -314,6 +320,7 @@ class Sampler(nn.Module): + logits_output: LogitsProcessorOutput, + logprobs: torch.Tensor, + top_logprobs_nums: List[int], ++ top_logprobs_ps: Optional[List[float]], + token_ids_logprobs: List[List[int]], + sampling_info: SamplingBatchInfo, + batch_next_token_ids: torch.Tensor, +@@ -328,6 +335,13 @@ class Sampler(nn.Module): + logits_output.next_token_top_logprobs_idx, + ) = get_top_logprobs(logprobs, top_logprobs_nums) + ++ # Attach top-p logprobs ++ if top_logprobs_ps is not None and any(x > 0.0 for x in top_logprobs_ps): ++ ( ++ logits_output.next_token_top_p_logprobs_val, ++ logits_output.next_token_top_p_logprobs_idx, ++ ) = get_top_p_logprobs(logprobs, top_logprobs_ps) ++ + if any(x is not None for x in token_ids_logprobs): + ( + logits_output.next_token_token_ids_logprobs_val, +@@ -362,6 +376,7 @@ class Sampler(nn.Module): + sampling_info: SamplingBatchInfo, + return_logprob: bool, + top_logprobs_nums: List[int], ++ top_logprobs_ps: Optional[List[float]], + token_ids_logprobs: List[List[int]], + ) -> None: + """ +diff --git a/python/sglang/srt/layers/utils/logprob.py b/python/sglang/srt/layers/utils/logprob.py +index 6f84c15..3e526d3 100644 +--- a/python/sglang/srt/layers/utils/logprob.py ++++ b/python/sglang/srt/layers/utils/logprob.py +@@ -25,6 +25,8 @@ class InputLogprobsResult: + input_token_logprobs: torch.Tensor + input_top_logprobs_val: Optional[List] = None + input_top_logprobs_idx: Optional[List] = None ++ input_top_p_logprobs_val: Optional[List] = None ++ input_top_p_logprobs_idx: Optional[List] = None + input_token_ids_logprobs_val: Optional[List] = None + input_token_ids_logprobs_idx: Optional[List] = None + +@@ -96,6 +98,75 @@ def get_top_logprobs_raw( + return top_logprobs_val, top_logprobs_idx + + ++ ++def get_top_p_logprobs_raw( ++ logprobs: torch.Tensor, ++ top_logprobs_ps: List[float], ++ stage: LogprobStage, ++ extend_logprob_pruned_lens_cpu: Optional[List[int]] = None, ++ no_copy_to_cpu: bool = False, ++): ++ """Get top-p logprobs: return tokens whose cumulative probability >= top_p threshold.""" ++ sorted_logprobs, sorted_indices = logprobs.sort(dim=-1, descending=True) ++ sorted_probs = sorted_logprobs.exp() ++ cumsum_probs = torch.cumsum(sorted_probs, dim=-1) ++ ++ top_logprobs_val = [] ++ top_logprobs_idx = [] ++ ++ if stage == LogprobStage.DECODE: ++ cumsum_cpu = cumsum_probs.cpu() ++ sorted_logprobs_cpu = sorted_logprobs.cpu() ++ sorted_indices_cpu = sorted_indices.cpu() ++ ++ for i, p in enumerate(top_logprobs_ps): ++ if p <= 0.0: ++ top_logprobs_val.append([]) ++ top_logprobs_idx.append([]) ++ continue ++ mask = cumsum_cpu[i] >= p ++ if mask.any(): ++ cutoff = mask.nonzero(as_tuple=True)[0][0].item() + 1 ++ else: ++ cutoff = sorted_logprobs_cpu.shape[1] ++ cutoff = max(cutoff, 1) ++ top_logprobs_val.append(sorted_logprobs_cpu[i, :cutoff].tolist()) ++ top_logprobs_idx.append(sorted_indices_cpu[i, :cutoff].tolist()) ++ else: ++ cumsum_cpu = cumsum_probs.cpu() ++ sorted_logprobs_cpu = sorted_logprobs.cpu() ++ sorted_indices_cpu = sorted_indices.cpu() ++ ++ pt = 0 ++ for p, pruned_len in zip(top_logprobs_ps, extend_logprob_pruned_lens_cpu): ++ if pruned_len <= 0: ++ top_logprobs_val.append([]) ++ top_logprobs_idx.append([]) ++ continue ++ ++ pos_vals = [] ++ pos_idxs = [] ++ for j in range(pruned_len): ++ row = pt + j ++ if p <= 0.0: ++ pos_vals.append([]) ++ pos_idxs.append([]) ++ continue ++ mask = cumsum_cpu[row] >= p ++ if mask.any(): ++ cutoff = mask.nonzero(as_tuple=True)[0][0].item() + 1 ++ else: ++ cutoff = sorted_logprobs_cpu.shape[1] ++ cutoff = max(cutoff, 1) ++ pos_vals.append(sorted_logprobs_cpu[row, :cutoff].tolist()) ++ pos_idxs.append(sorted_indices_cpu[row, :cutoff].tolist()) ++ top_logprobs_val.append(pos_vals) ++ top_logprobs_idx.append(pos_idxs) ++ pt += pruned_len ++ ++ return top_logprobs_val, top_logprobs_idx ++ ++ + def get_top_logprobs_prefill( + all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata + ): +@@ -114,6 +185,31 @@ def get_top_logprobs( + return get_top_logprobs_raw(logprobs, top_logprobs_nums, stage=LogprobStage.DECODE) + + ++def get_top_p_logprobs_prefill( ++ all_logprobs: torch.Tensor, logits_metadata: "LogitsMetadata" ++): ++ return get_top_p_logprobs_raw( ++ all_logprobs, ++ logits_metadata.top_logprobs_ps, ++ stage=LogprobStage.PREFILL, ++ extend_logprob_pruned_lens_cpu=logits_metadata.extend_logprob_pruned_lens_cpu, ++ ) ++ ++ ++def get_top_p_logprobs( ++ logprobs: torch.Tensor, ++ top_logprobs_ps: List[float], ++ no_copy_to_cpu: bool = False, ++): ++ result = get_top_p_logprobs_raw( ++ logprobs, ++ top_logprobs_ps, ++ stage=LogprobStage.DECODE, ++ no_copy_to_cpu=no_copy_to_cpu, ++ ) ++ return result ++ ++ + def get_token_ids_logprobs_raw( + logprobs: torch.Tensor, + token_ids_logprobs: List[Optional[List[int]]], diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py -index 652227860..7d3a5d0c4 100644 +index 6522278..0db2cbc 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py -@@ -405,6 +405,17 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): +@@ -387,6 +387,10 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, ++ input_top_p_logprobs_val=recv_obj.input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=recv_obj.input_top_p_logprobs_idx, ++ output_top_p_logprobs_val=recv_obj.output_top_p_logprobs_val, ++ output_top_p_logprobs_idx=recv_obj.output_top_p_logprobs_idx, + input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx, + output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, +@@ -405,6 +409,17 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): prefill_launch_delay=recv_obj.prefill_launch_delay, prefill_launch_latency=recv_obj.prefill_launch_latency, prefill_finished_ts=recv_obj.prefill_finished_ts, @@ -1930,7 +2094,7 @@ index 652227860..7d3a5d0c4 100644 def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py -index ff1774567..f947e71d7 100644 +index ff17745..df07928 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -101,6 +101,42 @@ class RequestTimingMetricsMixin: @@ -1976,7 +2140,80 @@ index ff1774567..f947e71d7 100644 @dataclass class SpeculativeDecodingMetricsMixin: -@@ -1403,6 +1439,20 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): +@@ -198,8 +234,12 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + top_logprobs_num: Optional[Union[List[int], int]] = None + # If return logprobs, the token ids to return logprob for. + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None ++ # If return logprobs, the top-p threshold for returning variable-length top logprobs per position. ++ top_logprobs_p: Optional[Union[List[float], float]] = None + # Whether to detokenize tokens in text in the returned logprobs. + return_text_in_logprobs: bool = False ++ # Whether to return logprobs encoded in base64 format. ++ return_logprobs_in_base64: bool = False + # Whether to stream output. + stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) +@@ -389,6 +429,8 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + self.top_logprobs_num = 0 + if not self.token_ids_logprob: # covers both None and [] + self.token_ids_logprob = None ++ if self.top_logprobs_p is None: ++ self.top_logprobs_p = 0.0 + + def _normalize_batch_inputs(self): + """Normalize inputs for a batch of examples, including parallel sampling expansion.""" +@@ -552,6 +594,9 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + self.top_logprobs_num = normalize_param( + self.top_logprobs_num, 0, "top_logprobs_num" + ) ++ self.top_logprobs_p = normalize_param( ++ self.top_logprobs_p, 0.0, "top_logprobs_p" ++ ) + + # Handle token_ids_logprob specially due to its nested structure + if not self.token_ids_logprob: # covers both None and [] +@@ -636,7 +681,9 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): + logprob_start_len=self.logprob_start_len[i], + top_logprobs_num=self.top_logprobs_num[i], + token_ids_logprob=self.token_ids_logprob[i], ++ top_logprobs_p=self.top_logprobs_p[i], + return_text_in_logprobs=self.return_text_in_logprobs, ++ return_logprobs_in_base64=self.return_logprobs_in_base64, + stream=self.stream, + log_metrics=self.log_metrics, + return_hidden_states=( +@@ -709,6 +756,8 @@ class TokenizedGenerateReqInput(BaseReq): + top_logprobs_num: int + # If return logprobs, the token id to return logprob for + token_ids_logprob: List[int] ++ # If return logprobs, the top-p threshold for variable-length top logprobs ++ top_logprobs_p: float + # Whether to stream output + stream: bool + +@@ -1008,6 +1057,10 @@ class BatchTokenIDOutput( + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] ++ input_top_p_logprobs_val: List[List] ++ input_top_p_logprobs_idx: List[List] ++ output_top_p_logprobs_val: List[List] ++ output_top_p_logprobs_idx: List[List] + input_token_ids_logprobs_val: List[List] + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] +@@ -1098,6 +1151,10 @@ class BatchStrOutput( + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] ++ input_top_p_logprobs_val: List[List] ++ input_top_p_logprobs_idx: List[List] ++ output_top_p_logprobs_val: List[List] ++ output_top_p_logprobs_idx: List[List] + input_token_ids_logprobs_val: List[List] + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] +@@ -1403,6 +1460,20 @@ class UpdateWeightsFromIPCReqOutput(BaseReq): message: str @@ -1997,7 +2234,7 @@ index ff1774567..f947e71d7 100644 @dataclass class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): success: bool -@@ -1802,6 +1852,10 @@ class GetLoadReqOutput(BaseReq): +@@ -1802,6 +1873,10 @@ class GetLoadReqOutput(BaseReq): num_waiting_reqs: int num_tokens: int ts_tic: float @@ -2009,7 +2246,7 @@ index ff1774567..f947e71d7 100644 @dataclass diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py -index e1236aa0f..daa598a1f 100644 +index e1236aa..daa598a 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -142,6 +142,39 @@ def _handle_output_by_index(output, i): @@ -2205,10 +2442,89 @@ index e1236aa0f..daa598a1f 100644 class SenderWrapper: def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py -index c07995798..dd8ca7167 100644 +index c079957..71912f3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py -@@ -1869,7 +1869,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): +@@ -520,6 +520,7 @@ class Req(ReqDllmMixin): + sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, ++ top_logprobs_p: float = 0.0, + dllm_config: Optional[DllmConfig] = None, + token_ids_logprob: List[int] = None, + stream: bool = False, +@@ -691,6 +692,7 @@ class Req(ReqDllmMixin): + # Start index to compute logprob from. + self.logprob_start_len = 0 + self.top_logprobs_num = top_logprobs_num ++ self.top_logprobs_p = top_logprobs_p + self.token_ids_logprob = token_ids_logprob + self.temp_scaled_logprobs = False + self.top_p_normalized_logprobs = False +@@ -702,12 +704,16 @@ class Req(ReqDllmMixin): + self.input_token_logprobs_idx: Optional[List[int]] = None + self.input_top_logprobs_val: Optional[List[float]] = None + self.input_top_logprobs_idx: Optional[List[int]] = None ++ self.input_top_p_logprobs_val: Optional[List] = None ++ self.input_top_p_logprobs_idx: Optional[List] = None + self.input_token_ids_logprobs_val: Optional[List[float]] = None + self.input_token_ids_logprobs_idx: Optional[List[int]] = None + # Temporary holder to store input_token_logprobs. + self.input_token_logprobs: Optional[List[Tuple[int]]] = None + self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None + self.temp_input_top_logprobs_idx: Optional[List[int]] = None ++ self.temp_input_top_p_logprobs_val: Optional[List] = None ++ self.temp_input_top_p_logprobs_idx: Optional[List] = None + self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None + self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None + +@@ -718,6 +724,9 @@ class Req(ReqDllmMixin): + # shape: (bs, k) + self.output_top_logprobs_val = [] + self.output_top_logprobs_idx = [] ++ # shape: (bs, variable) for top-p logprobs ++ self.output_top_p_logprobs_val = [] ++ self.output_top_p_logprobs_idx = [] + # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring) + self.output_token_ids_logprobs_val: List[ + Union[List[float], torch.Tensor] +@@ -726,7 +735,9 @@ class Req(ReqDllmMixin): + else: + self.output_token_logprobs_val = self.output_token_logprobs_idx = ( + self.output_top_logprobs_val +- ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( ++ ) = self.output_top_logprobs_idx = self.output_top_p_logprobs_val = ( ++ self.output_top_p_logprobs_idx ++ ) = self.output_token_ids_logprobs_val = ( + self.output_token_ids_logprobs_idx + ) = None + self.hidden_states: List[List[float]] = [] +@@ -1116,6 +1127,8 @@ class Req(ReqDllmMixin): + self.input_token_logprobs = None + self.temp_input_top_logprobs_val = None + self.temp_input_top_logprobs_idx = None ++ self.temp_input_top_p_logprobs_val = None ++ self.temp_input_top_p_logprobs_idx = None + self.extend_logprob_start_len = 0 + self.is_chunked = 0 + self.mamba_pool_idx = None +@@ -1260,6 +1273,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + # For processing logprobs + return_logprob: bool = False + top_logprobs_nums: Optional[List[int]] = None ++ top_logprobs_ps: Optional[List[float]] = None + token_ids_logprobs: Optional[List[List[int]]] = None + + # For logits and logprob post processing +@@ -1651,6 +1665,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + if self.return_logprob: + self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] ++ self.top_logprobs_ps = [r.top_logprobs_p for r in reqs] + self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] + + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] +@@ -1869,7 +1884,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): while first_iter or ( not self.check_decode_mem(selected_indices=sorted_indices) ): @@ -2220,8 +2536,52 @@ index c07995798..dd8ca7167 100644 # Always keep at least one request break +@@ -2094,9 +2112,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.return_logprob = any(req.return_logprob for req in self.reqs) + if self.return_logprob: + self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices] ++ self.top_logprobs_ps = [self.top_logprobs_ps[i] for i in keep_indices] + self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices] + else: + self.top_logprobs_nums = None ++ self.top_logprobs_ps = None + self.token_ids_logprobs = None + + self.has_stream = any(req.stream for req in self.reqs) +@@ -2143,12 +2163,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.mamba_track_seqlens = None + if self.return_logprob and other.return_logprob: + self.top_logprobs_nums.extend(other.top_logprobs_nums) ++ self.top_logprobs_ps.extend(other.top_logprobs_ps) + self.token_ids_logprobs.extend(other.token_ids_logprobs) + elif self.return_logprob: + self.top_logprobs_nums.extend([0] * len(other.reqs)) ++ self.top_logprobs_ps.extend([0.0] * len(other.reqs)) + self.token_ids_logprobs.extend([None] * len(other.reqs)) + elif other.return_logprob: + self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums ++ self.top_logprobs_ps = [0.0] * len(self.reqs) + other.top_logprobs_ps + self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs + self.reqs.extend(other.reqs) + if self.multimodal_inputs is not None: +@@ -2193,6 +2216,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + seq_lens_sum=self.seq_lens_sum, + return_logprob=self.return_logprob, + top_logprobs_nums=self.top_logprobs_nums, ++ top_logprobs_ps=self.top_logprobs_ps, + token_ids_logprobs=self.token_ids_logprobs, + global_num_tokens=self.global_num_tokens, + global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, +@@ -2352,6 +2376,7 @@ class ModelWorkerBatch: + # For logprob + return_logprob: bool + top_logprobs_nums: Optional[List[int]] ++ top_logprobs_ps: Optional[List[float]] + token_ids_logprobs: Optional[List[List[int]]] + + # For DP attention diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index a9ff0ac94..c124f43bc 100644 +index a9ff0ac..264b177 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -114,6 +114,7 @@ from sglang.srt.managers.io_struct import ( @@ -2261,8 +2621,16 @@ index a9ff0ac94..c124f43bc 100644 (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), +@@ -1505,6 +1512,7 @@ class Scheduler( + recv_req.sampling_params, + return_logprob=recv_req.return_logprob, + top_logprobs_num=recv_req.top_logprobs_num, ++ top_logprobs_p=recv_req.top_logprobs_p, + token_ids_logprob=recv_req.token_ids_logprob, + stream=recv_req.stream, + lora_id=recv_req.lora_id, diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py -index 30b2732b9..68090b161 100644 +index 30b2732..68090b1 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -609,12 +609,54 @@ class SchedulerMetricsMixin: @@ -2321,10 +2689,109 @@ index 30b2732b9..68090b161 100644 def get_loads(self: Scheduler, req: GetLoadsReqInput = None) -> GetLoadsReqOutput: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index 482bc6ca6..fbc486417 100644 +index 482bc6c..9b8b5cb 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -@@ -922,6 +922,18 @@ class SchedulerOutputProcessorMixin: +@@ -494,6 +494,13 @@ class SchedulerOutputProcessorMixin: + req.output_top_logprobs_idx.append( + logits_output.next_token_top_logprobs_idx[i] + ) ++ if req.top_logprobs_p > 0.0 and logits_output.next_token_top_p_logprobs_val is not None: ++ req.output_top_p_logprobs_val.append( ++ logits_output.next_token_top_p_logprobs_val[i] ++ ) ++ req.output_top_p_logprobs_idx.append( ++ logits_output.next_token_top_p_logprobs_idx[i] ++ ) + if req.token_ids_logprob is not None: + req.output_token_ids_logprobs_val.append( + logits_output.next_token_token_ids_logprobs_val[i] +@@ -623,6 +630,32 @@ class SchedulerOutputProcessorMixin: + # Clean up temp storage + req.temp_input_top_logprobs_idx = None + req.temp_input_top_logprobs_val = None ++ def _process_input_top_p_logprobs(self: Scheduler, req: Req) -> None: ++ """Process input top-p logprobs.""" ++ if req.top_logprobs_p <= 0.0: ++ return ++ ++ is_multi_item_scoring = self._is_multi_item_scoring(req) ++ ++ req.input_top_p_logprobs_val = [] if is_multi_item_scoring else [None] ++ req.input_top_p_logprobs_idx = [] if is_multi_item_scoring else [None] ++ ++ for val, idx in zip( ++ req.temp_input_top_p_logprobs_val, ++ req.temp_input_top_p_logprobs_idx, ++ strict=True, ++ ): ++ req.input_top_p_logprobs_val.extend(val) ++ req.input_top_p_logprobs_idx.extend(idx) ++ ++ if not is_multi_item_scoring: ++ req.input_top_p_logprobs_val.pop() ++ req.input_top_p_logprobs_idx.pop() ++ ++ req.temp_input_top_p_logprobs_idx = None ++ req.temp_input_top_p_logprobs_val = None ++ ++ + + def _process_input_token_ids_logprobs(self, req: Req) -> None: + """Process input token IDs logprobs.""" +@@ -737,6 +770,10 @@ class SchedulerOutputProcessorMixin: + req.temp_input_top_logprobs_val = [] + if req.temp_input_top_logprobs_idx is None: + req.temp_input_top_logprobs_idx = [] ++ if req.temp_input_top_p_logprobs_val is None: ++ req.temp_input_top_p_logprobs_val = [] ++ if req.temp_input_top_p_logprobs_idx is None: ++ req.temp_input_top_p_logprobs_idx = [] + if req.temp_input_token_ids_logprobs_val is None: + req.temp_input_token_ids_logprobs_val = [] + if req.temp_input_token_ids_logprobs_idx is None: +@@ -761,6 +798,10 @@ class SchedulerOutputProcessorMixin: + req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i]) + req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) + ++ if req.top_logprobs_p > 0.0 and output.input_top_p_logprobs_val is not None: ++ req.temp_input_top_p_logprobs_val.append(output.input_top_p_logprobs_val[i]) ++ req.temp_input_top_p_logprobs_idx.append(output.input_top_p_logprobs_idx[i]) ++ + if req.token_ids_logprob is not None: + req.temp_input_token_ids_logprobs_val.append( + output.input_token_ids_logprobs_val[i] +@@ -780,6 +821,7 @@ class SchedulerOutputProcessorMixin: + # Process all input logprob types using helper functions + self._process_input_token_logprobs(req, input_token_logprobs) + self._process_input_top_logprobs(req) ++ self._process_input_top_p_logprobs(req) + + self._process_input_token_ids_logprobs(req) + +@@ -822,6 +864,10 @@ class SchedulerOutputProcessorMixin: + req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) + ++ if req.top_logprobs_p > 0.0 and output.next_token_top_p_logprobs_val is not None: ++ req.output_top_p_logprobs_val.append(output.next_token_top_p_logprobs_val[i]) ++ req.output_top_p_logprobs_idx.append(output.next_token_top_p_logprobs_idx[i]) ++ + if ( + req.token_ids_logprob is not None + and output.next_token_token_ids_logprobs_val is not None +@@ -852,6 +898,10 @@ class SchedulerOutputProcessorMixin: + req.input_top_logprobs_val = [] + if req.input_top_logprobs_idx is None: + req.input_top_logprobs_idx = [] ++ if req.input_top_p_logprobs_val is None: ++ req.input_top_p_logprobs_val = [] ++ if req.input_top_p_logprobs_idx is None: ++ req.input_top_p_logprobs_idx = [] + if req.input_token_ids_logprobs_val is None: + req.input_token_ids_logprobs_val = [] + if req.input_token_ids_logprobs_idx is None: +@@ -922,6 +972,18 @@ class SchedulerOutputProcessorMixin: prefill_launch_delays = [] prefill_launch_latencies = [] prefill_finished_timestamps = [] @@ -2343,7 +2810,33 @@ index 482bc6ca6..fbc486417 100644 if return_logprob: input_token_logprobs_val = [] -@@ -1037,6 +1049,40 @@ class SchedulerOutputProcessorMixin: +@@ -932,6 +994,10 @@ class SchedulerOutputProcessorMixin: + input_top_logprobs_idx = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] ++ input_top_p_logprobs_val = [] ++ input_top_p_logprobs_idx = [] ++ output_top_p_logprobs_val = [] ++ output_top_p_logprobs_idx = [] + input_token_ids_logprobs_val = [] + input_token_ids_logprobs_idx = [] + output_token_ids_logprobs_val = [] +@@ -942,8 +1008,12 @@ class SchedulerOutputProcessorMixin: + ) = output_token_logprobs_idx = input_top_logprobs_val = ( + input_top_logprobs_idx + ) = output_top_logprobs_val = output_top_logprobs_idx = ( +- input_token_ids_logprobs_val +- ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = ( ++ input_top_p_logprobs_val ++ ) = input_top_p_logprobs_idx = output_top_p_logprobs_val = ( ++ output_top_p_logprobs_idx ++ ) = input_token_ids_logprobs_val = ( ++ input_token_ids_logprobs_idx ++ ) = output_token_ids_logprobs_val = ( + output_token_ids_logprobs_idx + ) = None + +@@ -1037,6 +1107,40 @@ class SchedulerOutputProcessorMixin: prefill_finished_timestamps.append( req.time_stats.get_prefill_finished_ts() ) @@ -2384,7 +2877,51 @@ index 482bc6ca6..fbc486417 100644 if not self.spec_algorithm.is_none(): spec_verify_ct.append(req.spec_verify_ct) -@@ -1134,7 +1180,7 @@ class SchedulerOutputProcessorMixin: +@@ -1054,6 +1158,8 @@ class SchedulerOutputProcessorMixin: + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) ++ input_top_p_logprobs_val.append(req.input_top_p_logprobs_val) ++ input_top_p_logprobs_idx.append(req.input_top_p_logprobs_idx) + input_token_ids_logprobs_val.append( + req.input_token_ids_logprobs_val + ) +@@ -1066,6 +1172,8 @@ class SchedulerOutputProcessorMixin: + input_token_logprobs_idx.append([]) + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) ++ input_top_p_logprobs_val.append([]) ++ input_top_p_logprobs_idx.append([]) + input_token_ids_logprobs_val.append([]) + input_token_ids_logprobs_idx.append([]) + +@@ -1090,6 +1198,16 @@ class SchedulerOutputProcessorMixin: + send_output_token_logprobs_offset: + ] + ) ++ output_top_p_logprobs_val.append( ++ req.output_top_p_logprobs_val[ ++ send_output_token_logprobs_offset: ++ ] ++ ) ++ output_top_p_logprobs_idx.append( ++ req.output_top_p_logprobs_idx[ ++ send_output_token_logprobs_offset: ++ ] ++ ) + output_token_ids_logprobs_val.append( + req.output_token_ids_logprobs_val[ + send_output_token_logprobs_offset: +@@ -1108,6 +1226,8 @@ class SchedulerOutputProcessorMixin: + output_token_logprobs_idx.append([]) + output_top_logprobs_val.append([]) + output_top_logprobs_idx.append([]) ++ output_top_p_logprobs_val.append([]) ++ output_top_p_logprobs_idx.append([]) + output_token_ids_logprobs_val.append([]) + output_token_ids_logprobs_idx.append([]) + +@@ -1134,7 +1254,7 @@ class SchedulerOutputProcessorMixin: req.log_time_stats() # Send to detokenizer @@ -2393,7 +2930,7 @@ index 482bc6ca6..fbc486417 100644 if self.model_config.is_multimodal_gen: return self.send_to_detokenizer.send_output( -@@ -1149,6 +1195,17 @@ class SchedulerOutputProcessorMixin: +@@ -1149,6 +1269,17 @@ class SchedulerOutputProcessorMixin: prefill_launch_delay=prefill_launch_delays, prefill_launch_latency=prefill_launch_latencies, prefill_finished_ts=prefill_finished_timestamps, @@ -2411,7 +2948,18 @@ index 482bc6ca6..fbc486417 100644 finished_reasons=finished_reasons, decoded_texts=decoded_texts, decode_ids=decode_ids_list, -@@ -1198,6 +1255,18 @@ class SchedulerOutputProcessorMixin: +@@ -1169,6 +1300,10 @@ class SchedulerOutputProcessorMixin: + input_top_logprobs_idx=input_top_logprobs_idx, + output_top_logprobs_val=output_top_logprobs_val, + output_top_logprobs_idx=output_top_logprobs_idx, ++ input_top_p_logprobs_val=input_top_p_logprobs_val, ++ input_top_p_logprobs_idx=input_top_p_logprobs_idx, ++ output_top_p_logprobs_val=output_top_p_logprobs_val, ++ output_top_p_logprobs_idx=output_top_p_logprobs_idx, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, + output_token_ids_logprobs_val=output_token_ids_logprobs_val, +@@ -1198,6 +1333,18 @@ class SchedulerOutputProcessorMixin: prefill_launch_delays = [] prefill_launch_latencies = [] prefill_finished_timestamps = [] @@ -2430,7 +2978,7 @@ index 482bc6ca6..fbc486417 100644 retraction_counts = [] for req in reqs: if req.finished(): -@@ -1221,6 +1290,40 @@ class SchedulerOutputProcessorMixin: +@@ -1221,6 +1368,40 @@ class SchedulerOutputProcessorMixin: prefill_finished_timestamps.append( req.time_stats.get_prefill_finished_ts() ) @@ -2471,7 +3019,7 @@ index 482bc6ca6..fbc486417 100644 retraction_counts.append(req.retraction_count) self.send_to_detokenizer.send_output( BatchEmbeddingOutput( -@@ -1231,6 +1334,17 @@ class SchedulerOutputProcessorMixin: +@@ -1231,6 +1412,17 @@ class SchedulerOutputProcessorMixin: prefill_launch_delay=prefill_launch_delays, prefill_launch_latency=prefill_launch_latencies, prefill_finished_ts=prefill_finished_timestamps, @@ -2490,7 +3038,7 @@ index 482bc6ca6..fbc486417 100644 embeddings=embeddings, prompt_tokens=prompt_tokens, diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py -index 1a65a3c3d..f76606469 100644 +index 1a65a3c..f766064 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -20,6 +20,7 @@ from sglang.srt.layers.dp_attention import ( @@ -2687,7 +3235,7 @@ index 1a65a3c3d..f76606469 100644 rids: List = [] for poll_statuses in poll_statuses_group: diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py -index 7d08f12b3..afc045da2 100644 +index 7d08f12..afc045d 100644 --- a/python/sglang/srt/managers/scheduler_profiler_mixin.py +++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py @@ -347,7 +347,7 @@ class SchedulerProfilerMixin: @@ -2700,7 +3248,7 @@ index 7d08f12b3..afc045da2 100644 if self.profile_in_progress: # force trace flush diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py -index 293a84350..244ea4eb1 100644 +index 293a843..244ea4e 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -12,6 +12,7 @@ from sglang.srt.constants import ( @@ -2765,7 +3313,7 @@ index 293a84350..244ea4eb1 100644 def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py -index f2ffa9909..6e4d1d460 100644 +index f2ffa99..6e4d1d4 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import ( @@ -2817,10 +3365,30 @@ index f2ffa9909..6e4d1d460 100644 self, obj: InitWeightsSendGroupForRemoteInstanceReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 0914a5230..33bb3844a 100644 +index 0914a52..bae5652 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -324,8 +324,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -163,6 +163,10 @@ class ReqState: + input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) ++ input_top_p_logprobs_val: List = dataclasses.field(default_factory=list) ++ input_top_p_logprobs_idx: List = dataclasses.field(default_factory=list) ++ output_top_p_logprobs_val: List = dataclasses.field(default_factory=list) ++ output_top_p_logprobs_idx: List = dataclasses.field(default_factory=list) + input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) + input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) + output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) +@@ -173,6 +177,8 @@ class ReqState: + output_token_logprobs: List[Any] = dataclasses.field(default_factory=list) + input_top_logprobs: List[Any] = dataclasses.field(default_factory=list) + output_top_logprobs: List[Any] = dataclasses.field(default_factory=list) ++ input_top_p_logprobs: List[Any] = dataclasses.field(default_factory=list) ++ output_top_p_logprobs: List[Any] = dataclasses.field(default_factory=list) + input_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) + output_token_ids_logprobs: List[Any] = dataclasses.field(default_factory=list) + +@@ -324,8 +330,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi context, zmq.PULL, port_args.tokenizer_ipc_name, True ) if self.server_args.tokenizer_worker_num == 1: @@ -2834,7 +3402,15 @@ index 0914a5230..33bb3844a 100644 ) else: from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper -@@ -1327,7 +1331,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -927,6 +937,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + obj.logprob_start_len, + obj.top_logprobs_num, + obj.token_ids_logprob, ++ obj.top_logprobs_p, + obj.stream, + rid=obj.rid, + http_worker_ipc=obj.http_worker_ipc, +@@ -1327,7 +1338,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi async with self.is_pause_cond: self.is_pause = True if obj.mode != "abort": @@ -2843,7 +3419,7 @@ index 0914a5230..33bb3844a 100644 else: # we are using the model_update_lock to check if there is still on-going requests. while True: -@@ -1341,7 +1345,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1341,7 +1352,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi async def continue_generation(self, obj: ContinueGenerationReqInput): async with self.is_pause_cond: self.is_pause = False @@ -2852,7 +3428,7 @@ index 0914a5230..33bb3844a 100644 self.is_pause_cond.notify_all() async def update_weights_from_disk( -@@ -1510,6 +1514,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1510,6 +1521,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi self._add_metric_if_present( recv_obj, "prefill_finished_ts", meta_info, i ) @@ -2893,7 +3469,186 @@ index 0914a5230..33bb3844a 100644 if getattr(state.obj, "return_logprob", False): self.convert_logprob_style( -@@ -1955,19 +1993,17 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1684,7 +1729,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + meta_info["input_top_logprobs"] = state.input_top_logprobs + meta_info["output_top_logprobs"] = state.output_top_logprobs + +- # 3. Handle token_ids_logprob ++ # 3. Handle top-p logprobs ++ top_logprobs_p = getattr(state.obj, "top_logprobs_p", 0.0) or 0.0 ++ if top_logprobs_p > 0.0: ++ if len(state.input_top_p_logprobs_val) > len(state.input_top_p_logprobs): ++ state.input_top_p_logprobs.extend( ++ self.detokenize_top_logprobs_tokens( ++ state.input_top_p_logprobs_val[ ++ len(state.input_top_p_logprobs) : ++ ], ++ state.input_top_p_logprobs_idx[ ++ len(state.input_top_p_logprobs) : ++ ], ++ return_text_in_logprobs, ++ ) ++ ) ++ if len(state.output_top_p_logprobs_val) > len( ++ state.output_top_p_logprobs ++ ): ++ state.output_top_p_logprobs.extend( ++ self.detokenize_top_logprobs_tokens( ++ state.output_top_p_logprobs_val[ ++ len(state.output_top_p_logprobs) : ++ ], ++ state.output_top_p_logprobs_idx[ ++ len(state.output_top_p_logprobs) : ++ ], ++ return_text_in_logprobs, ++ ) ++ ) ++ ++ meta_info["input_top_p_logprobs"] = state.input_top_p_logprobs ++ meta_info["output_top_p_logprobs"] = state.output_top_p_logprobs ++ ++ # 4. Handle token_ids_logprob + if token_ids_logprob is not None: + if len(state.input_token_ids_logprobs_val) > len( + state.input_token_ids_logprobs +@@ -1718,6 +1796,105 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + meta_info["input_token_ids_logprobs"] = state.input_token_ids_logprobs + meta_info["output_token_ids_logprobs"] = state.output_token_ids_logprobs + ++ # 5. Handle base64 encoding for top-k and top-p logprobs ++ return_logprobs_in_base64 = getattr( ++ state.obj, "return_logprobs_in_base64", False ++ ) ++ if return_logprobs_in_base64: ++ self._encode_logprobs_as_base64(meta_info, state, top_logprobs_num) ++ ++ ++ @staticmethod ++ def _encode_logprobs_as_base64( ++ meta_info: dict, state: "ReqState", top_logprobs_num: int ++ ): ++ """Encode top-k and top-p logprobs as base64-encoded numpy arrays.""" ++ import numpy as np ++ import pybase64 ++ ++ def encode_fixed_length(vals_raw, idxs_raw, key_prefix): ++ if not vals_raw: ++ return ++ filtered_vals = [] ++ filtered_idxs = [] ++ none_positions = [] ++ for pos_i, (v, x) in enumerate(zip(vals_raw, idxs_raw)): ++ if v is None: ++ none_positions.append(pos_i) ++ else: ++ filtered_vals.append(v) ++ filtered_idxs.append(x) ++ if not filtered_vals: ++ return ++ val_arr = np.array(filtered_vals, dtype=np.float32) ++ idx_arr = np.array(filtered_idxs, dtype=np.int32) ++ meta_info[f"{key_prefix}_val_base64"] = pybase64.b64encode( ++ val_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_idx_base64"] = pybase64.b64encode( ++ idx_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_shape"] = list(val_arr.shape) ++ if none_positions: ++ meta_info[f"{key_prefix}_none_positions"] = none_positions ++ if key_prefix in meta_info: ++ del meta_info[key_prefix] ++ ++ def encode_variable_length(vals_raw, idxs_raw, key_prefix): ++ if not vals_raw: ++ return ++ flat_vals = [] ++ flat_idxs = [] ++ lengths = [] ++ for v, x in zip(vals_raw, idxs_raw): ++ if v is None: ++ lengths.append(-1) ++ elif isinstance(v, list): ++ lengths.append(len(v)) ++ flat_vals.extend(v) ++ flat_idxs.extend(x) ++ else: ++ lengths.append(0) ++ if not flat_vals: ++ meta_info[f"{key_prefix}_lengths"] = lengths ++ return ++ val_arr = np.array(flat_vals, dtype=np.float32) ++ idx_arr = np.array(flat_idxs, dtype=np.int32) ++ meta_info[f"{key_prefix}_val_base64"] = pybase64.b64encode( ++ val_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_idx_base64"] = pybase64.b64encode( ++ idx_arr.tobytes() ++ ).decode("utf-8") ++ meta_info[f"{key_prefix}_lengths"] = lengths ++ if key_prefix in meta_info: ++ del meta_info[key_prefix] ++ ++ if top_logprobs_num > 0: ++ encode_fixed_length( ++ state.input_top_logprobs_val, ++ state.input_top_logprobs_idx, ++ "input_top_logprobs", ++ ) ++ encode_fixed_length( ++ state.output_top_logprobs_val, ++ state.output_top_logprobs_idx, ++ "output_top_logprobs", ++ ) ++ ++ top_logprobs_p = getattr(state.obj, "top_logprobs_p", 0.0) or 0.0 ++ if top_logprobs_p > 0.0: ++ encode_variable_length( ++ state.input_top_p_logprobs_val, ++ state.input_top_p_logprobs_idx, ++ "input_top_p_logprobs", ++ ) ++ encode_variable_length( ++ state.output_top_p_logprobs_val, ++ state.output_top_p_logprobs_idx, ++ "output_top_p_logprobs", ++ ) ++ + def convert_logprob_style( + self, + meta_info: dict, +@@ -1763,6 +1940,30 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi + recv_obj.output_top_logprobs_idx[recv_obj_index] + ) + ++ top_logprobs_p = getattr(state.obj, "top_logprobs_p", 0.0) or 0.0 ++ if top_logprobs_p > 0.0: ++ if ( ++ recv_obj.input_top_p_logprobs_val is not None ++ and len(recv_obj.input_top_p_logprobs_val) > 0 ++ and recv_obj.input_top_p_logprobs_val[recv_obj_index] ++ ): ++ state.input_top_p_logprobs_val.extend( ++ recv_obj.input_top_p_logprobs_val[recv_obj_index] ++ ) ++ state.input_top_p_logprobs_idx.extend( ++ recv_obj.input_top_p_logprobs_idx[recv_obj_index] ++ ) ++ if ( ++ recv_obj.output_top_p_logprobs_val is not None ++ and len(recv_obj.output_top_p_logprobs_val) > 0 ++ ): ++ state.output_top_p_logprobs_val.extend( ++ recv_obj.output_top_p_logprobs_val[recv_obj_index] ++ ) ++ state.output_top_p_logprobs_idx.extend( ++ recv_obj.output_top_p_logprobs_idx[recv_obj_index] ++ ) ++ + if token_ids_logprob is not None: + if len(recv_obj.input_token_ids_logprobs_val) > 0: + state.input_token_ids_logprobs_val.extend( +@@ -1955,19 +2156,17 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi if custom_labels else self.metrics_collector.labels ) @@ -2919,7 +3674,7 @@ index 0914a5230..33bb3844a 100644 new_time = time.time() interval = new_time - state.last_time self.metrics_collector.observe_inter_token_latency( -@@ -1976,7 +2012,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi +@@ -1976,7 +2175,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi num_new_tokens, ) state.last_time = new_time @@ -2929,7 +3684,7 @@ index 0914a5230..33bb3844a 100644 if state.finished: retraction_count = ( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py -index 86b009df4..16ebd52ae 100644 +index 86b009d..16ebd52 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,6 +29,7 @@ from sglang.srt.managers.io_struct import ( @@ -2953,7 +3708,7 @@ index 86b009df4..16ebd52ae 100644 parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py -index fa08bb66a..fa539315c 100644 +index fa08bb6..fa53931 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -347,6 +347,84 @@ def alloc_decode_kernel( @@ -3060,7 +3815,7 @@ index fa08bb66a..fa539315c 100644 prefix_lens, seq_lens, diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py -index d7cd472a9..81fae740f 100644 +index d7cd472..81fae74 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -76,6 +76,7 @@ class HiRadixCache(RadixCache): @@ -3117,7 +3872,7 @@ index d7cd472a9..81fae740f 100644 self._inc_hit_count(new_node, chunked) total_prefix_length += prefix_len diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py -index 1d917137c..669e5c518 100644 +index 1d91713..669e5c5 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1777,9 +1777,12 @@ class NSATokenToKVPool(MLATokenToKVPool): @@ -3200,7 +3955,7 @@ index 1d917137c..669e5c518 100644 kv_size_bytes = super().get_kv_size_bytes() for index_k_cache in self.index_k_with_scale_buffer: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py -index 42b169728..8e799196a 100644 +index 42b1697..8e79919 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -495,7 +495,17 @@ class RadixCache(BasePrefixCache): @@ -3235,7 +3990,7 @@ index 42b169728..8e799196a 100644 return delta diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py -index 255d41ccc..f93bedb4d 100644 +index 255d41c..f93bedb 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -20,7 +20,10 @@ import time @@ -3418,10 +4173,26 @@ index 255d41ccc..f93bedb4d 100644 if self.disagg_mode == DisaggregationMode.NULL: queue_duration = self.forward_entry_time - self.wait_queue_entry_time diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py -index 234523532..f5d479945 100644 +index 2345235..307b656 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py -@@ -909,6 +909,28 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): +@@ -266,6 +266,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): + # For logprob + return_logprob: bool = False + top_logprobs_nums: Optional[List[int]] = None ++ top_logprobs_ps: Optional[List[float]] = None + token_ids_logprobs: Optional[List[List[int]]] = None + + # For logits and logprobs post processing +@@ -399,6 +400,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): + orig_seq_lens=batch.orig_seq_lens, + return_logprob=batch.return_logprob, + top_logprobs_nums=batch.top_logprobs_nums, ++ top_logprobs_ps=batch.top_logprobs_ps, + token_ids_logprobs=batch.token_ids_logprobs, + is_extend_in_batch=batch.is_extend_in_batch, + can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, +@@ -909,6 +911,28 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): tokens_padded = (tokens + rank_size - 1) // rank_size * rank_size self._pad_inputs_to_size(model_runner, tokens_padded, self.batch_size) @@ -3451,7 +4222,7 @@ index 234523532..f5d479945 100644 self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 275775a73..f0bd3ebf8 100644 +index 275775a..fdeb693 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -395,7 +395,12 @@ class ModelRunner(ModelRunnerKVCacheMixin): @@ -3513,7 +4284,23 @@ index 275775a73..f0bd3ebf8 100644 # Normalize num_token_non_padded to be local to this attention TP rank if needed. if ( -@@ -2664,6 +2681,42 @@ class ModelRunner(ModelRunnerKVCacheMixin): +@@ -2553,6 +2570,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): + forward_batch.sampling_info, + forward_batch.return_logprob, + forward_batch.top_logprobs_nums, ++ forward_batch.top_logprobs_ps, + forward_batch.token_ids_logprobs, + # For prefill, we only use the position of the last token. + ( +@@ -2592,6 +2610,7 @@ class ModelRunner(ModelRunnerKVCacheMixin): + forward_batch.sampling_info, + forward_batch.return_logprob, + forward_batch.top_logprobs_nums, ++ forward_batch.top_logprobs_ps, + forward_batch.token_ids_logprobs, + ) + +@@ -2664,6 +2683,42 @@ class ModelRunner(ModelRunnerKVCacheMixin): device=self.device, ) @@ -3557,7 +4344,7 @@ index 275775a73..f0bd3ebf8 100644 def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/models/deepseek_common/attention_backend_handler.py b/python/sglang/srt/models/deepseek_common/attention_backend_handler.py -index cc673a9ca..06c430d2c 100644 +index cc673a9..06c430d 100644 --- a/python/sglang/srt/models/deepseek_common/attention_backend_handler.py +++ b/python/sglang/srt/models/deepseek_common/attention_backend_handler.py @@ -1,4 +1,5 @@ @@ -3576,7 +4363,7 @@ index cc673a9ca..06c430d2c 100644 return AttnForwardMethod.MHA_ONE_SHOT return AttnForwardMethod.MLA diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py -index cb13a7c67..d9669ce08 100644 +index cb13a7c..d9669ce 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -29,6 +29,7 @@ from sglang.srt.layers.attention.nsa.utils import ( @@ -3607,7 +4394,7 @@ index cb13a7c67..d9669ce08 100644 if not forward_batch.forward_mode.is_idle(): if residual is not None: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py -index 1583dd788..a35c00f96 100644 +index 1583dd7..a35c00f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1085,6 +1085,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin): @@ -3828,7 +4615,7 @@ index 1583dd788..a35c00f96 100644 if normal_end_layer != self.end_layer: hidden_states, residual = model_forward_maybe_tbo( diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py -index db8c1c7ce..53ffadf6d 100644 +index db8c1c7..53ffadf 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -678,8 +678,13 @@ class Glm4MoeDecoderLayer(nn.Module): @@ -3856,7 +4643,7 @@ index db8c1c7ce..53ffadf6d 100644 hidden_states, residual = self.layer_communicator.prepare_attn( diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py -index 1f6e75364..546cce4ab 100644 +index 1f6e753..546cce4 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -103,7 +103,7 @@ class Glm4MoeModelNextN(nn.Module): @@ -3869,7 +4656,7 @@ index 1f6e75364..546cce4ab 100644 ) diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py -index 324de18b4..fc72faa03 100644 +index 324de18..fc72faa 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -52,11 +52,31 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): @@ -3974,7 +4761,7 @@ index 324de18b4..fc72faa03 100644 continue diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py -index 2cf813bce..1250c49e4 100644 +index 2cf813b..1250c49 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -17,6 +17,7 @@ @@ -4088,7 +4875,7 @@ index 2cf813bce..1250c49e4 100644 weights_out_dict = dict(weights_in) diff --git a/python/sglang/srt/models/kimi_k25.py b/python/sglang/srt/models/kimi_k25.py -index d8399a691..0277bc671 100644 +index d8399a6..0277bc6 100644 --- a/python/sglang/srt/models/kimi_k25.py +++ b/python/sglang/srt/models/kimi_k25.py @@ -666,25 +666,30 @@ class KimiK25ForConditionalGeneration(nn.Module): @@ -4208,7 +4995,7 @@ index d8399a691..0277bc671 100644 diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py -index 49f938a1c..8eea383bb 100644 +index 49f938a..8eea383 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -85,6 +85,11 @@ class LlamaDecoderLayer(LlamaDecoderLayer): @@ -4236,7 +5023,7 @@ index 49f938a1c..8eea383bb 100644 # idle batch diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py -index f01225487..1dad8bb8e 100644 +index f012254..1dad8bb 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -372,6 +372,7 @@ class Qwen3_5LinearDecoderLayer(nn.Module): @@ -4312,7 +5099,7 @@ index f01225487..1dad8bb8e 100644 return hidden_states, residual diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py -index d641826e3..3abc39ef3 100644 +index d641826..3abc39e 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -711,14 +711,19 @@ class Qwen3LLMModel(Qwen3Model): @@ -4340,7 +5127,7 @@ index d641826e3..3abc39ef3 100644 positions, hidden_states, diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py -index 33cce6fe2..0970c4550 100644 +index 33cce6f..0970c45 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -1,6 +1,9 @@ @@ -4400,7 +5187,7 @@ index 33cce6fe2..0970c4550 100644 self, image_data: List[Union[str, bytes]], diff --git a/python/sglang/srt/multimodal/processors/kimi_k25.py b/python/sglang/srt/multimodal/processors/kimi_k25.py -index d8bb9ceb3..9311a431b 100644 +index d8bb9ce..9311a43 100644 --- a/python/sglang/srt/multimodal/processors/kimi_k25.py +++ b/python/sglang/srt/multimodal/processors/kimi_k25.py @@ -25,6 +25,18 @@ class KimiK2_5VLImageProcessor(SGLangBaseProcessor): @@ -4423,7 +5210,7 @@ index d8bb9ceb3..9311a431b 100644 async def process_mm_data_async( self, diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py -index 4395654e4..f9b5ea4ab 100644 +index 4395654..f9b5ea4 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -317,7 +317,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): @@ -4436,7 +5223,7 @@ index 4395654e4..f9b5ea4ab 100644 image_data=image_data, video_data=request_obj.video_data, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index b080aeb16..b0322fef4 100644 +index b080aeb..b0322fe 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -580,6 +580,7 @@ class ServerArgs: @@ -4554,7 +5341,7 @@ index b080aeb16..b0322fef4 100644 return PortArgs( tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py -index 5fe45086c..b283d2e9b 100644 +index 5fe4508..b283d2e 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner: @@ -4585,7 +5372,7 @@ index 5fe45086c..b283d2e9b 100644 self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py -index ac629c7ee..904f54b4a 100644 +index ac629c7..904f54b 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -337,7 +337,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): @@ -4637,7 +5424,7 @@ index ac629c7ee..904f54b4a 100644 @dataclass diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py -index 32b3a520a..d7f940147 100644 +index 32b3a52..d7f9401 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -234,7 +234,10 @@ class EAGLEWorker(TpModelWorker): @@ -4653,7 +5440,7 @@ index 32b3a520a..d7f940147 100644 Device2DraftCudaGraphRunner = { diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py -index 4636128fa..a9b61df39 100644 +index 4636128..a9b61df 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2359,6 +2359,8 @@ class SafeUnpickler(pickle.Unpickler): @@ -4666,7 +5453,7 @@ index 4636128fa..a9b61df39 100644 DENY_CLASSES = { diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py -index 3be16446e..1b2371c83 100644 +index 3be1644..1b2371c 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -69,6 +69,9 @@ def _check_tensors( diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index f68d665537..7d28193e42 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -229,7 +229,7 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) - for key in ["rollout_log_probs", "teacher_log_probs"]: + for key in ["rollout_log_probs", "teacher_log_probs", "sampling_logprob_sum"]: if key not in rollout_data: continue rollout_data[key] = [ @@ -253,6 +253,28 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: ) ) ] + # sampling_token_ids: variable-length nested lists, apply zigzag CP slicing but keep as lists. + # For allgather_cp, skip zigzag slicing — contiguous slicing is done at training time + # inside get_masked_log_probs_for_token_ids to match the allgather logits layout. + if "sampling_token_ids" in rollout_data and not self.args.allgather_cp: + key = "sampling_token_ids" + rollout_data[key] = [ + slice_log_prob_with_cp( + token_ids, + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + ) + for i, (token_ids, total_length, response_length) in enumerate( + zip( + rollout_data[key], + rollout_data["total_lengths"], + rollout_data["response_lengths"], + strict=False, + ) + ) + ] if "rollout_routed_experts" in rollout_data: rollout_data["rollout_routed_experts"] = [ torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"] diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 8a7f768b3b..46f5760602 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -396,7 +396,8 @@ def log_rollout_data( - Tensor-valued lists are concatenated and averaged. For token-level metrics like log-probs/returns/advantages/values, computes a CP-correct sample mean using `loss_masks` and total/response lengths. - - Non-tensor lists are averaged elementwise. + - Non-tensor lists are averaged elementwise. ``sampling_token_ids`` is + summarized by average candidate count per position. - Scalars are converted to Python numbers. """ if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage(): @@ -422,6 +423,8 @@ def log_rollout_data( # There are the following assumptions: # - Each dp rank has the same number of samples if isinstance(val, (list, tuple)): + if len(val) == 0: + continue if isinstance(val[0], torch.Tensor): # NOTE: Here we have to do the clone().detach(), otherwise the tensor will be # modified in place and will cause problem for the next rollout. @@ -447,6 +450,19 @@ def log_rollout_data( else: val = torch.cat(val).clone().detach() val = val.mean() * cp_size + elif key == "sampling_token_ids": + num_positions = sum(len(sample) for sample in val) + val = ( + sum(len(token_ids) for sample in val for token_ids in sample) / num_positions + if num_positions > 0 + else 0.0 + ) + elif key == "sampling_logprob_sum": + # Per-token tensor; skip generic aggregation — already + # summarized indirectly via sampling_token_ids metrics. + continue + elif isinstance(val[0], (list, tuple)): + continue else: val = sum(val) / len(val) elif isinstance(val, torch.Tensor): diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index db7b6098dd..e221b4113d 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -14,12 +14,14 @@ calculate_log_probs_and_entropy, compute_approx_kl, compute_gspo_kl, + compute_log_probs, compute_opsm_mask, compute_policy_loss, get_advantages_and_returns_batch, get_grpo_returns, get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, + mask_logits_for_token_ids, ) from slime.utils.types import RolloutBatch @@ -297,6 +299,141 @@ def get_log_probs_and_entropy( return torch.empty((0,), device=logits.device), res +def get_masked_log_probs_for_token_ids( + logits: torch.Tensor, + *, + args: Namespace, + unconcat_tokens: list[torch.Tensor], + total_lengths: list[int], + response_lengths: list[int], + sampling_token_ids: list[list[list[int]]], + max_seq_lens: list[int] | None = None, +) -> list[torch.Tensor]: + """Compute per-token log-probabilities restricted to sampling token subsets. + + For each sample, masks logits to keep only the tokens in ``sampling_token_ids`` + (setting others to ``-inf``), then computes ``log softmax`` over the + restricted set. + + Args: + logits: Policy logits with shape ``[1, T, V]``. + args: Configuration (temperature applied in ``get_responses``). + unconcat_tokens: Per-sample token tensors. + total_lengths: Total sequence lengths per sample. + response_lengths: Response segment lengths per sample. + sampling_token_ids: Per-sample, per-position list of global token IDs to + keep. Shape: ``[num_samples][response_length_or_cp_chunk][variable]``. + max_seq_lens: Optional max sequence lengths per sample (for bshd). + + Returns: + List of ``[R]`` tensors — masked log-probabilities per sample. + """ + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_group = mpu.get_tensor_model_parallel_group() + + # For allgather_cp, sampling_token_ids is NOT pre-sliced (full response length). + # Compute contiguous CP offsets to extract positions matching each logits chunk. + _allgather_cp = getattr(args, "allgather_cp", False) and mpu.get_context_parallel_world_size() > 1 + if _allgather_cp: + _logits_local_len = logits.view(-1, logits.size(-1)).size(0) + _cp_rank = mpu.get_context_parallel_rank() + _chunk_start = _cp_rank * _logits_local_len + _chunk_end = _chunk_start + _logits_local_len + + masked_log_probs_list = [] + _seq_start = 0 + for i, (logits_chunk, tokens_chunk) in enumerate( + get_responses( + logits, + args=args, + unconcat_tokens=unconcat_tokens, + total_lengths=total_lengths, + response_lengths=response_lengths, + max_seq_lens=max_seq_lens, + ) + ): + # Determine per-position token IDs for masking. + if _allgather_cp: + prompt_length = total_lengths[i] - response_lengths[i] + logit_global_start = _seq_start + prompt_length - 1 + logit_global_end = _seq_start + total_lengths[i] - 1 + s = max(logit_global_start, _chunk_start) + e = min(logit_global_end, _chunk_end) + if e <= s: + per_pos_ids = [] + else: + resp_start = s - logit_global_start + resp_end = e - logit_global_start + per_pos_ids = sampling_token_ids[i][resp_start:resp_end] + else: + per_pos_ids = sampling_token_ids[i] + _seq_start += total_lengths[i] + + vocab_shard_size = logits_chunk.size(-1) + masked_logits = mask_logits_for_token_ids(logits_chunk, per_pos_ids, vocab_shard_size, tp_rank) + # Clone before compute_log_probs: fused_vocab_parallel_cross_entropy + # modifies its input in-place (subtract max, exp, div), which would + # corrupt the autograd graph of masked_logits. + masked_lp = compute_log_probs(masked_logits.clone(), tokens_chunk, tp_group) + masked_log_probs_list.append(masked_lp.squeeze(-1)) + + if args.allgather_cp: + res = {"log_probs": masked_log_probs_list} + _allgather_cp_redistribute( + res, + logits=logits, + args=args, + total_lengths=total_lengths, + response_lengths=response_lengths, + max_seq_lens=max_seq_lens, + ) + masked_log_probs_list = res["log_probs"] + + return masked_log_probs_list + + +def apply_sampling_mask_to_log_probs( + args: Namespace, + batch: RolloutBatch, + logits: torch.Tensor, + log_probs: list[torch.Tensor], + old_log_probs: list[torch.Tensor], + total_lengths: list[int], + response_lengths: list[int], + max_seq_lens: list[int] | None = None, +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Apply rollout sampling-mask normalization to train-time log-probabilities.""" + sampling_token_ids = batch.get("sampling_token_ids") + mask_logprob_sum = batch.get("sampling_logprob_sum") + if ( + not getattr(args, "use_topp_mask", False) and not getattr(args, "use_topk_mask", False) + ) or sampling_token_ids is None: + return log_probs, old_log_probs + + if mask_logprob_sum is None: + raise ValueError( + "batch['sampling_logprob_sum'] must be provided when sampling masks are enabled " + "and 'sampling_token_ids' is present." + ) + if len(mask_logprob_sum) != len(old_log_probs): + raise ValueError( + f"sampling_logprob_sum has {len(mask_logprob_sum)} samples but " f"old_log_probs has {len(old_log_probs)}" + ) + + masked_log_probs = get_masked_log_probs_for_token_ids( + logits, + args=args, + unconcat_tokens=batch["unconcat_tokens"], + total_lengths=total_lengths, + response_lengths=response_lengths, + sampling_token_ids=sampling_token_ids, + max_seq_lens=max_seq_lens, + ) + + masked_old_log_probs = [olp - tlse for olp, tlse in zip(old_log_probs, mask_logprob_sum, strict=True)] + return masked_log_probs, masked_old_log_probs + + def get_values( logits: torch.Tensor, *, @@ -659,6 +796,18 @@ def policy_loss_function( log_probs = log_probs_and_entropy["log_probs"] + if getattr(args, "use_topp_mask", False) or getattr(args, "use_topk_mask", False): + log_probs, old_log_probs = apply_sampling_mask_to_log_probs( + args, + batch, + logits, + log_probs, + old_log_probs, + total_lengths=total_lengths, + response_lengths=response_lengths, + max_seq_lens=max_seq_lens, + ) + # Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering need_full_log_probs = args.use_opsm or args.advantage_estimator == "gspo" diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 39f6125fcf..7e282246fe 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -372,6 +372,8 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "rollout_log_probs", "max_seq_lens", "teacher_log_probs", + "sampling_token_ids", + "sampling_logprob_sum", ], args.data_pad_size_multiplier, args.qkv_format, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index d7a208753b..d5b860fc91 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -743,6 +743,12 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if samples[0].teacher_log_probs is not None: train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples] + if samples[0].sampling_token_ids is not None: + train_data["sampling_token_ids"] = [sample.sampling_token_ids for sample in samples] + + if samples[0].sampling_logprob_sum is not None: + train_data["sampling_logprob_sum"] = [sample.sampling_logprob_sum for sample in samples] + return train_data def set_train_parallel_config(self, config: dict): @@ -782,6 +788,8 @@ def _split_train_data_by_dp(self, data, dp_size): "rollout_routed_experts", "prompt", "teacher_log_probs", + "sampling_token_ids", + "sampling_logprob_sum", ]: if key not in data: continue @@ -1200,6 +1208,7 @@ def compute_metrics_from_samples(args, samples): log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") log_dict |= _compute_zero_std_metrics(args, samples) log_dict |= _compute_reward_cat_metrics(args, samples) + log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() return log_dict diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 72b42b0752..7b6229d183 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -36,6 +36,121 @@ logger = logging.getLogger(__name__) +def _logsumexp(values: np.ndarray) -> float: + """Numerically stable log-sum-exp for a 1-D numpy array.""" + if values.size == 0: + return 0.0 + max_v = values.max() + return float(max_v + np.log(np.sum(np.exp(values - max_v)))) + + +def _extract_top_p_candidates(meta_info: dict[str, Any]) -> list[list[tuple[int, float]]] | None: + """Extract top-p candidates from sglang response (base64 or list format). + + Returns per-position list of ``(token_id, logprob)`` pairs, or ``None``. + """ + # Base64 variable-length format + val_b64 = meta_info.get("output_top_p_logprobs_val_base64") + idx_b64 = meta_info.get("output_top_p_logprobs_idx_base64") + lengths = meta_info.get("output_top_p_logprobs_lengths") + if val_b64 and idx_b64 and lengths: + flat_vals = np.frombuffer(pybase64.b64decode(val_b64), dtype=np.float32) + flat_idxs = np.frombuffer(pybase64.b64decode(idx_b64), dtype=np.int32) + result: list[list[tuple[int, float]]] = [] + offset = 0 + for length in lengths: + if length <= 0: + result.append([]) + else: + result.append( + list( + zip( + flat_idxs[offset : offset + length].tolist(), + flat_vals[offset : offset + length].tolist(), + strict=True, + ) + ) + ) + offset += length + return result + + # List format: [(logprob, token_id, text), ...] per position + top_p_logprobs = meta_info.get("output_top_p_logprobs") + if top_p_logprobs: + return [[(tid, lp) for lp, tid, *_ in entries] for entries in top_p_logprobs] + + return None + + +def _extract_topk_candidates(meta_info: dict[str, Any], top_k: int) -> list[list[tuple[int, float]]] | None: + """Extract top-k candidates from sglang response (base64 or list format). + + Returns per-position list of ``(token_id, logprob)`` pairs, or ``None``. + """ + # Base64 fixed-length format (present when return_logprobs_in_base64=True) + val_b64 = meta_info.get("output_top_logprobs_val_base64") + idx_b64 = meta_info.get("output_top_logprobs_idx_base64") + shape = meta_info.get("output_top_logprobs_shape") + if val_b64 and idx_b64 and shape: + vals = np.frombuffer(pybase64.b64decode(val_b64), dtype=np.float32).reshape(shape) + idxs = np.frombuffer(pybase64.b64decode(idx_b64), dtype=np.int32).reshape(shape) + k = min(top_k, shape[1]) + return [list(zip(idxs[i, :k].tolist(), vals[i, :k].tolist(), strict=True)) for i in range(shape[0])] + + # List format: [(logprob, token_id, text), ...] per position + top_logprobs = meta_info.get("output_top_logprobs") + if not top_logprobs: + return None + return [ + [(tid, lp) for lp, tid, *_ in sorted(entries, key=lambda x: x[0], reverse=True)[:top_k]] + for entries in top_logprobs + ] + + +def append_sampling_mask_to_sample( + sample: Sample, + *, + meta_info: dict[str, Any], + args: Namespace, +) -> None: + use_topp = getattr(args, "use_topp_mask", False) + use_topk = getattr(args, "use_topk_mask", False) + if not use_topp and not use_topk: + return + + topp_candidates = _extract_top_p_candidates(meta_info) if use_topp else None + topk_candidates = _extract_topk_candidates(meta_info, args.rollout_top_k) if use_topk else None + + if topp_candidates is not None and topk_candidates is not None: + # Both enabled — take intersection per position + candidates: list[list[tuple[int, float]]] = [] + for topp_pos, topk_pos in zip(topp_candidates, topk_candidates, strict=True): + topk_set = {tid for tid, _ in topk_pos} + candidates.append([(tid, lp) for tid, lp in topp_pos if tid in topk_set]) + elif topp_candidates is not None: + candidates = topp_candidates + elif topk_candidates is not None: + candidates = topk_candidates + else: + return + + # Build token_ids and logsumexp per position + new_token_ids: list[list[int]] = [] + new_logprob_sums: list[float] = [] + for pos in candidates: + ids = [tid for tid, _ in pos] + lps = np.array([lp for _, lp in pos], dtype=np.float32) + new_token_ids.append(ids) + new_logprob_sums.append(_logsumexp(lps)) + + if sample.sampling_token_ids is None: + sample.sampling_token_ids = [] + if sample.sampling_logprob_sum is None: + sample.sampling_logprob_sum = [] + sample.sampling_token_ids += new_token_ids + sample.sampling_logprob_sum += new_logprob_sums + + def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") -> str: """Return the router URL for a named model. @@ -146,6 +261,13 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + if getattr(args, "use_topp_mask", False) or getattr(args, "use_topk_mask", False): + payload["return_logprobs_in_base64"] = True + if getattr(args, "use_topk_mask", False): + payload["top_logprobs_num"] = args.rollout_top_k + if getattr(args, "use_topp_mask", False): + payload["top_logprobs_p"] = args.rollout_top_p + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True @@ -199,6 +321,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A sample.rollout_log_probs = [] sample.rollout_log_probs += new_response_log_probs + # Record the exact rollout candidate set so training can reuse the same normalization domain. + append_sampling_mask_to_sample(sample, meta_info=output["meta_info"], args=args) + if "routed_experts" in output["meta_info"]: sample.rollout_routed_experts = np.frombuffer( pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index a634d1f003..3e8f3f61a3 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -258,6 +258,26 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-top-k", type=int, default=-1, help="the top-k for the inference engine during rollout." ) + parser.add_argument( + "--use-topp-mask", + action="store_true", + default=False, + help=( + "Enable top-p sampling mask for training-inference consistency. " + "Records per-position sampling token candidates during rollout and re-normalizes " + "log probabilities over the same token subset during training." + ), + ) + parser.add_argument( + "--use-topk-mask", + action="store_true", + default=False, + help=( + "Enable top-k sampling mask for training-inference consistency. " + "Records per-position sampling token candidates during rollout and re-normalizes " + "log probabilities over the same token subset during training." + ), + ) parser.add_argument( "--rollout-max-context-len", type=int, @@ -1494,6 +1514,20 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: return eval_datasets +def _validate_sampling_mask_args(args): + if not args.use_topp_mask and not getattr(args, "use_topk_mask", False): + return + + if args.use_topp_mask: + assert args.rollout_top_p < 1.0, ( + "--use-topp-mask requires rollout_top_p < 1.0. " f"Current value: rollout_top_p={args.rollout_top_p}" + ) + if getattr(args, "use_topk_mask", False): + assert args.rollout_top_k > 0, ( + "--use-topk-mask requires rollout_top_k > 0. " f"Current value: rollout_top_k={args.rollout_top_k}" + ) + + def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) @@ -1590,6 +1624,8 @@ def slime_validate_args(args): if args.use_rollout_logprobs: assert not args.use_tis, "use_rollout_logprobs and use_tis cannot be set at the same time." + _validate_sampling_mask_args(args) + if args.get_mismatch_metrics: assert ( args.custom_tis_function_path is not None diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2404883ab3..f1b50f3a97 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -148,6 +148,41 @@ def compute_policy_loss( return pg_losses, clipfrac +def mask_logits_for_token_ids( + logits: torch.Tensor, + sampling_token_ids: list[list[int]], + vocab_shard_size: int, + tp_rank: int, +) -> torch.Tensor: + """Mask logits to keep only the sampling token subset, setting others to -inf. + + During training, this restricts the softmax normalization domain to an + externally provided sampling token subset for each position. + + Uses ``torch.where`` (not in-place ``masked_fill_``) so the returned tensor + has a clean autograd graph that is safe for downstream custom autograd + functions (e.g. ``fused_vocab_parallel_cross_entropy``) which modify their + input in-place. + + Args: + logits: Logits tensor of shape ``[seq_len, vocab_shard_size]``. + sampling_token_ids: Per-position list of *global* token IDs to keep. + vocab_shard_size: Size of the local vocabulary shard on this TP rank. + tp_rank: Tensor-parallel rank (to map global IDs to local indices). + + Returns: + A **new** tensor with non-selected entries replaced by ``-inf``. + """ + vocab_start = tp_rank * vocab_shard_size + mask = torch.zeros_like(logits, dtype=torch.bool) + for t, ids in enumerate(sampling_token_ids): + local_ids = [gid - vocab_start for gid in ids if vocab_start <= gid < vocab_start + vocab_shard_size] + if local_ids: + idx = torch.tensor(local_ids, dtype=torch.long, device=logits.device) + mask[t].scatter_(0, idx, True) + return torch.where(mask, logits, torch.tensor(float("-inf"), device=logits.device, dtype=logits.dtype)) + + def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None): # TODO: when megatron is not installed, fall back to naive implementation from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy diff --git a/slime/utils/types.py b/slime/utils/types.py index 0681c184b0..864ab6d1b2 100644 --- a/slime/utils/types.py +++ b/slime/utils/types.py @@ -27,6 +27,9 @@ class Sample: rollout_routed_experts: list[list[int]] | None = None # Routed experts from rollout engine remove_sample: bool = False teacher_log_probs: list[float] | None = None # Log probabilities from teacher model for OPD + # Sampling support used to reconstruct the rollout normalization domain during training. + sampling_token_ids: list[list[int]] | None = None + sampling_logprob_sum: list[float] | None = None class Status(Enum): PENDING = "pending"