feat(gemma4): add Gemma4 26B-A4B MoE and 31B dense support#1855
feat(gemma4): add Gemma4 26B-A4B MoE and 31B dense support#1855leofan-lab wants to merge 1 commit intoTHUDM:mainfrom
Conversation
Plugin (model, provider, mbridge), HF<->Megatron converter, retool integration, and 10 unit tests. Highlights: - Heterogeneous attention. Sliding layers go through flash-attn (head_dim=256); global layers go through a PyTorch SDPA path because flash-attn 2.x doesn't support head_dim=512. - Unified CP>1 path. SDPACoreAttention._forward_cp_subseq_mask all-gathers K/V, un-zigzags to global order, and builds per-sub-sequence causal (+ sliding-window) masks from slime's zig-zag global indices. Covers both global and sliding layers. - Dual RoPE emits a single concatenated tensor (not a tuple) so it doesn't collide with Megatron's (self_attn, cross_attn) rope plumbing. - Gemma4MoELayer subclasses MoELayer to reuse dispatcher + grouped-GEMM + EP, then swaps in Gemma4Router (RMSNorm -> scale -> proj -> softmax -> topk -> renormalize -> per-expert scale; order is load-bearing). - Per-layer scalars loaded from the HF checkpoint as non-trainable buffers; rank-0 reads + broadcast to avoid O(world_size) FS hits. - Converter emits stacked 3D expert tensors for sglang and handles PP offset via get_transformer_layer_offset. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@leofan-lab thanks for your great job. could you please share your training script? I start the SGlang server of Gemma4 26B-A4B model, raise TP error. |
sure: The error you came across for SGLang is most likely because you need to upgrade its version to include Gemma4 |
Hi, I'm currently trying to upgrade sglang to include Gemma4 support (as in upstream PR #21952). However, I encountered an error while applying the sglang patches: Could you provide some advice on how to resolve this patch issue? Any guidance would be greatly appreciated. |
Summary
Adds Gemma4 (26B-A4B MoE + 31B dense) model support to slime for RL training. Covers model architecture, HF↔Megatron weight conversion, retool integration, and 10 unit tests.
Test validation: parity tests across TP/PP/DP/CP/EP/Sliding Window all pass (see table below), RL training runs on Gemma4 26B-A4B MoE and 31B dense for over 200+ rollout steps using the retool recipe (see End-to-end validation below).
What's in this PR
Model plugin (
slime_plugins/models/gemma4.py,gemma4_provider.py)Heterogeneous attention. Sliding layers use
head_dim=256+ flash-attn with a(sw-1, 0)left-window mask. Global layers usehead_dim=512— flash-attn 2.7.4 doesn't support >256, so they go through a PyTorch SDPA path.Context parallelism.
SDPACoreAttention._forward_cp_subseq_maskis a unified CP>1 path for both global and sliding layers:Dual RoPE.
DualRotaryEmbeddingwraps(local, global)ropes and emits a single concatenated tensor per call. The layer slices per-layer based onis_sliding. Concat (not tuple) so Megatron's existing 2-tuple(self_attn, cross_attn)rope plumbing doesn't misread it. Global layers use partial-rotary via zeroedinv_freqtail entries (requiresapply_rope_fusion=False).MoE via Megatron's dispatcher.
Gemma4MoELayersubclassesMoELayerto reuse the alltoall dispatcher, grouped-GEMM experts, and EP sharding — but swaps inGemma4Router(no-scale RMSNorm → learnable scale → proj → softmax → topk → renormalize → per-expert scale). The renormalize-then-scale order is load-bearing and guarded bytest_router_matches_hf_reference_equation.Per-layer scalars loaded from the HF safetensors checkpoint as non-learnable buffers and applied after the FFN residual add. Rank-0 reads +
broadcast_object_listto avoid O(world_size) filesystem hits. Mandatory by default;GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1downgrades to a warning.attention_k_eq_von global layers:linear_qkvemits[q, k]only withv_proj_weight == k_proj_weight, and_split_qkv_global_k_eq_vderivesV = v_norm(raw_k),K = k_norm(raw_k)without mutatingself.k_layernorm.HF↔Megatron conversion (
slime/backends/megatron_utils/megatron_to_hf/gemma4.py,slime_plugins/mbridge/gemma4.py)convert_gemma4_to_hfemits stacked 3D expert tensors(E, 2I, H)/(E, H, I)for sglang, drops the.weightsuffix on stacked keys, handles PP layer offset viaget_transformer_layer_offset.Gemma4Bridge._build_configexplicitly setsactivation_func = gelu_pytorch_tanh+bias_activation_fusion = False— Gemma uses GeGLU, not SwiGLU.test_gemma4_qkv_roundtrip.py).Retool integration (
examples/retool/generate_with_retool_gemma4.py)tokenizer.apply_chat_templateon Gemma4's native<|turn>role\n...<turn|>format instead of the hardcoded Qwen ChatML in the stock retool generate.<tool_call>{json}</tool_call>parsing contract so the sharedreward_func/postprocess_predictionsregex still works. Gemma4 instruct follows the system-prompt instruction despite its native<|tool_call>format being different.Configs (
scripts/models/gemma4-26B-A4B.sh,gemma4-31B.sh)MODEL_ARGS templates;
--swigluintentionally omitted (activation is set byget_gemma4_spec).Tests (
tests/gemma4/)10 test files:
test_gemma4_attention.py— K=V global split, sliding delegation, V=v_norm()test_gemma4_router.py— Gemma4Router equation, renorm-then-scale order, MoELayer.route adaptertest_gemma4_dual_rope.py— concat/split semantics, real Megatron RotaryEmbedding integration (CUDA-gated)test_gemma4_provider.py—_install_hooks,_load_layer_scalars, PP offset translationtest_gemma4_qkv_roundtrip.py— HF↔Mcore QKV bit-exact roundtriptest_gemma4_bridge.py— forward-only Megatron→HF conversiontest_gemma4_layer_integration.py— real layer build + forward (CUDA-gated)test_gemma4_cp_attention.py— SDPACoreAttention CP production paths (4 tests exercising the forward dispatch + zig-zag global indices)test_gemma4_hf_key_contract.py— asserts our Megatron→HF converter emits every key HF'sGemma4ForConditionalGenerationstate_dict expects; pins the sglang / HF loader contract against future drift.test_gemma4_layer_scalar_broadcast.py— 2-rank gloo test exercising the realrank-0-read + broadcast_object_listpath for layer scalars (single-process tests can't catch a regression that only affects rank>0).50/50 pass on an H200 GPU host via
pytest tests/gemma4/.Correctness: parity test results
All Gemma4 parallel dims verified via standalone parity harnesses. Summary:
force_cp_subseq_maskEnd-to-end validation
Trained three concurrent configurations via the retool recipe on dapo-math-17k. All runs: 48 H200 GPUs (16 actor + 32 rollout), global_batch_size=256, n_samples_per_prompt=8, GRPO, lr=5e-6, bf16, 0 pod restarts.
Reference
Hugging Face Transformers Gemma4 source:
SGLang Gemma4 model loader: