Skip to content

feat(gemma4): add Gemma4 26B-A4B MoE and 31B dense support#1855

Open
leofan-lab wants to merge 1 commit intoTHUDM:mainfrom
leofan-lab:gemma4-pr
Open

feat(gemma4): add Gemma4 26B-A4B MoE and 31B dense support#1855
leofan-lab wants to merge 1 commit intoTHUDM:mainfrom
leofan-lab:gemma4-pr

Conversation

@leofan-lab
Copy link
Copy Markdown
Contributor

@leofan-lab leofan-lab commented Apr 24, 2026

Summary

Adds Gemma4 (26B-A4B MoE + 31B dense) model support to slime for RL training. Covers model architecture, HF↔Megatron weight conversion, retool integration, and 10 unit tests.

Test validation: parity tests across TP/PP/DP/CP/EP/Sliding Window all pass (see table below), RL training runs on Gemma4 26B-A4B MoE and 31B dense for over 200+ rollout steps using the retool recipe (see End-to-end validation below).


What's in this PR

Model plugin (slime_plugins/models/gemma4.py, gemma4_provider.py)

  • Heterogeneous attention. Sliding layers use head_dim=256 + flash-attn with a (sw-1, 0) left-window mask. Global layers use head_dim=512 — flash-attn 2.7.4 doesn't support >256, so they go through a PyTorch SDPA path.

  • Context parallelism. SDPACoreAttention._forward_cp_subseq_mask is a unified CP>1 path for both global and sliding layers:

    1. All-gather K/V across the CP group with a differentiable gather (so grads flow back to the originating rank).
    2. Un-zigzag the gathered K/V via a permutation so mask indices line up with pure global order.
    3. Build per-sub-sequence causal (+ optional sliding-window) masks from slime's zig-zag global Q positions.
  • Dual RoPE. DualRotaryEmbedding wraps (local, global) ropes and emits a single concatenated tensor per call. The layer slices per-layer based on is_sliding. Concat (not tuple) so Megatron's existing 2-tuple (self_attn, cross_attn) rope plumbing doesn't misread it. Global layers use partial-rotary via zeroed inv_freq tail entries (requires apply_rope_fusion=False).

  • MoE via Megatron's dispatcher. Gemma4MoELayer subclasses MoELayer to reuse the alltoall dispatcher, grouped-GEMM experts, and EP sharding — but swaps in Gemma4Router (no-scale RMSNorm → learnable scale → proj → softmax → topk → renormalize → per-expert scale). The renormalize-then-scale order is load-bearing and guarded by test_router_matches_hf_reference_equation.

  • Per-layer scalars loaded from the HF safetensors checkpoint as non-learnable buffers and applied after the FFN residual add. Rank-0 reads + broadcast_object_list to avoid O(world_size) filesystem hits. Mandatory by default; GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1 downgrades to a warning.

  • attention_k_eq_v on global layers: linear_qkv emits [q, k] only with v_proj_weight == k_proj_weight, and _split_qkv_global_k_eq_v derives V = v_norm(raw_k), K = k_norm(raw_k) without mutating self.k_layernorm.

HF↔Megatron conversion (slime/backends/megatron_utils/megatron_to_hf/gemma4.py, slime_plugins/mbridge/gemma4.py)

  • convert_gemma4_to_hf emits stacked 3D expert tensors (E, 2I, H) / (E, H, I) for sglang, drops the .weight suffix on stacked keys, handles PP layer offset via get_transformer_layer_offset.
  • Gemma4Bridge._build_config explicitly sets activation_func = gelu_pytorch_tanh + bias_activation_fusion = False — Gemma uses GeGLU, not SwiGLU.
  • QKV roundtrip is bit-exact for both sliding and K=V global layers (test_gemma4_qkv_roundtrip.py).

Retool integration (examples/retool/generate_with_retool_gemma4.py)

  • Uses tokenizer.apply_chat_template on Gemma4's native <|turn>role\n...<turn|> format instead of the hardcoded Qwen ChatML in the stock retool generate.
  • Keeps the <tool_call>{json}</tool_call> parsing contract so the shared reward_func / postprocess_predictions regex still works. Gemma4 instruct follows the system-prompt instruction despite its native <|tool_call> format being different.

Configs (scripts/models/gemma4-26B-A4B.sh, gemma4-31B.sh)

MODEL_ARGS templates; --swiglu intentionally omitted (activation is set by get_gemma4_spec).

Tests (tests/gemma4/)

10 test files:

  • test_gemma4_attention.py — K=V global split, sliding delegation, V=v_norm()
  • test_gemma4_router.py — Gemma4Router equation, renorm-then-scale order, MoELayer.route adapter
  • test_gemma4_dual_rope.py — concat/split semantics, real Megatron RotaryEmbedding integration (CUDA-gated)
  • test_gemma4_provider.py_install_hooks, _load_layer_scalars, PP offset translation
  • test_gemma4_qkv_roundtrip.py — HF↔Mcore QKV bit-exact roundtrip
  • test_gemma4_bridge.py — forward-only Megatron→HF conversion
  • test_gemma4_layer_integration.py — real layer build + forward (CUDA-gated)
  • test_gemma4_cp_attention.py — SDPACoreAttention CP production paths (4 tests exercising the forward dispatch + zig-zag global indices)
  • test_gemma4_hf_key_contract.py — asserts our Megatron→HF converter emits every key HF's Gemma4ForConditionalGeneration state_dict expects; pins the sglang / HF loader contract against future drift.
  • test_gemma4_layer_scalar_broadcast.py — 2-rank gloo test exercising the real rank-0-read + broadcast_object_list path for layer scalars (single-process tests can't catch a regression that only affects rank>0).

50/50 pass on an H200 GPU host via pytest tests/gemma4/.


Correctness: parity test results

All Gemma4 parallel dims verified via standalone parity harnesses. Summary:

Parallel dim Test Result
TP TP=2 EP=4 vs TP=1 EP=1, fp32 bisect-24 100% argmax agreement
PP PP=2 vs PP=1 (Megatron scheduler), fp32 bit-exact (0.000 abs diff)
DP DP=2 vs DP=1 gradnorm, 2-sample batch, bisect-10 rel diff 0.008%
CP CP=2 vs CP=1 gradnorm, bisect-10, force_cp_subseq_mask rel diff 0.526%
EP TP=2 EP=1 vs TP=2 EP=4 logits 100% argmax
SW Gemma4 sliding-window invariance at layer 0 (perturb pos <400, probe pos ≥1424, bisect-1) rel diff 6.7e-7

End-to-end validation

Trained three concurrent configurations via the retool recipe on dapo-math-17k. All runs: 48 H200 GPUs (16 actor + 32 rollout), global_batch_size=256, n_samples_per_prompt=8, GRPO, lr=5e-6, bf16, 0 pod restarts.

  • Gemma4 31B dense (CP=1) — actor: TP=4, DP=4
  • Gemma4 31B dense (CP=2) — actor: TP=4, CP=2, DP=2
  • Gemma4 26B-A4B MoE — actor: TP=2, PP=2, CP=2, EP=2, DP=2

Reference

Hugging Face Transformers Gemma4 source:

SGLang Gemma4 model loader:

Plugin (model, provider, mbridge), HF<->Megatron converter, retool
integration, and 10 unit tests.

Highlights:

- Heterogeneous attention. Sliding layers go through flash-attn
  (head_dim=256); global layers go through a PyTorch SDPA path because
  flash-attn 2.x doesn't support head_dim=512.

- Unified CP>1 path. SDPACoreAttention._forward_cp_subseq_mask
  all-gathers K/V, un-zigzags to global order, and builds
  per-sub-sequence causal (+ sliding-window) masks from slime's zig-zag
  global indices. Covers both global and sliding layers.

- Dual RoPE emits a single concatenated tensor (not a tuple) so it
  doesn't collide with Megatron's (self_attn, cross_attn) rope plumbing.

- Gemma4MoELayer subclasses MoELayer to reuse dispatcher + grouped-GEMM
  + EP, then swaps in Gemma4Router (RMSNorm -> scale -> proj -> softmax
  -> topk -> renormalize -> per-expert scale; order is load-bearing).

- Per-layer scalars loaded from the HF checkpoint as non-trainable
  buffers; rank-0 reads + broadcast to avoid O(world_size) FS hits.

- Converter emits stacked 3D expert tensors for sglang and handles PP
  offset via get_transformer_layer_offset.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@EthanChen1234
Copy link
Copy Markdown

@leofan-lab thanks for your great job.

could you please share your training script? I start the SGlang server of Gemma4 26B-A4B model, raise TP error.

@leofan-lab
Copy link
Copy Markdown
Contributor Author

@leofan-lab thanks for your great job.

could you please share your training script? I start the SGlang server of Gemma4 26B-A4B model, raise TP error.

sure:

python3 train_async.py \
      --actor-num-nodes 2 \
      --actor-num-gpus-per-node 8 \
      --rollout-num-gpus 32 \
      ${MODEL_ARGS[@]} \
      --hf-checkpoint $HF_CKPT \
      --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" \
      --ref-load $TORCH_DIST \
      --save $SAVE_DIR \
      --save-interval 100 \
      --no-save-optim \
      --prompt-data $DATASET \
      --input-key prompt \
      --label-key label \
      --apply-chat-template-kwargs '{"enable_thinking": true}' \
      --rollout-shuffle \
      --reward-key score \
      --num-rollout 1000 \
      --rollout-batch-size 64 \
      --n-samples-per-prompt 8 \
      --rollout-max-response-len 8192 \
      --rollout-max-context-len 16384 \
      --rollout-temperature 1 \
      --global-batch-size 256 \
      --balance-data \
      --eval-interval 9999 \
      --eval-prompt-data aime $AIME_PATH \
      --n-samples-per-eval-prompt 16 \
      --eval-max-response-len 16384 \
      --eval-top-p 1 \
      --tensor-model-parallel-size 2 \
      --expert-model-parallel-size 2 \
      --expert-tensor-parallel-size 1 \
      --sequence-parallel \
      --pipeline-model-parallel-size 2 \
      --context-parallel-size 2 \
      --use-distributed-optimizer \
      --recompute-granularity full \
      --recompute-method uniform \
      --recompute-num-layers 1 \
      --use-dynamic-batch-size \
      --max-tokens-per-gpu 8192 \
      --advantage-estimator grpo \
      --use-kl-loss \
      --kl-loss-coef 0.00 \
      --kl-loss-type low_var_kl \
      --entropy-coef 0.001 \
      --eps-clip 0.2 \
      --eps-clip-high 0.28 \
      --use-tis \
      --update-weights-interval 4 \
      --optimizer adam \
      --lr 5e-6 \
      --lr-decay-style constant \
      --weight-decay 0.1 \
      --adam-beta1 0.9 \
      --adam-beta2 0.98 \
      --rollout-num-gpus-per-engine 2 \
      --sglang-mem-fraction-static 0.7 \
      --sglang-cuda-graph-max-bs 64 \
      --sglang-disable-cuda-graph \
      --sglang-enable-metrics \
      --attention-dropout 0.0 \
      --hidden-dropout 0.0 \
      --accumulate-allreduce-grads-in-fp32 \
      --attention-softmax-in-fp32 \
      --custom-generate-function-path generate_with_retool_gemma4.generate \
      --rollout-function-path fully_async_rollout.generate_rollout_fully_async \
      --eval-function-path slime.rollout.sglang_rollout.generate_rollout \
      --custom-rm-path generate_with_retool_gemma4.reward_func \
      2>&1 | tee "$SAVE_DIR/train.log"

The error you came across for SGLang is most likely because you need to upgrade its version to include Gemma4
Sharing my docker file:

# transformers >= 5.5.3 is the first release with Gemma4 model configs plus an
# internal Qwen3-VL dict-typed vision_config fix.
RUN pip install --no-cache-dir "transformers>=5.5.3"

# Pin SGLang to the commit that lands Gemma4 support (upstream PR #21952).
# The `python3 -c ...` block writes a fake sglang_kernel-0.4.1.dist-info/
# METADATA so `pip install -e` doesn't fail resolving the kernel package
# name mismatch against the prebuilt .so that ships with the base image.
RUN cd /sgl-workspace/sglang && \
    git fetch --depth=1 origin 2813cb6d9a5b6e8fb02435647917e6f1652d7940 && \
    git checkout -f FETCH_HEAD && \
    pip install -e "python/[all]" --no-deps && \
    python3 -c "\
import pathlib, sysconfig; \
d = pathlib.Path(sysconfig.get_path('purelib')) / 'sglang_kernel-0.4.1.dist-info'; \
d.mkdir(exist_ok=True); \
(d / 'METADATA').write_text('Metadata-Version: 2.1\nName: sglang-kernel\nVersion: 0.4.1\n'); \
(d / 'INSTALLER').write_text('pip\n')"

# The pinned SGLang commit predates transformers >= 5.5's move to dict-typed
# vision_config, so qwen3_vl.py breaks with "AttributeError: dict has no
# attribute 'num_hidden_layers'". This two-line sed promotes the dict to a
# SimpleNamespace before it reaches Qwen3VLMoeVisionModel. Delete when we
# bump SGLang past the fix.
RUN sed -i '/self.visual = Qwen3VLMoeVisionModel(/i\        _vc = config.vision_config\n        if isinstance(_vc, dict):\n            from types import SimpleNamespace\n            _vc = SimpleNamespace(**_vc)' \
    /sgl-workspace/sglang/python/sglang/srt/models/qwen3_vl.py && \
    sed -i 's/            config.vision_config,/            _vc,/' \
    /sgl-workspace/sglang/python/sglang/srt/models/qwen3_vl.py

@EthanChen1234
Copy link
Copy Markdown

EthanChen1234 commented May 6, 2026

@leofan-lab thanks for your great job.
could you please share your training script? I start the SGlang server of Gemma4 26B-A4B model, raise TP error.

sure:

python3 train_async.py \
      --actor-num-nodes 2 \
      --actor-num-gpus-per-node 8 \
      --rollout-num-gpus 32 \
      ${MODEL_ARGS[@]} \
      --hf-checkpoint $HF_CKPT \
      --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" \
      --ref-load $TORCH_DIST \
      --save $SAVE_DIR \
      --save-interval 100 \
      --no-save-optim \
      --prompt-data $DATASET \
      --input-key prompt \
      --label-key label \
      --apply-chat-template-kwargs '{"enable_thinking": true}' \
      --rollout-shuffle \
      --reward-key score \
      --num-rollout 1000 \
      --rollout-batch-size 64 \
      --n-samples-per-prompt 8 \
      --rollout-max-response-len 8192 \
      --rollout-max-context-len 16384 \
      --rollout-temperature 1 \
      --global-batch-size 256 \
      --balance-data \
      --eval-interval 9999 \
      --eval-prompt-data aime $AIME_PATH \
      --n-samples-per-eval-prompt 16 \
      --eval-max-response-len 16384 \
      --eval-top-p 1 \
      --tensor-model-parallel-size 2 \
      --expert-model-parallel-size 2 \
      --expert-tensor-parallel-size 1 \
      --sequence-parallel \
      --pipeline-model-parallel-size 2 \
      --context-parallel-size 2 \
      --use-distributed-optimizer \
      --recompute-granularity full \
      --recompute-method uniform \
      --recompute-num-layers 1 \
      --use-dynamic-batch-size \
      --max-tokens-per-gpu 8192 \
      --advantage-estimator grpo \
      --use-kl-loss \
      --kl-loss-coef 0.00 \
      --kl-loss-type low_var_kl \
      --entropy-coef 0.001 \
      --eps-clip 0.2 \
      --eps-clip-high 0.28 \
      --use-tis \
      --update-weights-interval 4 \
      --optimizer adam \
      --lr 5e-6 \
      --lr-decay-style constant \
      --weight-decay 0.1 \
      --adam-beta1 0.9 \
      --adam-beta2 0.98 \
      --rollout-num-gpus-per-engine 2 \
      --sglang-mem-fraction-static 0.7 \
      --sglang-cuda-graph-max-bs 64 \
      --sglang-disable-cuda-graph \
      --sglang-enable-metrics \
      --attention-dropout 0.0 \
      --hidden-dropout 0.0 \
      --accumulate-allreduce-grads-in-fp32 \
      --attention-softmax-in-fp32 \
      --custom-generate-function-path generate_with_retool_gemma4.generate \
      --rollout-function-path fully_async_rollout.generate_rollout_fully_async \
      --eval-function-path slime.rollout.sglang_rollout.generate_rollout \
      --custom-rm-path generate_with_retool_gemma4.reward_func \
      2>&1 | tee "$SAVE_DIR/train.log"

The error you came across for SGLang is most likely because you need to upgrade its version to include Gemma4 Sharing my docker file:

# transformers >= 5.5.3 is the first release with Gemma4 model configs plus an
# internal Qwen3-VL dict-typed vision_config fix.
RUN pip install --no-cache-dir "transformers>=5.5.3"

# Pin SGLang to the commit that lands Gemma4 support (upstream PR #21952).
# The `python3 -c ...` block writes a fake sglang_kernel-0.4.1.dist-info/
# METADATA so `pip install -e` doesn't fail resolving the kernel package
# name mismatch against the prebuilt .so that ships with the base image.
RUN cd /sgl-workspace/sglang && \
    git fetch --depth=1 origin 2813cb6d9a5b6e8fb02435647917e6f1652d7940 && \
    git checkout -f FETCH_HEAD && \
    pip install -e "python/[all]" --no-deps && \
    python3 -c "\
import pathlib, sysconfig; \
d = pathlib.Path(sysconfig.get_path('purelib')) / 'sglang_kernel-0.4.1.dist-info'; \
d.mkdir(exist_ok=True); \
(d / 'METADATA').write_text('Metadata-Version: 2.1\nName: sglang-kernel\nVersion: 0.4.1\n'); \
(d / 'INSTALLER').write_text('pip\n')"

# The pinned SGLang commit predates transformers >= 5.5's move to dict-typed
# vision_config, so qwen3_vl.py breaks with "AttributeError: dict has no
# attribute 'num_hidden_layers'". This two-line sed promotes the dict to a
# SimpleNamespace before it reaches Qwen3VLMoeVisionModel. Delete when we
# bump SGLang past the fix.
RUN sed -i '/self.visual = Qwen3VLMoeVisionModel(/i\        _vc = config.vision_config\n        if isinstance(_vc, dict):\n            from types import SimpleNamespace\n            _vc = SimpleNamespace(**_vc)' \
    /sgl-workspace/sglang/python/sglang/srt/models/qwen3_vl.py && \
    sed -i 's/            config.vision_config,/            _vc,/' \
    /sgl-workspace/sglang/python/sglang/srt/models/qwen3_vl.py

Hi,

I'm currently trying to upgrade sglang to include Gemma4 support (as in upstream PR #21952). However, I encountered an error while applying the sglang patches:
docker/patch/v0.5.9/sglang.patch
docker/patch/latest/sglang.patch

Could you provide some advice on how to resolve this patch issue? Any guidance would be greatly appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants