diff --git a/examples/retool/generate_with_retool_gemma4.py b/examples/retool/generate_with_retool_gemma4.py new file mode 100644 index 0000000000..eb58462e3c --- /dev/null +++ b/examples/retool/generate_with_retool_gemma4.py @@ -0,0 +1,320 @@ +# Gemma4-compatible retool generate function. +# +# Uses tokenizer.apply_chat_template() instead of the hardcoded Qwen ChatML +# Jinja template in generate_with_retool.py. Everything else (postprocessing, +# tool execution, scoring rules) is reused. +# +# Why a Gemma4-specific version: +# generate_with_retool.py wraps the prompt in a Qwen ChatML template +# (<|im_start|>/<|im_end|>). Gemma4's tokenizer doesn't recognize those +# tokens as specials, and Gemma4's native turn format is +# <|turn>role\n.... Feeding a ChatML-wrapped prompt to Gemma4 +# produces mangled input. +# +# Design choice: +# Fix the chat framing via apply_chat_template, but keep the Qwen-style +# {json} contract in the system prompt. This lets us +# reuse postprocess_predictions / postprocess_responses / execute_predictions +# unchanged. Switching to Gemma4's native <|tool_call>call:... +# format would require rewriting those parsers and is deferred. +# +# Note: the companion yaml must drop --apply-chat-template so sample.prompt +# stays as the raw message list; this function re-templates once with a +# custom system message. +# +# Usage in training args: +# --custom-generate-function-path generate_with_retool_gemma4.generate +# --custom-rm-path generate_with_retool_gemma4.reward_func +import json +from typing import Any + +from generate_with_retool import ( + execute_predictions, + postprocess_predictions, # noqa: F401 - re-exported for external callers + postprocess_responses, # noqa: F401 - re-exported for external callers +) +from tool_sandbox import TOOL_CONFIGS, tool_registry + +from slime.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score +from slime.rollout.sglang_rollout import GenerateState +from slime.utils.http_utils import post +from slime.utils.types import Sample + +_dropped_system_warned = {"v": False} + +DEFAULT_SYSTEM_PROMPT = ( + "You are a helpful assistant that can use Python tools to solve " + "mathematical problems. When you need to perform calculations, use " + "the code_interpreter tool to execute code and get results." +) + + +def _build_tool_instructions(tools: list[dict]) -> str: + """Append Qwen-style tool instructions to the system message. + + We keep the {json} contract (not Gemma4's native + <|tool_call>call:...) so postprocess_predictions' regex and + the reward function stay unchanged. Gemma4-it is strong enough at + instruction-following to emit this format on request. + """ + if not tools: + return "" + tool_specs = "\n".join(json.dumps(tool) for tool in tools) + return ( + "\n\n# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + "\n" + f"{tool_specs}\n" + "\n\n" + "For each function call, return a json object with function name and arguments " + "within XML tags:\n" + "\n" + '{"name": , "arguments": }\n' + "" + ) + + +def _coerce_to_messages(raw_prompt) -> list[dict]: + """Normalize sample.prompt into a list of {role, content} dicts.""" + if isinstance(raw_prompt, list): + return list(raw_prompt) + if isinstance(raw_prompt, str): + return [{"role": "user", "content": raw_prompt}] + raise TypeError(f"Unsupported sample.prompt type: {type(raw_prompt)}") + + +def format_conversation_with_tools( + raw_prompt, + tools: list[dict[str, Any]] | None = None, + system_prompt: str | None = None, + tokenizer=None, +) -> str: + """Render the chat-templated prompt using Gemma4's native template. + + We do NOT pass `tools=` to apply_chat_template — that would trigger + Gemma4's native <|tool>declaration:... tool-spec format, which + downstream postprocess_predictions can't parse. Instead we inline tool + info as text inside the system message (Qwen-style contract). + """ + system_content = system_prompt or DEFAULT_SYSTEM_PROMPT + system_content += _build_tool_instructions(tools or []) + + user_messages = _coerce_to_messages(raw_prompt) + # If the dataset already contains a system message, prefer our system + # prompt (which carries the tool instructions) and drop theirs. + dataset_system = [m for m in user_messages if m.get("role") == "system"] + if dataset_system and not _dropped_system_warned["v"]: + # One-shot log — useful during dataset migrations, silent thereafter. + print( + "[retool-gemma4] dataset supplied a system message; overriding " + "with tool-instruction system prompt. " + f"(dropped: {dataset_system[0].get('content', '')[:120]!r})", + flush=True, + ) + _dropped_system_warned["v"] = True + user_messages = [m for m in user_messages if m.get("role") != "system"] + + messages = [{"role": "system", "content": system_content}, *user_messages] + + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + + +async def generate(args, sample: Sample, sampling_params) -> Sample: + """Custom generation function supporting tool calls (Gemma4 version).""" + assert not args.partial_rollout, "Partial rollout is not supported for this function at the moment." + + # Retried samples (previously aborted / partial) arrive here with stale + # rollout state from the first attempt. Clear it so this generation starts + # clean; otherwise the concatenation below appends new tokens to old ones + # and downstream `slice_log_prob_with_cp` sees a length mismatch. + sample.rollout_log_probs = None + sample.response = "" + sample.response_length = 0 + sample.loss_mask = None + + state = GenerateState(args) + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + # Set up the initial prompt with system prompt and tools + tool_specs = tool_registry.get_tool_specs() + prompt = format_conversation_with_tools( + raw_prompt=sample.prompt, + tools=tool_specs, + tokenizer=state.tokenizer, + ) + + prompt_tokens_ids = state.tokenizer(prompt, add_special_tokens=False)["input_ids"] + response = "" + response_token_ids = [] + loss_masks = [] + tool_call_count = 0 + + if args.rollout_max_context_len is not None: + max_context_length = args.rollout_max_context_len + else: + max_context_length = args.context_parallel_size * args.max_tokens_per_gpu + + for turn in range(TOOL_CONFIGS["max_turns"]): + # Check if total length exceeds max context length + total_length = len(prompt_tokens_ids) + len(response_token_ids) + if total_length >= max_context_length: + sample.status = Sample.Status.TRUNCATED + break + + # Clamp per-turn max_new_tokens to the remaining context budget so a + # single turn cannot push total_length past max_context_length. + remaining_budget = max_context_length - total_length + per_turn_sampling_params = dict(sampling_params) + per_turn_sampling_params["max_new_tokens"] = min( + sampling_params.get("max_new_tokens", remaining_budget), + remaining_budget, + ) + + current_token_ids = prompt_tokens_ids + response_token_ids + payload = { + "input_ids": current_token_ids, + "sampling_params": per_turn_sampling_params, + "return_logprob": True, + } + + try: + import wandb + + if wandb.run is not None: + wandb.log( + { + "debug/payload_length": len(prompt_tokens_ids) + len(response_token_ids), + "debug/available_tools": len(tool_specs), + "debug/tools_used": response.count(""), + "debug/turn": turn, + } + ) + except ImportError: + pass + + output = await post(url, payload) + + if output["meta_info"]["finish_reason"]["type"] == "abort": + sample.status = Sample.Status.ABORTED + return sample + + if "output_token_logprobs" in output["meta_info"]: + cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + cur_response = state.tokenizer.decode(cur_response_token_ids) + cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += cur_log_probs + else: + # sglang returned text but no output_token_logprobs — we cannot + # recover per-token logprobs for this turn, which would desync + # rollout_log_probs from response_token_ids and blow up + # slice_log_prob_with_cp downstream. Abort so the rollout manager + # returns the group to the buffer for retry instead of poisoning + # the trainer. + sample.status = Sample.Status.ABORTED + return sample + + response += cur_response + response_token_ids += cur_response_token_ids + loss_masks += [1] * len(cur_response_token_ids) + + if output["meta_info"]["finish_reason"]["type"] == "length": + break + + next_obs, done = await execute_predictions(cur_response) + if done: + break + + if "" in next_obs: + tool_call_count += 1 + + assert next_obs != "", "Next observation should not be empty." + obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"] + response += next_obs + response_token_ids += obs_tokens_ids + loss_masks += [0] * len(obs_tokens_ids) + + if sample.rollout_log_probs is not None: + sample.rollout_log_probs += [0.0] * len(obs_tokens_ids) + assert len(response_token_ids) == len(sample.rollout_log_probs), ( + f"Token/logp length mismatch at turn {turn}: " + f"{len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" + ) + + # Tool output is appended verbatim and can push total_length past + # max_context_length. Trim tail tokens so the final sample fits the + # training budget exactly. + overflow = len(prompt_tokens_ids) + len(response_token_ids) - max_context_length + if overflow > 0: + response_token_ids = response_token_ids[:-overflow] + loss_masks = loss_masks[:-overflow] + if sample.rollout_log_probs is not None: + sample.rollout_log_probs = sample.rollout_log_probs[:-overflow] + response = state.tokenizer.decode(response_token_ids) + sample.status = Sample.Status.TRUNCATED + break + + if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]: + break + + sample.tokens = prompt_tokens_ids + response_token_ids + sample.response_length = len(response_token_ids) + sample.response = response + sample.loss_mask = loss_masks + # Overwrite raw list prompt with the rendered string. Upstream slime + # (e.g. fully_async_rollout.py:215) does sample.prompt + sample.response + # in log statements and assumes a string; with --apply-chat-template off, + # sample.prompt arrives as a list of message dicts and the concat raises + # TypeError. We've already rendered the string above, so reuse it. + sample.prompt = prompt + + sample.payload_text = prompt + response + sample.payload_has_system = True + sample.payload_has_tools = "# Tools" in prompt + + sample.tool_call_count = tool_call_count + + match output["meta_info"]["finish_reason"]["type"]: + case "length": + sample.status = Sample.Status.TRUNCATED + case "abort": + sample.status = Sample.Status.ABORTED + case "stop": + sample.status = Sample.Status.COMPLETED + + return sample + + +async def reward_func(args, sample, **kwargs): + """Tool-call reward function for Gemma4. + + Mirrors generate_with_retool.reward_func but scores on sample.response + alone — with --apply-chat-template disabled, sample.prompt is a list of + message dicts and cannot be string-concatenated. math_dapo_compute_score + only looks for an Answer: \\boxed{...} pattern, which lives in the + response, so dropping the prompt from the solution string is safe. + """ + if not isinstance(sample, Sample): + raise TypeError("Sample must be an instance of Sample class.") + + solution_str = sample.response + ground_truth = sample.label if sample.label is not None else "" + num_turns = getattr(sample, "tool_call_count", 0) + + result = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=True) + + if result["score"] < 0: + tool_call_reward = (num_turns - 2) / 2 * 0.1 + result["score"] = min(-0.6, result["score"] + tool_call_reward) + + if result["pred"] is None: + result["pred"] = "" + + return result diff --git a/scripts/models/gemma4-26B-A4B.sh b/scripts/models/gemma4-26B-A4B.sh new file mode 100644 index 0000000000..8168393750 --- /dev/null +++ b/scripts/models/gemma4-26B-A4B.sh @@ -0,0 +1,35 @@ +# Gemma4 26B-A4B MoE model configuration +# Based on google/gemma-4-26B-A4B-it +# 30 layers, 2816 hidden, 16 heads (8 kv), 128 experts top-8 +# Features: SWA (window=1024, every 6th layer full attention), gelu_pytorch_tanh + +MODEL_ARGS=( + --spec "slime_plugins.models.gemma4" "get_gemma4_spec" + # Gemma4 uses GeGLU (gated GELU-tanh), not SwiGLU. Activation is set by + # get_gemma4_spec; --swiglu is intentionally omitted. + --num-layers 30 + --hidden-size 2816 + --ffn-hidden-size 2112 + --num-attention-heads 16 + --group-query-attention + --num-query-groups 8 + --kv-channels 256 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 10000 + --rotary-percent 1.0 + --vocab-size 262144 + --qk-layernorm + # MoE + --num-experts 128 + --moe-ffn-hidden-size 704 + --moe-router-topk 8 + --moe-router-dtype fp32 + --moe-router-score-function softmax + --moe-router-load-balancing-type none + --moe-aux-loss-coeff 0.0 + --moe-token-dispatcher-type alltoall + --moe-grouped-gemm +) diff --git a/scripts/models/gemma4-31B.sh b/scripts/models/gemma4-31B.sh new file mode 100644 index 0000000000..e03867724f --- /dev/null +++ b/scripts/models/gemma4-31B.sh @@ -0,0 +1,20 @@ +MODEL_ARGS=( + --spec "slime_plugins.models.gemma4" "get_gemma4_spec" + # Gemma4 uses GeGLU (gated GELU-tanh), not SwiGLU. Activation is set by + # get_gemma4_spec; --swiglu is intentionally omitted. + --num-layers 60 + --hidden-size 5376 + --ffn-hidden-size 21504 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 16 + --kv-channels 256 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 10000 + --rotary-percent 1.0 + --vocab-size 262144 + --qk-layernorm +) diff --git a/slime/backends/megatron_utils/megatron_to_hf/__init__.py b/slime/backends/megatron_utils/megatron_to_hf/__init__.py index d6cccc23f3..06ea95254b 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/slime/backends/megatron_utils/megatron_to_hf/__init__.py @@ -1,4 +1,5 @@ from .deepseekv3 import convert_deepseekv3_to_hf +from .gemma4 import convert_gemma4_to_hf from .glm4 import convert_glm4_to_hf from .glm4moe import convert_glm4moe_to_hf from .gpt_oss import convert_gpt_oss_to_hf @@ -52,6 +53,8 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_qwen3vl_to_hf(args, name, param) elif "qwen2" in model_name or "qwen3" in model_name: converted_named_tensors = convert_qwen2_to_hf(args, name, param) + elif "gemma4" in model_name: + converted_named_tensors = convert_gemma4_to_hf(args, name, param) elif "llama" in model_name: converted_named_tensors = convert_llama_to_hf(args, name, param) elif "mimo" in model_name: diff --git a/slime/backends/megatron_utils/megatron_to_hf/gemma4.py b/slime/backends/megatron_utils/megatron_to_hf/gemma4.py new file mode 100644 index 0000000000..947371a72c --- /dev/null +++ b/slime/backends/megatron_utils/megatron_to_hf/gemma4.py @@ -0,0 +1,189 @@ +import re +import torch + + +_config_cache = {} + +# Per-layer buffers for stacked expert tensors. sglang's Gemma4 loader expects +# `experts.gate_up_proj` as a single 3D tensor of shape [E, 2I, H] and +# `experts.down_proj` as [E, H, I] — it walks all experts inside the loader +# and would silently drop per-expert 2D inputs. We accumulate expert tensors +# as they stream through and emit the stacked form once all num_experts arrive. +_expert_buffers: dict = {} + + +def reset_expert_buffers() -> None: + """Drop any partial expert buckets. Callers that drive the converter from a + long-lived process (tests, repeated conversions) should invoke this between + runs so an interrupted prior conversion doesn't leak its partial state.""" + _expert_buffers.clear() + + +def _get_config(args): + if "config" not in _config_cache: + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) + hf_text = hf_config.text_config if hasattr(hf_config, "text_config") else hf_config + _config_cache["config"] = { + "global_attn_layers": {i for i, t in enumerate(hf_text.layer_types) if t == "full_attention"}, + "local_head_dim": hf_text.head_dim, + "global_head_dim": hf_text.global_head_dim, + "num_attention_heads": hf_text.num_attention_heads, + "local_num_kv_heads": hf_text.num_key_value_heads, + "global_num_kv_heads": hf_text.num_global_key_value_heads, + "hidden_size": hf_text.hidden_size, + "num_experts": getattr(hf_text, "num_experts", 0), + } + return _config_cache["config"] + + +def convert_gemma4_to_hf(args, name, param): + cfg = _get_config(args) + prefix = "model.language_model." + + if name == "module.module.embedding.word_embeddings.weight": + return [(f"{prefix}embed_tokens.weight", param)] + if name == "module.module.output_layer.weight": + return [(f"{prefix}embed_tokens.weight", param)] # tied embeddings + if name == "module.module.decoder.final_layernorm.weight": + return [(f"{prefix}norm.weight", param)] + + match = re.match(r"module\.module\.decoder\.layers\.(\d+)\.(.+)", name) + if match: + layer_idx = int(match.group(1)) + rest = match.group(2) + L = f"{prefix}layers.{layer_idx}" + is_global = layer_idx in cfg["global_attn_layers"] + + if rest == "self_attention.linear_proj.weight": + return [(f"{L}.self_attn.o_proj.weight", param)] + elif rest == "self_attention.linear_qkv.weight": + if is_global: + head_dim = cfg["global_head_dim"] + num_kv_heads = cfg["global_num_kv_heads"] + else: + head_dim = cfg["local_head_dim"] + num_kv_heads = cfg["local_num_kv_heads"] + + q_heads_per_kv = cfg["num_attention_heads"] // num_kv_heads + # Megatron packs QKV as [num_kv_heads, (q_heads_per_kv + 2) * head_dim, hidden] + hidden_size = cfg["hidden_size"] + param = param.view(num_kv_heads, (q_heads_per_kv + 2) * head_dim, hidden_size) + q_dim = q_heads_per_kv * head_dim + q_param = param[:, :q_dim, :].reshape(-1, hidden_size) + k_param = param[:, q_dim:q_dim + head_dim, :].reshape(-1, hidden_size) + + if is_global: + # Global layers: K=V shared, only emit q and k + return [ + (f"{L}.self_attn.q_proj.weight", q_param), + (f"{L}.self_attn.k_proj.weight", k_param), + ] + else: + v_param = param[:, q_dim + head_dim:, :].reshape(-1, hidden_size) + return [ + (f"{L}.self_attn.q_proj.weight", q_param), + (f"{L}.self_attn.k_proj.weight", k_param), + (f"{L}.self_attn.v_proj.weight", v_param), + ] + elif rest == "self_attention.linear_qkv.layer_norm_weight": + return [(f"{L}.input_layernorm.weight", param)] + elif rest == "self_attention.q_layernorm.weight": + return [(f"{L}.self_attn.q_norm.weight", param)] + elif rest == "self_attention.k_layernorm.weight": + return [(f"{L}.self_attn.k_norm.weight", param)] + # Dense MLP paths. For the 31B dense variant this is the single `.mlp` + # submodule; for the 26B-A4B MoE variant Megatron's `.mlp` slot holds + # the MoE block and the parallel dense feed-forward lives at + # `.dense_mlp` (see Gemma4TransformerLayer). Both map to HF's + # `mlp.gate_proj/up_proj/down_proj` since HF calls it `mlp` regardless. + elif rest in ("mlp.linear_fc1.weight", "dense_mlp.linear_fc1.weight"): + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"{L}.mlp.gate_proj.weight", gate_weight), + (f"{L}.mlp.up_proj.weight", up_weight), + ] + elif rest in ("mlp.linear_fc2.weight", "dense_mlp.linear_fc2.weight"): + return [(f"{L}.mlp.down_proj.weight", param)] + elif rest in ("mlp.linear_fc1.layer_norm_weight", "dense_mlp.linear_fc1.layer_norm_weight"): + return [(f"{L}.pre_feedforward_layernorm.weight", param)] + elif rest == "pre_mlp_layernorm.weight": + return [(f"{L}.pre_feedforward_layernorm.weight", param)] + elif rest == "post_attention_layernorm.weight": + return [(f"{L}.post_attention_layernorm.weight", param)] + elif rest == "post_feedforward_layernorm.weight": + return [(f"{L}.post_feedforward_layernorm.weight", param)] + elif rest == "layer_scalar": + # Non-trainable per-layer scalar buffer; HF stores it at + # `layers.N.layer_scalar` (see HF Gemma4TextDecoderLayer). + return [(f"{L}.layer_scalar", param)] + # MoE weights (26B-A4B). Under the MoE variant the MoE block is + # `self.mlp = Gemma4MoELayer`, so router lives at `.mlp.router.*` and + # per-expert TEGroupedLinear weights are at + # `.mlp.experts.linear_fc{1,2}.weight{E}` where E is the GLOBAL + # expert index (already remapped from local→global by callers). + elif rest == "mlp.router.proj.weight": + return [(f"{L}.router.proj.weight", param)] + elif rest == "mlp.router.scale": + return [(f"{L}.router.scale", param)] + elif rest == "mlp.router.per_expert_scale": + return [(f"{L}.router.per_expert_scale", param)] + # Per-expert weights → buffer and emit stacked 3D tensors once all experts + # in the layer have arrived. sglang's Gemma4 loader expects + # `experts.gate_up_proj` shape [E, 2I, H] + # `experts.down_proj` shape [E, H, I] + # as single 3D tensors (unlike qwen3_moe which takes per-expert 2D). + # Rather than patching sglang, we match sglang's expectation here. + else: + expert_match = re.match(r"mlp\.experts\.linear_fc([12])\.weight(\d+)", rest) + if expert_match: + fc, expert_idx = expert_match.group(1), int(expert_match.group(2)) + return _buffer_expert_and_maybe_flush( + layer_idx, fc, expert_idx, param, L, + num_experts=cfg["num_experts"], + ) + + if rest == "pre_feedforward_layernorm_2.weight": + # Legacy: pre_feedforward_layernorm_2 used to live on the layer. + return [(f"{L}.pre_feedforward_layernorm_2.weight", param)] + elif rest == "mlp.pre_feedforward_layernorm_2.weight": + # Current: pre_feedforward_layernorm_2 is owned by Gemma4MoELayer so + # the Megatron state-dict path is `.mlp.pre_feedforward_layernorm_2. + # weight`. HF still expects it at the decoder-layer level (sglang's + # Gemma4DecoderLayer also keeps it there), so emit without the `mlp.`. + return [(f"{L}.pre_feedforward_layernorm_2.weight", param)] + elif rest == "post_feedforward_layernorm_2.weight": + return [(f"{L}.post_feedforward_layernorm_2.weight", param)] + elif rest == "post_feedforward_layernorm_1.weight": + return [(f"{L}.post_feedforward_layernorm_1.weight", param)] + + raise ValueError(f"Unknown Gemma4 parameter name: {name}") + + +def _buffer_expert_and_maybe_flush(layer_idx, fc, expert_idx, param, L_prefix, num_experts): + """Buffer per-expert tensor; emit stacked 3D `experts.gate_up_proj` / `experts.down_proj` + once the bucket for (layer, fc) has all `num_experts` experts.""" + assert num_experts and num_experts > 0, ( + f"num_experts must be known for MoE layer expert conversion, got {num_experts}" + ) + key = (layer_idx, fc) + bucket = _expert_buffers.setdefault(key, {}) + # Deliberately allow re-entry (EP all-gather may re-stream): overwrite. + bucket[expert_idx] = param + + if len(bucket) < num_experts: + return [] + + # Stack in expert-index order. + ordered = [bucket[i] for i in range(num_experts)] + stacked = torch.stack(ordered, dim=0).contiguous() + del _expert_buffers[key] + + if fc == "1": + # Per-expert linear_fc1 comes in as [2*I, H]; stacked is [E, 2I, H]. + # HF stores these WITHOUT a `.weight` suffix; sglang's gemma4 loader + # relies on exact-name lookup after `experts.gate_up_proj → experts.w13_weight`. + return [(f"{L_prefix}.experts.gate_up_proj", stacked)] + else: + # Per-expert linear_fc2 comes in as [H, I]; stacked is [E, H, I]. + return [(f"{L_prefix}.experts.down_proj", stacked)] diff --git a/slime_plugins/mbridge/__init__.py b/slime_plugins/mbridge/__init__.py index 9a918fdbdc..91477ee902 100644 --- a/slime_plugins/mbridge/__init__.py +++ b/slime_plugins/mbridge/__init__.py @@ -1,4 +1,5 @@ from .deepseek_v32 import DeepseekV32Bridge +from .gemma4 import Gemma4Bridge from .glm4 import GLM4Bridge from .glm4moe import GLM4MoEBridge from .glm4moe_lite import GLM4MoELiteBridge @@ -16,4 +17,5 @@ "Qwen3_5Bridge", "MimoBridge", "DeepseekV32Bridge", + "Gemma4Bridge", ] diff --git a/slime_plugins/mbridge/gemma4.py b/slime_plugins/mbridge/gemma4.py new file mode 100644 index 0000000000..77d43d4d37 --- /dev/null +++ b/slime_plugins/mbridge/gemma4.py @@ -0,0 +1,311 @@ +import functools +import re + +import torch +import torch.nn.functional as F +from mbridge.core import register_model +from mbridge.models import Gemma3Bridge + +from slime_plugins.models.gemma4 import get_rope_local_base_freq as _rope_local_base_freq + +# Gemma uses GeGLU (GELU with tanh approximation + gated linear unit), not SwiGLU. +# See: https://developers.googleblog.com/en/gemma-explained-new-in-gemma-2/ +_gelu_tanh = functools.partial(F.gelu, approximate="tanh") + + +@register_model("gemma4") +class Gemma4Bridge(Gemma3Bridge): + """ + Bridge for Gemma 4 dense 31B. + + Megatron-side keys have NO language_model. prefix (text-only model). + HF-side values have model.language_model. prefix (Gemma4ForConditionalGeneration). + """ + + _ATTENTION_MAPPING = { + "decoder.layers.{layer_number}.self_attention.linear_qkv.weight": [ + "model.language_model.layers.{layer_number}.self_attn.q_proj.weight", + "model.language_model.layers.{layer_number}.self_attn.k_proj.weight", + "model.language_model.layers.{layer_number}.self_attn.v_proj.weight", + ], + "decoder.layers.{layer_number}.self_attention.linear_proj.weight": [ + "model.language_model.layers.{layer_number}.self_attn.o_proj.weight", + ], + "decoder.layers.{layer_number}.self_attention.linear_qkv.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.input_layernorm.weight", + ], + "decoder.layers.{layer_number}.self_attention.q_layernorm.weight": [ + "model.language_model.layers.{layer_number}.self_attn.q_norm.weight", + ], + "decoder.layers.{layer_number}.self_attention.k_layernorm.weight": [ + "model.language_model.layers.{layer_number}.self_attn.k_norm.weight", + ], + } + + # Dense MLP entries. For the 31B dense variant these map the single `.mlp` + # submodule directly. For the 26B-A4B MoE variant `.mlp` is the MoE block + # and the dense feed-forward lives at `.dense_mlp` — we map both so state- + # dict round-trips work regardless of variant. + _MLP_MAPPING = { + # 31B dense variant: `.mlp` is the dense MLP. + "decoder.layers.{layer_number}.mlp.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.up_proj.weight", + ], + "decoder.layers.{layer_number}.mlp.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.down_proj.weight", + ], + "decoder.layers.{layer_number}.mlp.linear_fc1.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm.weight", + ], + "decoder.layers.{layer_number}.pre_mlp_layernorm.weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm.weight", + ], + # 26B-A4B MoE variant: `.dense_mlp` is the parallel dense feed-forward. + "decoder.layers.{layer_number}.dense_mlp.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.up_proj.weight", + ], + "decoder.layers.{layer_number}.dense_mlp.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.down_proj.weight", + ], + "decoder.layers.{layer_number}.dense_mlp.linear_fc1.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm.weight", + ], + # MoE router weights (live under `.mlp.router.*` since self.mlp is the + # Gemma4MoELayer in the MoE variant). + "decoder.layers.{layer_number}.mlp.router.proj.weight": [ + "model.language_model.layers.{layer_number}.router.proj.weight", + ], + "decoder.layers.{layer_number}.mlp.router.scale": [ + "model.language_model.layers.{layer_number}.router.scale", + ], + "decoder.layers.{layer_number}.mlp.router.per_expert_scale": [ + "model.language_model.layers.{layer_number}.router.per_expert_scale", + ], + # pre_feedforward_layernorm_2 now owned by Gemma4MoELayer so the Megatron + # path is `.mlp.pre_feedforward_layernorm_2.weight`. HF still expects it + # at the decoder-layer level. See slime_plugins/models/gemma4.py — this + # ownership change aligns slime's router semantic with HF (router sees + # un-normed residual, experts see pre_ff_norm_2(residual)). + "decoder.layers.{layer_number}.mlp.pre_feedforward_layernorm_2.weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm_2.weight", + ], + } + + _OTHER_MAPPING = { + "decoder.layers.{layer_number}.post_attention_layernorm.weight": [ + "model.language_model.layers.{layer_number}.post_attention_layernorm.weight", + ], + "decoder.layers.{layer_number}.post_feedforward_layernorm.weight": [ + "model.language_model.layers.{layer_number}.post_feedforward_layernorm.weight", + ], + "decoder.layers.{layer_number}.layer_scalar": [ + "model.language_model.layers.{layer_number}.layer_scalar", + ], + # MoE variant extra layernorms that wrap the dense + MoE paths before + # summing. `pre_feedforward_layernorm_2` moved to + # `.mlp.pre_feedforward_layernorm_2.weight` (see `_MLP_MAPPING` above); + # `post_feedforward_layernorm_1/_2` still live on the layer directly. + "decoder.layers.{layer_number}.post_feedforward_layernorm_2.weight": [ + "model.language_model.layers.{layer_number}.post_feedforward_layernorm_2.weight", + ], + "decoder.layers.{layer_number}.post_feedforward_layernorm_1.weight": [ + "model.language_model.layers.{layer_number}.post_feedforward_layernorm_1.weight", + ], + } + + # Matches per-expert linear weights emitted by TEGroupedLinear: + # decoder.layers..mlp.experts.linear_fc1.weight + # decoder.layers..mlp.experts.linear_fc2.weight + # where is the layer number and is the GLOBAL expert index + # (after mbridge base's `_weight_name_mapping_mcore_local_to_global` + # has remapped local→global across EP ranks — its built-in logic + # handles the `.mlp.experts.linear_fc` pattern automatically). + _RE_MOE_EXPERT = re.compile( + r"^decoder\.layers\.(\d+)\.mlp\.experts\.linear_fc([12])\.weight(\d+)$" + ) + + _DIRECT_MAPPING = { + "embedding.word_embeddings.weight": "model.language_model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.language_model.norm.weight", + "output_layer.weight": "model.language_model.embed_tokens.weight", # tied embeddings + } + + _BUFFER_NAMES = [ + "model.language_model.layers.{layer_number}.layer_scalar", + ] + + _GLOBAL_ATTN_LAYERS = None # derived from HF config in __init__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + hf_text = self.hf_config.text_config if hasattr(self.hf_config, "text_config") else self.hf_config + layer_types = getattr(hf_text, "layer_types", []) + self._GLOBAL_ATTN_LAYERS = {i for i, t in enumerate(layer_types) if t == "full_attention"} + + def _weight_name_mapping_attention(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = int(split_name[2]) + split_name[2] = "{layer_number}" + key = ".".join(split_name) + + # For global layers with K=V, linear_qkv maps to only [q_proj, k_proj] + if key == "decoder.layers.{layer_number}.self_attention.linear_qkv.weight": + if layer_number in self._GLOBAL_ATTN_LAYERS: + return [ + f"model.language_model.layers.{layer_number}.self_attn.q_proj.weight", + f"model.language_model.layers.{layer_number}.self_attn.k_proj.weight", + ] + + return [x.format(layer_number=layer_number) for x in self._ATTENTION_MAPPING[key]] + + def _weight_name_mapping_mcore_local_to_global( + self, model, consider_ep: bool = True + ): + """Restore the GPT-style local→global mapping for text-only Gemma4. + + Gemma3Bridge (our base class) assumes a VLM structure where + ``model.language_model.decoder.layers`` exists, and only applies the + PP layer-offset remap when that attribute is present. Our Gemma4 + model provider builds a plain ``GPTModel`` (text-only) with + ``model.decoder.layers``, so the Gemma3 check fails silently and all + PP ranks end up mapping their local layer index i → global index i — + which means every PP rank loads HF layers ``0..N/PP-1`` into its + local slots. The result is that, post-conversion, the torch_dist + checkpoint has layer weights cyclically duplicated with period + (num_layers / pp_size). + + We override to delegate to ``Bridge._weight_name_mapping_mcore_local_to_global`` + from the top-level mbridge base class, which walks ``model.decoder.layers`` + directly — matching our GPT-style layout. + """ + from mbridge.core.bridge import Bridge + return Bridge._weight_name_mapping_mcore_local_to_global( + self, model, consider_ep=consider_ep + ) + + def _weight_name_mapping_mlp(self, name: str) -> list[str]: + # Per-expert MoE weight: Megatron names the per-expert tensors + # `mlp.experts.linear_fc{1,2}.weight{E}`. HF stores the 3D stacked + # tensors `experts.gate_up_proj` / `experts.down_proj`; we slice the + # expert row out in `_weight_to_mcore_format`. + m = self._RE_MOE_EXPERT.match(name) + if m: + layer_number, fc = m.group(1), m.group(2) + hf_tensor = "gate_up_proj" if fc == "1" else "down_proj" + return [ + f"model.language_model.layers.{layer_number}.experts.{hf_tensor}", + ] + + split_name = name.split(".") + layer_number = split_name[2] + split_name[2] = "{layer_number}" + key = ".".join(split_name) + return [x.format(layer_number=layer_number) for x in self._MLP_MAPPING[key]] + + def _weight_name_mapping_other(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = split_name[2] + split_name[2] = "{layer_number}" + key = ".".join(split_name) + return [x.format(layer_number=layer_number) for x in self._OTHER_MAPPING[key]] + + def _weight_to_mcore_format(self, mcore_weights_name, hf_weights): + # Per-expert MoE weight: slice the global 3D HF tensor down to one + # expert row. The expert index is encoded in the mcore name by + # `_weight_name_mapping_mcore_local_to_global`, which rewrites local + # weight{j} → weight{global_expert_idx}. + m = self._RE_MOE_EXPERT.match(mcore_weights_name) + if m: + expert_idx = int(m.group(3)) + assert len(hf_weights) == 1, ( + f"expected exactly one HF tensor for expert weight, got {len(hf_weights)}" + ) + # HF shape: [num_experts, out_dim, in_dim]. Slice to [out_dim, in_dim]. + return hf_weights[0][expert_idx].contiguous() + + if ( + "self_attention.linear_qkv." in mcore_weights_name + and "layer_norm" not in mcore_weights_name + ): + m = re.search(r"layers\.(\d+)\.", mcore_weights_name) + layer_num = int(m.group(1)) if m else -1 + is_global = layer_num in self._GLOBAL_ATTN_LAYERS + + hf_text = self.hf_config.text_config if hasattr(self.hf_config, "text_config") else self.hf_config + num_attention_heads = hf_text.num_attention_heads + if is_global: + head_dim = hf_text.global_head_dim + num_kv_heads = hf_text.num_global_key_value_heads + else: + head_dim = hf_text.head_dim + num_kv_heads = hf_text.num_key_value_heads + + # For K=V global layers the HF checkpoint ships `[q, k]` (no + # v_proj); reconstruct V by duplicating K so the Mcore linear_qkv + # weight has the standard `[q, k, v]` layout with v_proj == k_proj. + if len(hf_weights) == 2: + assert is_global and getattr(hf_text, "attention_k_eq_v", True), ( + f"layer {layer_num}: got 2 HF weights ([q, k]) but this is " + f"not a K=V global layer (is_global={is_global}, " + f"attention_k_eq_v=" + f"{getattr(hf_text, 'attention_k_eq_v', None)})" + ) + q, k = hf_weights + hf_weights = [q, k, k.clone()] + + q, k, v = hf_weights + group_dim = head_dim * num_attention_heads // num_kv_heads + assert q.shape[0] == num_kv_heads * group_dim, ( + f"layer {layer_num}: q_proj rows ({q.shape[0]}) must equal " + f"num_kv_heads ({num_kv_heads}) * group_dim ({group_dim}); " + f"check head_dim/num_attention_heads/num_kv_heads consistency" + ) + assert k.shape[0] == num_kv_heads * head_dim, ( + f"layer {layer_num}: k_proj rows ({k.shape[0]}) must equal " + f"num_kv_heads ({num_kv_heads}) * head_dim ({head_dim})" + ) + assert v.shape[0] == num_kv_heads * head_dim, ( + f"layer {layer_num}: v_proj rows ({v.shape[0]}) must equal " + f"num_kv_heads ({num_kv_heads}) * head_dim ({head_dim})" + ) + q = q.view(num_kv_heads, group_dim, -1) + k = k.view(num_kv_heads, head_dim, -1) + v = v.view(num_kv_heads, head_dim, -1) + return torch.cat([q, k, v], dim=1).view(-1, hf_text.hidden_size).contiguous() + + if "linear_fc1.weight" in mcore_weights_name: + assert len(hf_weights) == 2, ( + f"MLP linear_fc1.weight expects [gate_proj, up_proj] from HF " + f"(2 tensors); got {len(hf_weights)}" + ) + gate, up = hf_weights + return torch.cat([gate, up], dim=0) + + # Generic 1:1 passthrough for everything else (layernorm weights, + # single-tensor projections). Placed after the MLP/QKV-specific + # branches so a malformed 1-tensor input for linear_fc1 / linear_qkv + # hits its assertion instead of being silently returned rotated. + if len(hf_weights) == 1: + return hf_weights[0] + + raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}") + + def _build_config(self): + text_config_key = "text_config" if hasattr(self.hf_config, "text_config") else None + hf_text = self.hf_config.text_config if text_config_key else self.hf_config + + return self._build_base_config( + text_config_key=text_config_key, + use_cpu_initialization=False, + add_qkv_bias=False, + qk_layernorm=True, + layernorm_zero_centered_gamma=False, + normalization="RMSNorm", + persist_layer_norm=True, + activation_func=_gelu_tanh, + bias_activation_fusion=False, + bias_dropout_fusion=True, + rope_local_base_freq=_rope_local_base_freq(hf_text), + ) diff --git a/slime_plugins/models/gemma4.py b/slime_plugins/models/gemma4.py new file mode 100644 index 0000000000..2a36601b97 --- /dev/null +++ b/slime_plugins/models/gemma4.py @@ -0,0 +1,1219 @@ +"""Native Megatron Gemma4 transformer layer and config. + +Extends the Gemma3 implementation from mbridge with Gemma4-specific features: +- Heterogeneous attention: global layers use head_dim=512, num_kv_heads=4; + sliding layers use head_dim=256, num_kv_heads=16. +- attention_k_eq_v: global layers reuse K output as V (no v_proj). +- v_norm: RMSNorm without learnable scale applied to V states. +- layer_scalar: buffer multiplied after residual (not learned). +- final_logit_softcapping: applied to output logits in the model wrapper. +- MoE block (26B-A4B): Gemma4's custom router (with per-expert scale) plugged + into Megatron's MoE infrastructure for proper expert-parallel sharding. + The router is still custom (see Gemma4Router); dispatching + grouped-GEMM + come from Megatron's MoELayer + TEGroupedMLP. +""" + +import functools +import logging +from dataclasses import dataclass, replace as dc_replace + +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.moe.moe_layer import MoELayer, BaseMoELayer +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + HAVE_TE = True +except ImportError: + HAVE_TE = False + +from mbridge.models.gemma3.transformer_config import Gemma3TransformerConfig + + +# Gemma uses GeGLU, not SwiGLU. +_gelu_tanh = functools.partial(F.gelu, approximate="tanh") + + +@dataclass +class Gemma4TransformerConfig(Gemma3TransformerConfig): + """Gemma4-specific config extending Gemma3.""" + # Heterogeneous attention: global layers use different head_dim and num_kv_heads + global_kv_channels: int = 512 + global_num_query_groups: int = 4 + global_partial_rotary_factor: float = 0.25 # fraction of global head_dim that gets RoPE + attention_k_eq_v: bool = True # global layers: V = K (no v_proj) + enable_moe_block: bool = False # 26B-A4B MoE variant + + +class VNorm(nn.Module): + """RMSNorm without learnable scale, matching Gemma4's v_norm.""" + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + return (x * torch.pow(x.pow(2).mean(-1, keepdim=True) + self.eps, -0.5)).to(dtype) + + +@dataclass +class Gemma4TransformerLayerSubmodules(TransformerLayerSubmodules): + post_attention_layernorm: ModuleSpec | type = IdentityOp + post_feedforward_layernorm: ModuleSpec | type = IdentityOp + # For MoE-enabled variants (26B-A4B), the primary `mlp` submodule is swapped + # to a Gemma4MoELayer and the original dense MLP moves to `dense_mlp`. This + # keeps the `.mlp.experts.linear_fc...` naming that mbridge's EP auto-handling + # expects while preserving Gemma4's dense+MoE-in-parallel structure. + dense_mlp: ModuleSpec | type = IdentityOp + + +class Gemma4Router(nn.Module): + """Gemma4 MoE router. + + The router equation (mirroring HF ``Gemma4TextTopkRouter``) is: + + h_norm = RMSNorm_no_scale(h) # VNorm: no learnable scale + h_scaled = h_norm * scale / sqrt(H) # learnable per-hidden scale + logits = proj(h_scaled) # [T, E] + probs = softmax(logits, dim=-1) + top_w, top_i = topk(probs, k=top_k) + top_w = top_w / top_w.sum(dim=-1, keepdim=True) # renormalize + top_w = top_w * per_expert_scale[top_i] # per-expert scale + + The renormalise-then-scale order is load-bearing and must match HF: it + produces ``top_w.sum() == per_expert_scale.mean_over_selected`` rather + than a renormalised-back-to-1 distribution. Reversing the order (scale + first, then renormalise) would cancel ``per_expert_scale``. + ``test_router_matches_hf_reference_equation`` guards this. + """ + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_moe_experts + self.top_k = config.moe_router_topk + self.scalar_root_size = self.hidden_size ** -0.5 + self.norm = VNorm(self.hidden_size, eps=config.layernorm_epsilon) + self.proj = nn.Linear(self.hidden_size, self.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(self.num_experts)) + + def forward(self, hidden_states): + # hidden_states: [tokens, hidden_size] + h = self.norm(hidden_states) + h = h * self.scale * self.scalar_root_size + logits = self.proj(h) + probs = torch.softmax(logits, dim=-1) + top_k_weights, top_k_index = torch.topk(probs, k=self.top_k, dim=-1) + # Order matters: renormalize FIRST, then apply per-expert scale. See + # class docstring. + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] + return top_k_weights, top_k_index + + def set_layer_number(self, layer_number): + # Megatron's ``MoELayer.set_layer_number`` delegates to + # ``self.router.set_layer_number``. ``TopKRouter`` uses it for + # aux-loss scoping / logging; Gemma4Router has no aux loss and + # doesn't log, so this is a no-op — but the method MUST exist or + # ``MoELayer.set_layer_number`` raises AttributeError on first call. + pass + + +class Gemma4MoELayer(MoELayer): + """Gemma4 MoE block: Megatron's MoELayer with Gemma4's custom router. + + Megatron's MoELayer hardcodes its own ``TopKRouter`` which uses a + softmax-with-expert-bias scheme. Gemma4 has its own router semantics + (no-scale RMSNorm → learnable per-hidden scale → proj → softmax → topk → + per-expert scale multiplier). We reuse all of Megatron's infrastructure + for dispatching (alltoall), expert parallelism, and grouped-GEMM expert + computation — but swap in our ``Gemma4Router`` and convert its compact + (top_k_weights [T, K], top_k_index [T, K]) output into Megatron's + expected (probs [T, E], routing_map [T, E]) format inside ``route()``. + """ + + def __init__(self, config, submodules=None, layer_number=None, pg_collection=None): + # Bypass MoELayer.__init__ so we can avoid building Megatron's TopKRouter, + # then run the rest of MoELayer's setup ourselves. The parts we need: + # - self.ep_group / num_local_experts / local_expert_indices (from BaseMoELayer) + # - token_dispatcher, experts (alltoall path, GroupedMLP) + # Anything shared-expert related is disabled: Gemma4 has no shared experts in + # this sense — its "dense MLP" lives outside the MoE block in the parent layer. + # + # Fall back to Megatron's global parallel_state when pg_collection isn't + # explicitly passed. TransformerLayer only forwards pg_collection when + # submodules.mlp.module is *exactly* one of + # (MoELayer, GroupedMLP, TEGroupedMLP, SequentialMLP) — an identity check + # via `in`, so Gemma4MoELayer (a MoELayer subclass) slips through and + # receives None. BaseMoELayer.__init__ then crashes on `pg_collection.ep`. + # Same fallback MoELayer.__init__ uses when invoked directly. + # TODO(gemma4): remove this fallback and the matching + # _Gemma4MoELayerWarningFilter when Megatron widens the MLP-type check + # to use issubclass() (search TransformerLayer.__init__ for + # "Unknown MLP type"); both the fallback and the filter will become + # dead code at that point. + if pg_collection is None: + from megatron.core.transformer.moe.moe_utils import get_default_pg_collection + pg_collection = get_default_pg_collection() + BaseMoELayer.__init__( + self, config=config, layer_number=layer_number, pg_collection=pg_collection + ) + # Disable Megatron-checkpoint paths that don't apply here. + self.moe_layer_recompute = False + self.shared_experts_recompute = False + self.submodules = submodules + + # --- Router: Gemma4's custom router, not TopKRouter. --- + self.router = Gemma4Router(config) + + # --- Token dispatcher (identical to MoELayer.__init__). --- + from megatron.core.transformer.moe.token_dispatcher import ( + MoEAllGatherTokenDispatcher, + MoEAlltoAllTokenDispatcher, + MoEFlexTokenDispatcher, + ) + if config.moe_token_dispatcher_type == "allgather": + self.token_dispatcher = MoEAllGatherTokenDispatcher( + self.num_local_experts, self.local_expert_indices, + config=self.config, pg_collection=pg_collection, + ) + elif config.moe_token_dispatcher_type == "alltoall": + self.token_dispatcher = MoEAlltoAllTokenDispatcher( + self.num_local_experts, self.local_expert_indices, + config=self.config, pg_collection=pg_collection, + ) + elif config.moe_token_dispatcher_type == "flex": + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, self.local_expert_indices, + config=self.config, pg_collection=pg_collection, + ) + else: + raise ValueError(f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}") + + # --- Experts: Megatron's GroupedMLP / TEGroupedMLP. --- + self.experts = build_module( + self.submodules.experts, + self.num_local_experts, + self.config, + pg_collection=pg_collection, + ) + + # Gemma4 doesn't use shared experts in Megatron's sense. + self.shared_experts = None + + # cudagraph tensor store (required by MoELayer.forward decorators) + from megatron.core.transformer.moe.moe_utils import MoECudaGraphTensorStore + self.cudagraph_tensor_store = MoECudaGraphTensorStore() + + # pre_feedforward_layernorm_2: applied to experts' input ONLY (router + # input stays un-normed). Matches HF Gemma4TextDecoderLayer: + # hidden_states_flat = residual # router input (un-normed) + # hidden_states_2 = pre_feedforward_layernorm_2(hidden_states_flat) + # hidden_states_2 = experts(hidden_states_2, top_k_index, top_k_weights) + self.pre_feedforward_layernorm_2 = TENorm( + config=config, hidden_size=config.hidden_size, eps=config.layernorm_epsilon, + ) + + def route(self, hidden_states: torch.Tensor): + """Call ``Gemma4Router`` and pack its output into Megatron's + ``(probs, routing_map)`` format. + + ``Gemma4Router`` emits compact top-k tensors: + top_k_weights: [T, K] — routing weights (already scaled by per_expert_scale) + top_k_index: [T, K] — which experts each token routes to + Megatron's dispatcher wants: + probs: [T, E] — weight per (token, expert), 0 where not routed + routing_map: [T, E] — boolean mask + """ + flat = hidden_states.reshape(-1, hidden_states.shape[-1]) + top_k_weights, top_k_index = self.router(flat) + + num_tokens = flat.shape[0] + num_experts = self.config.num_moe_experts + probs = torch.zeros( + num_tokens, num_experts, + dtype=top_k_weights.dtype, device=top_k_weights.device, + ) + probs.scatter_(1, top_k_index, top_k_weights) + routing_map = probs != 0 + return probs, routing_map + + def forward( + self, hidden_states: torch.Tensor, router_input: torch.Tensor | None = None, + ): + """Gemma4 MoE forward with split router / experts inputs. + + HF's ``Gemma4TextDecoderLayer`` routes based on the *un-normed* residual + but feeds the experts the *pre-ff-norm-2'd* residual: + + hidden_states_flat = residual # un-normed + _, tk_w, tk_i = self.router(hidden_states_flat) + experts_input = self.pre_feedforward_layernorm_2(hidden_states_flat) + output = self.experts(experts_input, tk_i, tk_w) + + We take the un-normed residual in ``hidden_states`` and apply + ``pre_feedforward_layernorm_2`` internally to obtain the experts + input. The router path uses the un-normed residual directly. Callers + may pass a different ``router_input`` for tests or ablations; when + ``router_input is None`` (the normal case) the router sees the same + un-normed residual the layer was called with. + + We inline the Megatron parent's ``forward`` body here — rather than + calling ``super().forward`` with a side-channel stash — so the + router input is passed explicitly end-to-end and the code is safe + under activation checkpointing / recomputation. + """ + if self.training and self.attn_tp_group.size() > 1 and not self.config.sequence_parallel: + raise ValueError( + "During training, performance may degrade if MoE and tensor " + "parallelism are enabled without also enabling sequence parallelism." + ) + + router_in = router_input if router_input is not None else hidden_states + experts_in = self.pre_feedforward_layernorm_2(hidden_states) + + def custom_forward(experts_in, router_in): + # Gemma4 has no shared experts; shared_experts_compute returns None. + shared_expert_output = self.shared_experts_compute(experts_in) + probs, routing_map = self.route(router_in) + experts_in2, probs, residual = self.preprocess(experts_in, probs, routing_map) + dispatched_input, probs = self.dispatch(experts_in2, probs) + output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual) + output = self.combine(output, shared_expert_output) + return output, mlp_bias + + # moe_layer_recompute is forced to False in __init__; call directly. + return custom_forward(experts_in, router_in) + + +class Gemma4TransformerLayer(TransformerLayer): + """Gemma4 transformer layer with heterogeneous attention and layer_scalar.""" + + def __init__( + self, + config: Gemma4TransformerConfig, + submodules: Gemma4TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + **kwargs, + ): + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + global_layer_number = layer_number + get_transformer_layer_offset(config) + # Megatron passes `layer_number` as 1-indexed (default 1), so in 0-indexed + # HF space a global layer is `(i+1) % pattern == 0` → `i % pattern == pattern-1`. + # Equivalently: `is_sliding` when `global_layer_number % pattern != 0`. + self.is_sliding = bool(global_layer_number % config.sliding_window_pattern) + self._is_global = not self.is_sliding + + # Global layers have different head_dim (kv_channels) and num_kv_heads + # (num_query_groups). Build the layer against a *cloned* config with + # those overrides so we never mutate the shared transformer config. + # Mutation would be reentrant-unsafe under concurrent layer + # construction and leak global-layer shapes into sibling sliding + # layers if an exception were raised during super().__init__. + layer_config = ( + dc_replace( + config, + kv_channels=config.global_kv_channels, + num_query_groups=config.global_num_query_groups, + ) + if self._is_global + else config + ) + super().__init__( + config=layer_config, submodules=submodules, + layer_number=layer_number, hidden_dropout=hidden_dropout, + **kwargs, + ) + + # Tell the attention module whether this is a global layer + self.self_attention._is_global = self._is_global + + # Replace TE core attention with PyTorch SDPA for all layers. + # Global layers require this because head_dim=512 exceeds flash attention's limit (256). + # Local layers also use SDPA for consistency. + self.self_attention.core_attention = SDPACoreAttention( + config=config, + layer_number=self.layer_number, + attn_mask_type=AttnMaskType.causal, + softmax_scale=config.softmax_scale, + ) + self.self_attention.core_attention._is_sliding = self.is_sliding + + # Post-attention and post-feedforward layernorms (Gemma-specific) + self.post_attention_layernorm = build_module( + submodules.post_attention_layernorm, + config=self.config, hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + self.post_feedforward_layernorm = build_module( + submodules.post_feedforward_layernorm, + config=self.config, hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # Layer scalar (buffer, not learned). Kept in fp32 intentionally — + # HF stores this scalar in fp32 and relies on the implicit upcast of + # ``bf16_hidden * fp32_scalar`` at multiply time (see HF Gemma4 + # ``Gemma4TextDecoderLayer.__init__`` at modeling_gemma4.py:1331). + # Don't switch to ``dtype=self.config.params_dtype``; that would + # silently change the arithmetic. + self.register_buffer("layer_scalar", torch.ones(1)) + + # MoE block (26B-A4B): super().__init__ already built self.mlp from the + # layer spec, which when enable_moe_block=True is a Gemma4MoELayer (not + # a dense MLP). We also build a parallel `dense_mlp` for Gemma4's + # dense + MoE combined-FFN pattern. The two outputs are summed in + # forward(). + self.enable_moe_block = getattr(config, 'enable_moe_block', False) + if self.enable_moe_block: + # Parallel dense MLP branch (sibling to self.mlp, which is the MoE). + self.dense_mlp = build_module( + submodules.dense_mlp, + config=config, + ) + # Gemma4 pre/post layernorms that wrap the two FFN paths. + self.post_feedforward_layernorm_1 = TENorm( + config=config, hidden_size=config.hidden_size, eps=config.layernorm_epsilon, + ) + # pre_feedforward_layernorm_2 now lives INSIDE Gemma4MoELayer + # (matching HF Gemma4TextDecoderLayer semantics: router sees un-normed + # residual, experts see pre_feedforward_layernorm_2(residual)). This + # attribute is kept on the MoE block so mbridge/state-dict paths + # don't change. + self.post_feedforward_layernorm_2 = TENorm( + config=config, hidden_size=config.hidden_size, eps=config.layernorm_epsilon, + ) + + def _forward_dense_ffn(self, pre_mlp_ln): + """Run the dense MLP. ``self.mlp`` is the dense MLP directly for the + 31B variant.""" + out, bias = self.mlp(pre_mlp_ln) + return out + bias if bias is not None else out + + def _forward_moe_ffn(self, residual, pre_mlp_ln): + """Run dense + MoE in parallel and sum (26B-A4B variant). + + Mirrors HF ``Gemma4TextDecoderLayer.forward`` (transformers + modeling_gemma4.py:1376-1391): dense branch goes through + ``post_feedforward_layernorm_1``, MoE branch through + ``post_feedforward_layernorm_2``, the two are summed, and the outer + ``Gemma4TransformerLayer.forward`` applies ``post_feedforward_layernorm`` + to the sum — 3 post-FFN LNs total for MoE layers is correct. + + HF routes on the un-normed residual but feeds experts the + ``pre_feedforward_layernorm_2``'d residual; Gemma4MoELayer applies + that norm internally, so we pass the un-normed residual directly. + """ + dense_out, dense_bias = self.dense_mlp(pre_mlp_ln) + if dense_bias is not None: + dense_out = dense_out + dense_bias + mlp_output = self.post_feedforward_layernorm_1(dense_out) + + moe_output, _ = self.mlp(residual) + moe_output = self.post_feedforward_layernorm_2(moe_output) + + return mlp_output + moe_output + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_context=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + **kwargs, + ): + # Select per-layer rotary embeddings and attention mask + # DualRotaryEmbedding returns concatenated [seq, 1, global_dim + local_dim] tensor. + # Split and select based on layer type. + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + global_dim = getattr(self.config, 'dual_rope_global_dim', 0) + if global_dim > 0 and rotary_pos_emb.shape[-1] > global_dim: + if self.is_sliding: + rotary_pos_emb = rotary_pos_emb[..., global_dim:] # local part + else: + rotary_pos_emb = rotary_pos_emb[..., :global_dim] # global part + elif isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb[1] if self.is_sliding else rotary_pos_emb[0] + if isinstance(attention_mask, tuple): + attention_mask = attention_mask[1] if self.is_sliding else attention_mask[0] + + # Global layers use partial RoPE (25% of head_dim=512 = 128 dims) + # Local layers use full RoPE (100% of head_dim=256 = 256 dims) + # With DualRotaryEmbedding, global RoPE is full-size (512 dims) with zero-padded + # non-rotated dims, so no truncation needed. + # With single RoPE (local only, 256 dims), truncate for global layers. + if not self.is_sliding and rotary_pos_emb is not None: + global_rope_dim = int(self.config.global_kv_channels * self.config.global_partial_rotary_factor) + if rotary_pos_emb.shape[-1] != self.config.global_kv_channels and rotary_pos_emb.shape[-1] > global_rope_dim: + rotary_pos_emb = rotary_pos_emb[..., :global_rope_dim] + + residual = hidden_states + + extra_kwargs = {} + if inference_context is not None: + extra_kwargs["inference_context"] = inference_context + elif inference_params is not None: + extra_kwargs["inference_params"] = inference_params + + # Input layernorm + input_layernorm_output = self.input_layernorm(hidden_states) + + # Self attention + hidden_states, hidden_states_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **extra_kwargs, + ) + + if hidden_states_bias is not None: + hidden_states = hidden_states + hidden_states_bias + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # FFN path: dense-only (31B) vs dense + MoE (26B-A4B). + residual = hidden_states + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + if self.enable_moe_block: + hidden_states = self._forward_moe_ffn(residual, pre_mlp_layernorm_output) + else: + hidden_states = self._forward_dense_ffn(pre_mlp_layernorm_output) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Layer scalar + hidden_states = hidden_states * self.layer_scalar + + output = make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True, + ) + + if self.config.external_cuda_graph and self.training: + return output + return output, context + + +class SDPACoreAttention(nn.Module): + """Gemma4 core attention. + + Replaces TE's DotProductAttention because: + - Global layers have head_dim=512, which flash-attn 2.x doesn't support. + - Sliding-window layers need an explicit left-window mask (HF behavior). + - Context-parallelism on the global layers needs an all-gather+full-attn + path with a differentiable K/V gather. + + Dispatch at call time (packed / thd shape): + - CP > 1 (any layer) : all-gather K/V, apply causal + optional + sliding-window mask computed from slime zig-zag global indices. + - global + CP == 1 : sub-sequence causal SDPA (no O(T²) mask alloc). + - sliding + CP == 1 : flash_attn_varlen_func with (sw-1, 0) window. + """ + + def __init__(self, config, layer_number, attn_mask_type, attention_type="self", + attention_dropout=None, softmax_scale=None, **kwargs): + super().__init__() + # Megatron's SelfAttention.__init__ passes a few kwargs (e.g. cp_comm_type, + # model_comm_pgs) intended for TE's DotProductAttention. We accept-and-ignore + # by name rather than asserting empty; a strict assert breaks whenever + # Megatron/TE add a new kwarg. If a kwarg shows up here that we *should* + # honor (e.g. a new softmax dtype), it will surface as a behavioral bug + # in parity, which is what the test suite covers. + del kwargs + self.config = config + self.softmax_scale = softmax_scale + self.dropout_p = config.attention_dropout if attention_dropout is None else attention_dropout + self._is_sliding = False # set by Gemma4TransformerLayer + + def _resolve_scale(self, hn: int) -> float: + # `0.0 or fallback` would silently mask a misconfigured scale; be explicit. + return self.softmax_scale if self.softmax_scale is not None else (hn ** -0.5) + + @staticmethod + def _zigzag_global_indices(local_len, cp_rank, cp_size, device): + """Global positions of this rank's local Q tokens under slime's + zig-zag CP layout (matches cp_utils.slice_with_cp). + + Local tokens on rank r occupy two global sub-ranges: + [r*cs, (r+1)*cs) and [(2*cp-r-1)*cs, (2*cp-r)*cs) + where cs = local_len / 2 = seq_len / (2*cp_size). + """ + cs = local_len // 2 + first = torch.arange(cp_rank * cs, (cp_rank + 1) * cs, device=device) + second = torch.arange( + (2 * cp_size - cp_rank - 1) * cs, + (2 * cp_size - cp_rank) * cs, + device=device, + ) + return torch.cat([first, second]) + + def _forward_cp_subseq_mask(self, query, key, value, packed_seq_params, + sliding_window=None): + """CP>1 path for any layer: all-gather K/V, then loop over sub-seqs + and apply a per-sub-seq attention mask built from zig-zag global + positions. Supports causal-only (global layers) and causal + + sliding-window (sliding layers). + + Under slime's CP convention, ``packed_seq_params.cu_seqlens_q`` holds + GLOBAL boundaries: each packed sub-sequence on this rank represents + ``(cu[i+1] - cu[i])`` tokens globally but only ``(cu[i+1] - cu[i]) // + cp_size`` tokens locally (the zig-zag slice of this rank's two + chunks, concatenated as [first, second]). + """ + from megatron.core import parallel_state + from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region + + cp_group = parallel_state.get_context_parallel_group() + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + + t_local = query.shape[0] + np_q, hn = query.shape[1], query.shape[2] + nk = key.shape[1] + scale = self._resolve_scale(hn) + + # Differentiable all-gather along the token dim. forward: AG, + # backward: RS — so K/V grads on non-owning ranks flow back to the + # originating rank. The raw `dist.all_gather_into_tensor` has no + # autograd rule and PyTorch prints a "silently incorrect behavior" + # warning + drops those grads. + k_full = gather_from_sequence_parallel_region(key.contiguous(), group=cp_group) + v_full = gather_from_sequence_parallel_region(value.contiguous(), group=cp_group) + # gather_from_sequence_parallel_region stacks each rank's chunk + # consecutively in rank order. Under zig-zag, each rank's [2*cs] + # local tokens are [chunk_r_first, chunk_r_second]. So the gathered + # tensor layout is [r0_first, r0_second, r1_first, r1_second, ...]. + # We need to un-zig-zag into pure global order so mask indices line + # up. Build a permutation that maps gathered index -> global index. + device = query.device + dtype = query.dtype + cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None + + # Sanity: for each packed sub-seq, the GLOBAL length must be + # divisible by 2*cp_size so chunk_size is integer. With cp_size=1 this + # reduces to even-length, which the CP=1 parity-test harness may + # violate (no zig-zag pre-slicing). Skip the check there; permutation + # is identity under cp_size=1 so odd length is harmless. + if cu_seqlens is not None and cp_size > 1: + expected_t_local = 0 + for s_idx in range(len(cu_seqlens) - 1): + s_len = (cu_seqlens[s_idx + 1] - cu_seqlens[s_idx]).item() + assert s_len % (2 * cp_size) == 0, ( + f"sub-sequence {s_idx} global length ({s_len}) is not " + f"divisible by 2*cp_size ({2 * cp_size}); `slice_with_cp` " + "should pad before packing" + ) + expected_t_local += s_len // cp_size + assert expected_t_local == t_local, ( + f"packed-seq local length mismatch: sum(seq_len // cp_size) = " + f"{expected_t_local}, but query.shape[0] = {t_local}" + ) + + # Un-zig-zag: build a permutation over k_full of length `t_full` so + # that `k_full[perm]` is in pure global order. Each sub-sequence + # contributes its own block of size `s_len` to the permutation. + # + # For each sub-sequence, the gathered block layout is: + # [r=0: first_cs, second_cs, r=1: first_cs, second_cs, ...] + # where first_r is global [r*cs, (r+1)*cs) and + # second_r is global [(2*cp-r-1)*cs, (2*cp-r)*cs). + # We build `perm[global_pos] = gathered_pos` within each sub-seq. + if cu_seqlens is None: + # Single-sequence fallback: treat the full t_full as one sub-seq. + t_full_total = k_full.shape[0] + cu_seqlens_list = [0, t_full_total] + else: + cu_seqlens_list = cu_seqlens.tolist() + + # With cp_size=1 the zigzag degenerates to identity and all-gather is + # a no-op; skip the permutation (and the floor-div that would drop the + # trailing odd token for seq_len_global % 2 == 1). + if cp_size > 1: + perm_parts = [] + for s_idx in range(len(cu_seqlens_list) - 1): + seq_start = cu_seqlens_list[s_idx] + seq_len_global = cu_seqlens_list[s_idx + 1] - seq_start + cs = seq_len_global // (2 * cp_size) + # For this sub-seq, build inverse permutation: for each global + # position g in [0, seq_len_global), find its gathered offset. + # gathered_offset(g) = (rank of g) * 2*cs (offset to start of + # that rank's block within the sub-seq) + local_offset_in_rank(g). + # Rank ownership of global position g in [0, 2*cp*cs): + # r = g // cs (if g < cp*cs -> "first half", r = g//cs, + # local_offset = g - r*cs) + # r = 2*cp-1 - g//cs (if g >= cp*cs -> "second half", + # local_offset = cs + g - (2*cp-r-1)*cs) + g = torch.arange(seq_len_global, device=device) + q_idx = g // cs # chunk index 0..2*cp-1 + # rank for chunk q_idx: first half (q_idx < cp) owned by q_idx, + # second half (q_idx >= cp) owned by 2*cp - 1 - q_idx. + owner = torch.where(q_idx < cp_size, q_idx, 2 * cp_size - 1 - q_idx) + # local offset within that rank's 2*cs block: + # first half (q_idx < cp): local = g - owner*cs + # second half (q_idx >= cp): local = cs + (g - (2*cp-1-owner)*cs) + local_in_rank = torch.where( + q_idx < cp_size, + g - owner * cs, + cs + (g - (2 * cp_size - 1 - owner) * cs), + ) + gathered_offset = owner * (2 * cs) + local_in_rank + perm_parts.append(gathered_offset + seq_start) # add sub-seq base + perm = torch.cat(perm_parts) + k_full = k_full.index_select(0, perm) + v_full = v_full.index_select(0, perm) + + out = torch.empty(t_local, np_q * hn, dtype=dtype, device=device) + + local_offset = 0 + for s_idx in range(len(cu_seqlens_list) - 1): + seq_start = cu_seqlens_list[s_idx] + seq_len_global = cu_seqlens_list[s_idx + 1] - seq_start + local_len = seq_len_global // cp_size # this sub-seq's local Q count + + q_seq = query[local_offset:local_offset + local_len] + k_seq = k_full[seq_start:seq_start + seq_len_global] + v_seq = v_full[seq_start:seq_start + seq_len_global] + + q4 = q_seq.unsqueeze(0).transpose(1, 2) # [1, np, local_len, hn] + k4 = k_seq.unsqueeze(0).transpose(1, 2) # [1, nk, seq_len, hn] + v4 = v_seq.unsqueeze(0).transpose(1, 2) + + # Global positions of local Q tokens. cp_size=1 degenerates to + # identity; use arange to preserve odd-length seqs (zigzag helper + # floor-divides, dropping the trailing token). + if cp_size > 1: + row_idx = self._zigzag_global_indices(local_len, cp_rank, cp_size, device) + else: + row_idx = torch.arange(local_len, device=device) + col_idx = torch.arange(seq_len_global, device=device) + # Causal: K pos <= Q pos. Sliding: also K pos > Q pos - sw. + forbid_future = col_idx[None, :] > row_idx[:, None] + if sliding_window is not None and sliding_window > 0: + forbid_past = col_idx[None, :] < (row_idx[:, None] - (sliding_window - 1)) + forbid = forbid_future | forbid_past + else: + forbid = forbid_future + mask = torch.where( + forbid, torch.finfo(dtype).min, 0.0, + ).to(dtype=dtype) + + o = F.scaled_dot_product_attention( + q4, k4, v4, attn_mask=mask[None, None, :, :], + dropout_p=self.dropout_p if self.training else 0.0, + scale=scale, enable_gqa=(np_q != nk), + ) + out[local_offset:local_offset + local_len] = o.transpose(1, 2).reshape(local_len, -1) + local_offset += local_len + + return out + + def _forward_thd_flash(self, query, key, value, cu_seqlens): + """Sliding-window or head_dim<=256 path via flash_attn_varlen_func. + + CP==1 only. For CP>1, `_forward_cp_subseq_mask` handles zig-zag. + + Sliding-window layers must pass `window_size=(sliding_window-1, 0)` so + only tokens within `sliding_window` positions back are attended to — + this matches HF's `sliding_window_mask_function`. Global layers and + dense-attention sliding layers use the default full-causal window. + """ + from flash_attn import flash_attn_varlen_func + + window_size = (-1, -1) # full causal when causal=True + if self._is_sliding: + sw = getattr(self.config, "sliding_window", None) + if sw and sw > 0: + window_size = (int(sw) - 1, 0) + + cu = cu_seqlens.to(torch.int32) + max_seqlen = (cu[1:] - cu[:-1]).max().item() + out = flash_attn_varlen_func( + query.contiguous(), key.contiguous(), value.contiguous(), + cu_seqlens_q=cu, cu_seqlens_k=cu, + max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + dropout_p=self.dropout_p if self.training else 0.0, + softmax_scale=self._resolve_scale(query.shape[2]), + causal=True, + window_size=window_size, + ) + return out.reshape(query.shape[0], -1) + + def _forward_thd_sdpa_per_subseq(self, query, key, value, cu_seqlens): + """Per-sub-sequence causal SDPA — used when flash-attn can't handle + head_dim (global layer w/o CP). Avoids materializing a [T, T] mask. + """ + np_q, hn = query.shape[1], query.shape[2] + nk = key.shape[1] + scale = self._resolve_scale(hn) + out = torch.empty(query.shape[0], np_q * hn, dtype=query.dtype, device=query.device) + for i in range(len(cu_seqlens) - 1): + s = cu_seqlens[i].item() + e = cu_seqlens[i + 1].item() + q4 = query[s:e].unsqueeze(0).transpose(1, 2) # [1, np, L, hn] + k4 = key[s:e].unsqueeze(0).transpose(1, 2) + v4 = value[s:e].unsqueeze(0).transpose(1, 2) + o = F.scaled_dot_product_attention( + q4, k4, v4, + dropout_p=self.dropout_p if self.training else 0.0, + scale=scale, is_causal=True, enable_gqa=(np_q != nk), + ) + out[s:e] = o.transpose(1, 2).reshape(e - s, -1) + return out + + def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, + packed_seq_params=None, **kwargs): + cp_size = getattr(self.config, "context_parallel_size", 1) or 1 + is_thd = query.dim() == 3 + + # Parity-test hook: force CP=1 to take the same _forward_cp_subseq_mask + # code path as CP>1 so parity comparisons don't confound kernel choice + # (flash-attn vs SDPA) with CP-correctness. NEVER set in production — + # flash-attn is faster than SDPA-with-mask. + force_cp_path = getattr(self.config, "force_cp_subseq_mask", False) + + if is_thd: + # CP>1 (any layer): all-gather KV, apply zig-zag-aware mask. + # Sliding layers also need sliding_window masking in addition to + # causal — `_forward_cp_subseq_mask` handles both. + if cp_size > 1 or force_cp_path: + sw = None + if self._is_sliding: + sw_cfg = getattr(self.config, "sliding_window", None) + if sw_cfg and sw_cfg > 0: + sw = int(sw_cfg) + return self._forward_cp_subseq_mask( + query, key, value, packed_seq_params, sliding_window=sw, + ) + + # CP==1: use the existing per-layer-type paths. + cu_seqlens = None + if packed_seq_params is not None: + cu_seqlens = packed_seq_params.cu_seqlens_q + + hn = query.shape[2] + if cu_seqlens is not None: + if hn <= 256: + return self._forward_thd_flash(query, key, value, cu_seqlens) + # Global layer, no CP, packed — flash-attn won't take head_dim>256. + return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens) + + # Un-packed thd (rare: eval/smoke test with batch=1). Plain SDPA. + q = query.unsqueeze(0).transpose(1, 2) + k = key.unsqueeze(0).transpose(1, 2) + v = value.unsqueeze(0).transpose(1, 2) + nq, nk = q.shape[1], k.shape[1] + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.dropout_p if self.training else 0.0, + scale=self._resolve_scale(hn), is_causal=True, + enable_gqa=(nq != nk), + ) + return out.transpose(1, 2).reshape(query.shape[0], -1) + + # bshd path: 4D input [seq, batch, np, hn]. Used by smoke tests and + # potential eval with --qkv-format bshd. + q = query.permute(1, 2, 0, 3) + k = key.permute(1, 2, 0, 3) + v = value.permute(1, 2, 0, 3) + nq, nk = q.shape[1], k.shape[1] + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.dropout_p if self.training else 0.0, + scale=self._resolve_scale(query.shape[3]), is_causal=True, + enable_gqa=(nq != nk), + ) + return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1) + + +class Gemma4SelfAttention(SelfAttention): + """SelfAttention with Gemma4-specific modifications: + - v_norm: RMSNorm without learnable scale applied to value states. + - attention_k_eq_v: on global layers the linear_qkv projection emits + ``[q, k]`` only (no v_proj) and V is derived from K — specifically + ``V = v_norm(raw_k)`` while ``K = k_norm(raw_k)``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_global = False # set by Gemma4TransformerLayer after construction + self.v_norm = VNorm(self.hidden_size_per_attention_head, eps=self.config.layernorm_epsilon) + + def _split_qkv_global_k_eq_v(self, hidden_states): + """Split linear_qkv output for global K=V layers. + + The Mcore linear_qkv weight for a K=V global layer is built with + ``v_proj_weight == k_proj_weight`` (see Gemma4Bridge + convert_gemma4_to_hf), + so ``linear_qkv(h)`` emits Q/K/V with ``raw_k == raw_v``. Gemma4's + per-head norms then apply as ``key = k_norm(raw_k)`` and + ``value = v_norm(raw_k)`` — *not* ``v_norm(k_norm(raw_k))``. We + reimplement the split here rather than calling the parent so we + don't have to mutate ``self.k_layernorm`` mid-forward. + + Returns (query[sq,b,np,hn], key[sq,b,ng,hn], value[sq,b,ng,hn]). + """ + mixed_qkv, _ = self.linear_qkv(hidden_states) + num_query_heads_per_group = ( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition + ) + new_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + (num_query_heads_per_group + 2) * self.hidden_size_per_attention_head, + ) + mixed_qkv = mixed_qkv.view(*new_shape) + + q_width = num_query_heads_per_group * self.hidden_size_per_attention_head + hn = self.hidden_size_per_attention_head + # _raw_value is bit-identical to raw_key (K=V linear_qkv layout); + # use raw_key directly and skip the extra tensor. + query, raw_key, _raw_value = torch.split(mixed_qkv, [q_width, hn, hn], dim=3) + query = query.reshape(query.size(0), query.size(1), -1, hn) + + if self.q_layernorm is not None: + query = self.q_layernorm(query) + + value = self.v_norm(raw_key) + key = self.k_layernorm(raw_key) if self.k_layernorm is not None else raw_key + return query, key, value + + def get_query_key_value_tensors( + self, hidden_states, key_value_states=None, output_gate=False, split_qkv=True + ): + if self._is_global and self.config.attention_k_eq_v and split_qkv: + if output_gate: + raise NotImplementedError( + "output_gate is not supported together with attention_k_eq_v" + ) + return self._split_qkv_global_k_eq_v(hidden_states) + + result = super().get_query_key_value_tensors( + hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv + ) + if not split_qkv: + return result + + if output_gate: + query, key, value, gate = result + value = self.v_norm(value) + return query, key, value, gate + + query, key, value = result + value = self.v_norm(value) + return query, key, value + + +def _build_moe_submodule_spec(config): + """Build the MoE submodule spec (Gemma4MoELayer + TE GroupedMLP experts).""" + from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend + from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider + + # Reuse Megatron's canonical TE-backed MoE spec factory to get + # TEColumnParallelGroupedLinear / TERowParallelGroupedLinear etc. wired up + # properly for GroupedMLP experts. Then swap the top-level module from + # Megatron's MoELayer to our Gemma4MoELayer, which keeps all that wiring + # but plugs in Gemma4Router. + base_spec = get_moe_module_spec_for_backend( + backend=TESpecProvider(), + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + use_te_activation_func=False, # use plain F.gelu(approximate='tanh') from config.activation_func + ) + return ModuleSpec( + module=Gemma4MoELayer, + submodules=base_spec.submodules, + metainfo=base_spec.metainfo, + ) + + +def get_gemma4_layer_spec_te(config=None) -> ModuleSpec: + """Layer spec for Gemma4 using native Megatron attention with TE. + + If ``config.enable_moe_block`` is set, the main ``mlp`` submodule is a + :class:`Gemma4MoELayer` (so that the state-dict path + ``.mlp.experts.linear_fc*.weight*`` matches mbridge's EP auto-handling), + and the original dense MLP moves to a sibling ``dense_mlp`` submodule that + the layer forward sums with the MoE output. For the 31B dense variant, + ``enable_moe_block=False`` and ``mlp`` stays as the normal Megatron MLP. + """ + # dense_mlp: use a plain (non-fused-layernorm) linear_fc1 so our explicit + # `pre_mlp_layernorm` in the layer forward is the sole norm applied to the + # MLP input. Using TELayerNormColumnParallelLinear here would apply a + # SECOND layernorm inside fc1, resulting in double-normalization and + # ~8× inflated MLP outputs. + dense_mlp_spec = ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ) + if config is not None and getattr(config, "enable_moe_block", False): + mlp_spec = _build_moe_submodule_spec(config) + dense_spec = dense_mlp_spec + else: + mlp_spec = dense_mlp_spec + dense_spec = IdentityOp + + submods = Gemma4TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=Gemma4SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=mlp_spec, + mlp_bda=get_bias_dropout_add, + post_attention_layernorm=TENorm, + post_feedforward_layernorm=TENorm, + dense_mlp=dense_spec, + ) + return ModuleSpec(module=Gemma4TransformerLayer, submodules=submods) + + +@functools.lru_cache(maxsize=4) +def _load_hf_text_config(hf_checkpoint): + """Load HF config and unwrap `text_config` if it's a multimodal wrapper. + + Cached via lru_cache so repeated callers (model provider, mbridge, weight + converter) all share the same parsed object. + """ + from transformers import AutoConfig + cfg = AutoConfig.from_pretrained(hf_checkpoint, trust_remote_code=True) + return cfg.text_config if hasattr(cfg, "text_config") else cfg + + +class _Gemma4MoELayerWarningFilter(logging.Filter): + """Silence the once-per-layer Megatron warning: + 'Unknown MLP type: . Using default kwargs.' + Megatron's TransformerLayer.__init__ recognizes a hardcoded tuple of MLP + classes via `==` (not issubclass), so Gemma4MoELayer (a MoELayer subclass) + falls through to the default-kwargs branch. That branch is correct for us + — Gemma4MoELayer.__init__ fetches its own pg_collection via + get_default_pg_collection — but the warning spams 30 lines per layer at + init and confuses log readers. See gemma4_provider.py install hook. + """ + + def filter(self, record: logging.LogRecord) -> bool: + msg = record.getMessage() + return not ("Unknown MLP type" in msg and "Gemma4MoELayer" in msg) + + +def _install_moe_warning_filter(): + """Silence the per-layer "Unknown MLP type: Gemma4MoELayer" warning. + + Megatron's TransformerLayer compares MLP class identity via ``==``, so + MoELayer subclasses hit the default-kwargs branch and log a warning. + The default-kwargs branch is correct for us (Gemma4MoELayer fetches + pg_collection itself); filter the noise. + """ + tl_logger = logging.getLogger("megatron.core.transformer.transformer_layer") + if getattr(tl_logger, "_gemma4_moe_filter_installed", False): + return + tl_logger.addFilter(_Gemma4MoELayerWarningFilter()) + tl_logger._gemma4_moe_filter_installed = True + + +def _assert_hf_features_supported(hf_text): + """Fail loudly on Gemma4 HF features this plugin doesn't implement.""" + if getattr(hf_text, "hidden_size_per_layer_input", 0): + raise NotImplementedError( + "Gemma4 per-layer input mechanism " + f"(hidden_size_per_layer_input={hf_text.hidden_size_per_layer_input}) " + "is not implemented. See Gemma4TextDecoderLayer.per_layer_input_gate in HF." + ) + if getattr(hf_text, "num_kv_shared_layers", 0): + raise NotImplementedError( + "Gemma4 KV-sharing across the last N layers " + f"(num_kv_shared_layers={hf_text.num_kv_shared_layers}) is not implemented." + ) + if getattr(hf_text, "use_double_wide_mlp", False): + raise NotImplementedError("Gemma4 use_double_wide_mlp is not implemented.") + # Text-only training assumes causal attention; HF's "all" mode disables it. + if getattr(hf_text, "use_bidirectional_attention", "vision") == "all": + raise NotImplementedError( + "Gemma4 use_bidirectional_attention='all' disables causal masking; not supported." + ) + + +def _apply_core_config(config, hf_text): + """Set Gemma4's non-MoE, non-RoPE config fields. + + Mutates ``config`` in place. Promotes its ``__class__`` to + ``Gemma4TransformerConfig`` so the new dataclass fields are reachable + from downstream Megatron code. + """ + # Gemma uses GeGLU (gated gelu-tanh), not SwiGLU. + config.gated_linear_unit = True + config.activation_func = _gelu_tanh + config.bias_activation_fusion = False + + # No MoE-vs-dense layer scheduling: every layer is our Gemma4TransformerLayer + # and the MoE block lives inside its forward. An all-zero list keeps + # transformer_block's non_homogeneous_layers=True branch active (correct for + # 26B's differing global vs sliding head_dim / num_kv_heads). + # Rationale for using moe_layer_freq as the flag: Megatron's + # TransformerBlock.__init__ sets ``non_homogeneous_layers = True`` iff + # ``config.moe_layer_freq is not None``. We only need that flag on — + # the actual dense/MoE dispatch happens inside + # Gemma4TransformerLayer.forward, so the list contents are never + # consulted by TransformerBlock itself. If a future Megatron refactor + # starts reading the list per-layer, we need a Gemma4-specific schedule + # instead. + config.moe_layer_freq = [0] * config.num_layers + + # Mirror Megatron's own misspelling (`hetereogenous_*`) — correcting it + # would silently no-op on Megatron's read path. + # TODO(gemma4): rename to ``heterogeneous_dist_checkpoint`` when Megatron + # fixes the spelling upstream. + config.hetereogenous_dist_checkpoint = True + + config.__class__ = Gemma4TransformerConfig + config.global_kv_channels = hf_text.global_head_dim + config.global_num_query_groups = hf_text.num_global_key_value_heads + config.attention_k_eq_v = getattr(hf_text, "attention_k_eq_v", True) + config.final_logit_softcapping = getattr(hf_text, "final_logit_softcapping", 30.0) + config.sliding_window = hf_text.sliding_window + + # `sliding_window_pattern` isn't in Gemma4 HF configs — infer from + # layer_types (first full_attention layer's 1-indexed position). + layer_types = list(getattr(hf_text, "layer_types", [])) + try: + config.sliding_window_pattern = layer_types.index("full_attention") + 1 + except ValueError: + config.sliding_window_pattern = 6 + + # Q/K norms handle softmax scaling; Megatron's default of 1/sqrt(hn) is wrong. + config.softmax_scale = 1.0 + # Fused RoPE ignores zeroed inv_freq tails; we need unfused for partial-rotary. + config.apply_rope_fusion = False + + +def _apply_moe_config(config, hf_text): + """Set MoE fields if this is a MoE variant (26B-A4B).""" + config.enable_moe_block = getattr(hf_text, "enable_moe_block", False) + if not config.enable_moe_block: + return + + config.num_moe_experts = hf_text.num_experts + config.moe_router_topk = hf_text.top_k_experts + config.moe_ffn_hidden_size = hf_text.moe_intermediate_size + # Megatron MoE infrastructure reads these even though our custom router + # bypasses its scoring logic; defaults mirror a working Qwen3.5-A3B config. + config.moe_token_dispatcher_type = ( + getattr(config, "moe_token_dispatcher_type", None) or "alltoall" + ) + config.moe_grouped_gemm = getattr(config, "moe_grouped_gemm", None) or True + config.moe_aux_loss_coeff = 0.0 # Gemma4 router has no aux loss + config.moe_router_load_balancing_type = ( + getattr(config, "moe_router_load_balancing_type", None) or "none" + ) + config.moe_router_score_function = ( + getattr(config, "moe_router_score_function", None) or "softmax" + ) + config.moe_router_topk_scaling_factor = ( + getattr(config, "moe_router_topk_scaling_factor", None) or 1.0 + ) + config.moe_router_pre_softmax = False + + +def get_rope_local_base_freq(hf_text) -> float: + """Extract sliding-attention RoPE theta from an HF Gemma4 text config. + + Single source of truth for both the model provider and the mbridge + config builder — otherwise the 10000.0 default would drift between + call sites. + """ + return ( + getattr(hf_text, "rope_parameters", {}) or {} + ).get("sliding_attention", {}).get("rope_theta", 10000.0) + + +def _apply_rope_config(config, hf_text): + rope_params = getattr(hf_text, "rope_parameters", {}) or {} + config.rope_local_base_freq = get_rope_local_base_freq(hf_text) + config.global_partial_rotary_factor = ( + rope_params.get("full_attention", {}).get("partial_rotary_factor", 0.25) + ) + + +def _guard_cp_sliding_window(args, config): + """Fail if per-rank CP token cap is smaller than the sliding window. + + Strong signal of a miscounted CP sizing — we'd train on truncated + attention windows otherwise. + """ + cp_size = getattr(args, "context_parallel_size", 1) or 1 + if cp_size <= 1: + return + max_tokens = getattr(args, "max_tokens_per_gpu", None) + if max_tokens is not None and max_tokens < config.sliding_window: + raise ValueError( + f"context_parallel_size={cp_size} with max_tokens_per_gpu={max_tokens} " + f"< sliding_window={config.sliding_window}: per-rank CP chunk cap is " + "smaller than the sliding window. Reduce CP or raise max_tokens_per_gpu." + ) + + +def get_gemma4_spec(args, config, vp_stage): + """Return the native Gemma4 layer spec with proper config overrides.""" + hf_text = _load_hf_text_config(args.hf_checkpoint) + + _install_moe_warning_filter() + _assert_hf_features_supported(hf_text) + _apply_core_config(config, hf_text) + _apply_moe_config(config, hf_text) + _apply_rope_config(config, hf_text) + _guard_cp_sliding_window(args, config) + + # Build layer spec AFTER MoE config (so the MoE submodule spec attaches + # when enable_moe_block is True), then override MLP specs for HF numerics. + # The linear_fc1 override applies only to the dense MLP — on the MoE path + # the top-level `mlp` is Gemma4MoELayer and its `submodules` is an + # MoESubmodules (experts / shared_experts), which has no `linear_fc1`. + # The dense_mlp sibling already uses TEColumnParallelLinear (see + # `dense_mlp_spec` above), so on MoE layers this block is a no-op by + # design. + spec = get_gemma4_layer_spec_te(config) + from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider + if not getattr(config, "enable_moe_block", False): + spec.submodules.mlp.submodules.linear_fc1 = TEColumnParallelLinear + spec.submodules.mlp.metainfo = {"fuse_pre_mlp_layernorm": False} + spec.submodules.pre_mlp_layernorm = TESpecProvider().layer_norm() + return spec diff --git a/slime_plugins/models/gemma4_provider.py b/slime_plugins/models/gemma4_provider.py new file mode 100644 index 0000000000..e14a46ce60 --- /dev/null +++ b/slime_plugins/models/gemma4_provider.py @@ -0,0 +1,295 @@ +"""Custom model provider for Gemma4. + +Installs Gemma4-specific behaviors that sit outside the transformer layer: +- embedding scaling (multiply embeddings by sqrt(hidden_size)) +- logit softcapping (`final_logit_softcapping`) +- dual-RoPE (different rope_theta + partial-rotary for global vs sliding layers) +- layer_scalar buffers loaded from the HF checkpoint +""" +import json +import logging +import os + +import torch +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.transformer.spec_utils import import_module +from megatron.training import get_args +from megatron.training.arguments import core_transformer_config_from_args + +from slime_plugins.models.gemma4 import _load_hf_text_config + +logger = logging.getLogger(__name__) + + +def _is_rank_zero() -> bool: + """True on a single-process run, or on rank 0 when distributed.""" + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return True + return torch.distributed.get_rank() == 0 + + +def model_provider(pre_process=True, post_process=True, vp_stage=None): + args = get_args() + config = core_transformer_config_from_args(args) + + transformer_layer_spec = import_module(args.spec) + if callable(transformer_layer_spec): + transformer_layer_spec = transformer_layer_spec(args, config, vp_stage) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + ) + + _install_hooks(model, args, config, pre_process, post_process) + return model + + +class DualRotaryEmbedding(torch.nn.Module): + """Wraps a (global, local) pair of RotaryEmbedding modules and emits a + single concatenated tensor (global part first). ``Gemma4TransformerLayer`` + slices it per-layer based on ``is_sliding``. Concat (not tuple) because + Megatron's ``SelfAttention.forward`` reads a 2-tuple as + ``(self_attn, cross_attn)`` RoPE and would misread our pair. + """ + + def __init__(self, local_rope, global_rope, global_dim: int): + super().__init__() + self.local_rope = local_rope + self.global_rope = global_rope + self.global_dim = global_dim + + def get_rotary_seq_len(self, *args, **kwargs): + # Both ropes share the same sequence-length logic (they only differ in + # theta and partial-rotary); delegate to the local one. + return self.local_rope.get_rotary_seq_len(*args, **kwargs) + + def forward(self, seq_len, **kwargs): + global_emb = self.global_rope(seq_len, **kwargs) + local_emb = self.local_rope(seq_len, **kwargs) + return torch.cat([global_emb, local_emb], dim=-1) + + +def _install_hooks(model, args, config, pre_process, post_process): + """Install Gemma4-specific pre/post-process hooks on a built GPTModel. + + We use ``register_forward_hook`` rather than subclassing GPTModel + because: + - Two independent behaviors (embed scale, softcap) on two different + submodules. Subclassing would require overriding + ``GPTModel.forward`` and branching on pp/vp stage. + - The hooks are shape- and dtype-preserving, so they compose cleanly + with PP (only first-stage runs embedding, only last-stage runs + output_layer) — we gate registration on ``pre_process`` / + ``post_process`` accordingly. + - Keeps the diff local to this plugin: we don't need to shadow any + Megatron-maintained class. + """ + hf_text = _load_hf_text_config(args.hf_checkpoint) + hidden_size = config.hidden_size + + inner = model.module if hasattr(model, "module") else model + + # Embedding scaling — HF applies this inside the embedding module. + # See ``Gemma4TextScaledWordEmbedding``: the scale is stored as an fp32 + # tensor and cast to the embedding weight's dtype at forward time, so + # the scale-as-applied depends on the current weight dtype (bf16 during + # training, fp32 during some eval paths). We match that behavior here. + if pre_process and hasattr(inner, "embedding"): + embed_scale = torch.tensor(hidden_size ** 0.5) # fp32 + + def _embed_hook(module, inp, output): + return output * embed_scale.to(output.dtype) + + inner.embedding.register_forward_hook(_embed_hook) + + # Final logit softcapping — HF applies tanh(logits / cap) * cap. + # Some Megatron output_layer variants (parallel_output paths) return + # ``(logits, bias)``; we pass the non-logit tail through unchanged. + softcap = getattr(hf_text, "final_logit_softcapping", None) + if post_process and softcap and hasattr(inner, "output_layer"): + def _softcap_hook(module, inp, output): + if isinstance(output, tuple): + return (torch.tanh(output[0] / softcap) * softcap,) + output[1:] + return torch.tanh(output / softcap) * softcap + inner.output_layer.register_forward_hook(_softcap_hook) + + # Dual RoPE: replace Megatron's single rotary_pos_emb with a wrapper that + # produces (global, local) RoPE side-by-side. Gemma4 uses partial-rotary + # on global layers (implemented here by zeroing the tail of inv_freq). + if hasattr(inner, "rotary_pos_emb"): + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + + rope_params = getattr(hf_text, "rope_parameters", {}) or {} + full = rope_params.get("full_attention", {}) or {} + sliding = rope_params.get("sliding_attention", {}) or {} + global_theta = full.get("rope_theta", 1_000_000.0) + local_theta = sliding.get("rope_theta", 10_000.0) + global_head_dim = hf_text.global_head_dim + global_partial = full.get("partial_rotary_factor", 0.25) + + local_rope = inner.rotary_pos_emb # already built with args.rotary_base + + global_rope = RotaryEmbedding( + kv_channels=global_head_dim, + rotary_percent=1.0, + rotary_base=global_theta, + ) + # HF "proportional" RoPE: first (partial * head_dim // 2) inv_freq + # entries are live, the rest are zero (no rotation on those dims). + # Writing this to the existing buffer keeps device/dtype correct. + rope_angles = int(global_partial * global_head_dim // 2) + half = global_head_dim // 2 + # Guard the RoPE geometry: 0 means "no rotation" (nonsensical here); + # > half would produce nope<0 and a shape-mismatched copy_. Both + # should fail loudly rather than silently writing garbage. + assert 0 < rope_angles <= half, ( + f"global_partial_rotary_factor={global_partial} with " + f"global_head_dim={global_head_dim} produced rope_angles=" + f"{rope_angles}; must be in (0, {half}]." + ) + inv_freq_live = 1.0 / ( + global_theta ** ( + torch.arange(0, 2 * rope_angles, 2, dtype=torch.float) / global_head_dim + ) + ) + nope = half - rope_angles + inv_freq = torch.cat([inv_freq_live, torch.zeros(nope)]) if nope > 0 else inv_freq_live + assert inv_freq.shape == global_rope.inv_freq.shape, ( + f"inv_freq shape {tuple(inv_freq.shape)} doesn't match " + f"global_rope.inv_freq shape {tuple(global_rope.inv_freq.shape)}; " + "Megatron RotaryEmbedding layout may have changed." + ) + global_rope.inv_freq.copy_(inv_freq.to(global_rope.inv_freq.device)) + + inner.rotary_pos_emb = DualRotaryEmbedding(local_rope, global_rope, global_head_dim) + # Layers split the concatenated tensor by this dim. + config.dual_rope_global_dim = global_head_dim + if _is_rank_zero(): + logger.info( + "DualRotaryEmbedding: local_theta=%s global_theta=%s " + "global_dim=%s rope_angles=%d (nope=%d)", + local_theta, global_theta, global_head_dim, rope_angles, nope, + ) + + # Load layer scalars from the HF checkpoint. These are buffers, not + # parameters, and are applied once per layer after the MoE/MLP block. + if hasattr(inner, "decoder") and args.hf_checkpoint: + _load_layer_scalars(inner, args.hf_checkpoint, config) + + +def _read_layer_scalars_from_safetensors(hf_checkpoint: str) -> dict[int, float] | None: + """Read all ``layer_scalar`` values from the HF safetensors checkpoint. + + Returns ``{global_layer_idx: scalar}`` or ``None`` if the checkpoint has + no safetensors index (older HF layouts) or no layer_scalar weights. Only + called on rank 0 — results are broadcast to the other ranks. + """ + index_path = os.path.join(hf_checkpoint, "model.safetensors.index.json") + if not os.path.exists(index_path): + logger.warning("No safetensors index at %s; skipping layer scalars", index_path) + return None + + from safetensors import safe_open + + with open(index_path) as f: + index = json.load(f) + + scalars: dict[int, float] = {} + for key, filename in index["weight_map"].items(): + if "layer_scalar" not in key: + continue + layer_idx = int(key.split(".layers.")[1].split(".")[0]) + with safe_open(os.path.join(hf_checkpoint, filename), framework="pt", device="cpu") as sf: + scalars[layer_idx] = sf.get_tensor(key).item() + + if not scalars: + logger.warning("No layer_scalar weights found in checkpoint %s", hf_checkpoint) + return None + return scalars + + +def _broadcast_layer_scalars(scalars: dict[int, float] | None) -> dict[int, float] | None: + """Broadcast the rank-0-read ``scalars`` dict to every rank. + + safetensors reads on every rank cause an O(world_size) fan-out of tiny + reads on the shared filesystem; the dict itself is a few kilobytes. If + ``torch.distributed`` isn't initialized (single-process run), we simply + return the input dict. + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return scalars + obj = [scalars] if torch.distributed.get_rank() == 0 else [None] + torch.distributed.broadcast_object_list(obj, src=0) + return obj[0] + + +def _load_layer_scalars(inner, hf_checkpoint, config): + # Wrong layer_scalars materially change activations vs HF (they're per- + # layer multiplicative gains on the residual stream, not decorative), so + # by default we fail hard if the load breaks. Set + # GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1 to downgrade to a warning and + # train with the default value of 1.0 — only useful for debug runs + # against a checkpoint that genuinely lacks these buffers. + allow_missing = os.environ.get("GEMMA4_ALLOW_MISSING_LAYER_SCALARS") == "1" + try: + scalars = _read_layer_scalars_from_safetensors(hf_checkpoint) if _is_rank_zero() else None + scalars = _broadcast_layer_scalars(scalars) + if not scalars: + if allow_missing: + return + raise RuntimeError( + "No layer_scalar weights found in checkpoint; set " + "GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1 to proceed with " + "default values (not numerically equivalent to HF)." + ) + + # Under pipeline-parallelism, inner.decoder.layers holds only this + # rank's local subset. Translate the local index back to the global + # (HF 0-indexed) layer index so we apply the right scalar per layer. + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + pp_offset = get_transformer_layer_offset(config) + + loaded = 0 + for i, layer in enumerate(inner.decoder.layers): + if hasattr(layer, "layer_scalar"): + global_idx = i + pp_offset + if global_idx not in scalars: + if allow_missing: + logger.warning( + "layer_scalar for global layer %d missing; using default 1.0", + global_idx, + ) + else: + raise KeyError( + f"layer_scalar for global layer {global_idx} " + f"missing in checkpoint (have: {sorted(scalars)[:10]}...); " + "checkpoint may be truncated." + ) + layer.layer_scalar.fill_(scalars.get(global_idx, 1.0)) + loaded += 1 + if _is_rank_zero(): + logger.info( + "Applied %d/%d layer scalars (pp_offset=%d, range=%.4f..%.4f)", + loaded, len(inner.decoder.layers), pp_offset, + min(scalars.values()), max(scalars.values()), + ) + except (FileNotFoundError, json.JSONDecodeError) as e: + # These are recoverable: older HF layouts have no safetensors index, + # or json is malformed. Warn and fall back. + if allow_missing: + logger.warning("layer scalars unavailable (%s: %s); using default 1.0", + type(e).__name__, e) + return + raise diff --git a/tests/gemma4/test_gemma4_attention.py b/tests/gemma4/test_gemma4_attention.py new file mode 100644 index 0000000000..0c4119fc76 --- /dev/null +++ b/tests/gemma4/test_gemma4_attention.py @@ -0,0 +1,160 @@ +"""Unit tests for ``Gemma4SelfAttention.get_query_key_value_tensors``. + +These tests target the two code paths independently: + +- **Global K=V**: ``_split_qkv_global_k_eq_v`` must produce + ``key = k_norm(raw_k)`` and ``value = v_norm(raw_k)`` without mutating + ``self.k_layernorm`` and without going through the parent class. We + construct a minimal stub of the parent attribs that the method reads. + +- **Sliding**: delegates to the parent, and then applies ``v_norm`` to V. + Covered indirectly in ``test_gemma4_cp_attention.py``; here we only verify that + the non-global path produces a V tensor that equals ``v_norm(raw_v)``. +""" + +from types import SimpleNamespace + +import pytest +import torch + +from slime_plugins.models.gemma4 import Gemma4SelfAttention, VNorm + + +def _stub_attention(num_attention_heads, num_kv_heads, head_dim, hidden_size): + """Build a Gemma4SelfAttention-compatible stub without invoking + ``SelfAttention.__init__`` (which needs TE / process group / etc.). + + We populate only the attributes that ``_split_qkv_global_k_eq_v`` and + ``get_query_key_value_tensors`` read. + """ + attn = object.__new__(Gemma4SelfAttention) + # Gemma4SelfAttention is a nn.Module (through SelfAttention); we must + # initialize the Module machinery before assigning submodules. + torch.nn.Module.__init__(attn) + + # Deterministic linear_qkv: output width = num_kv_heads * (q_per_kv + 2) * head_dim + q_per_kv = num_attention_heads // num_kv_heads + out_width = num_kv_heads * (q_per_kv + 2) * head_dim + linear_qkv = torch.nn.Linear(hidden_size, out_width, bias=False) + torch.nn.init.normal_(linear_qkv.weight, std=0.02) + + def _linear_qkv(h): + # Megatron's ColumnParallelLinear returns (out, bias); mimic that. + return linear_qkv(h), None + + attn.linear_qkv = _linear_qkv + attn.num_attention_heads_per_partition = num_attention_heads + attn.num_query_groups_per_partition = num_kv_heads + attn.hidden_size_per_attention_head = head_dim + # Learnable k_norm / q_norm so their effect is visible. + attn.q_layernorm = torch.nn.LayerNorm(head_dim) + attn.k_layernorm = torch.nn.LayerNorm(head_dim) + attn.v_norm = VNorm(head_dim, eps=1e-6) + attn.config = SimpleNamespace( + layernorm_epsilon=1e-6, + attention_k_eq_v=True, + ) + attn._is_global = False # flipped per-test + return attn, linear_qkv + + +def test_global_k_eq_v_produces_k_norm_and_v_norm_of_raw_k(): + """With _is_global=True, value must equal v_norm(raw_k_proj) — NOT + v_norm(k_norm(raw_k_proj)). This guards the bug where the parent-class + path would pre-normalize K before we extract V from it.""" + torch.manual_seed(0) + num_attention_heads, num_kv_heads, head_dim, hidden_size = 8, 2, 512, 256 + attn, linear_qkv = _stub_attention(num_attention_heads, num_kv_heads, head_dim, hidden_size) + attn._is_global = True + + seq_len, batch = 4, 1 + hidden = torch.randn(seq_len, batch, hidden_size) + + query, key, value = attn.get_query_key_value_tensors(hidden) + + assert query.shape == (seq_len, batch, num_attention_heads, head_dim) + assert key.shape == (seq_len, batch, num_kv_heads, head_dim) + assert value.shape == (seq_len, batch, num_kv_heads, head_dim) + + # Recompute expected tensors independently. We use the same linear_qkv + # weights + norms to derive ground truth. + mixed, _ = attn.linear_qkv(hidden) + q_per_kv = num_attention_heads // num_kv_heads + mixed = mixed.view(seq_len, batch, num_kv_heads, (q_per_kv + 2) * head_dim) + q_width = q_per_kv * head_dim + raw_q, raw_k, _raw_v = torch.split(mixed, [q_width, head_dim, head_dim], dim=3) + raw_q = raw_q.reshape(seq_len, batch, -1, head_dim) + + expected_query = attn.q_layernorm(raw_q) + expected_key = attn.k_layernorm(raw_k) + expected_value = attn.v_norm(raw_k) # v_norm applied to RAW k, not k_norm(raw_k) + + assert torch.allclose(query, expected_query), "query mismatch" + assert torch.allclose(key, expected_key), "key must be k_norm(raw_k)" + assert torch.allclose(value, expected_value), ( + "value must be v_norm(raw_k) — if this fails, v is being derived from " + "k_norm(raw_k) instead of raw_k" + ) + + +def test_global_k_eq_v_does_not_mutate_k_layernorm(): + """Before the refactor, the implementation set ``self.k_layernorm = None`` + around the parent call. Any exception in the parent would leak that + mutation permanently, and concurrent construction of sibling layers + would race. Confirm the attribute is untouched across a forward.""" + torch.manual_seed(1) + attn, _ = _stub_attention(8, 2, 512, 256) + attn._is_global = True + + k_layernorm_before = attn.k_layernorm + hidden = torch.randn(3, 1, 256) + _ = attn.get_query_key_value_tensors(hidden) + assert attn.k_layernorm is k_layernorm_before + + +def test_global_k_eq_v_rejects_output_gate(): + """output_gate is incompatible with the K=V split (the parent class uses + a different tensor shape for gated attention). Ensure we fail loudly.""" + attn, _ = _stub_attention(8, 2, 512, 256) + attn._is_global = True + with pytest.raises(NotImplementedError): + attn.get_query_key_value_tensors(torch.randn(3, 1, 256), output_gate=True) + + +def test_sliding_layer_applies_v_norm_to_value(): + """For non-global (sliding) layers, value comes from the standard QKV + split and then gets v_norm applied. Verify the non-K=V path.""" + torch.manual_seed(2) + num_attention_heads, num_kv_heads, head_dim, hidden_size = 8, 2, 256, 256 + attn, linear_qkv = _stub_attention(num_attention_heads, num_kv_heads, head_dim, hidden_size) + attn._is_global = False # sliding — parent path + v_norm + + # Patch the parent's call to return a known (q, k, v) triple. We avoid + # calling the real SelfAttention.get_query_key_value_tensors (needs full + # Megatron init) by monkey-patching just the bound method lookup. + seq_len, batch = 3, 1 + raw_q = torch.randn(seq_len, batch, num_attention_heads, head_dim) + raw_k = torch.randn(seq_len, batch, num_kv_heads, head_dim) + raw_v = torch.randn(seq_len, batch, num_kv_heads, head_dim) + + def _fake_parent(*_a, **_k): + return raw_q, raw_k, raw_v + + # Call the method directly via the base-class path we want to test. We + # cannot use super() without a real SelfAttention, so re-execute the + # split logic as written in get_query_key_value_tensors by calling with + # monkey-patched super(). + import unittest.mock as mock + from megatron.core.transformer.attention import SelfAttention as _Base + with mock.patch.object(_Base, "get_query_key_value_tensors", _fake_parent): + query, key, value = attn.get_query_key_value_tensors( + torch.randn(seq_len, batch, hidden_size) + ) + + assert torch.equal(query, raw_q) + assert torch.equal(key, raw_k) + assert torch.allclose(value, attn.v_norm(raw_v)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_bridge.py b/tests/gemma4/test_gemma4_bridge.py new file mode 100644 index 0000000000..145fe46967 --- /dev/null +++ b/tests/gemma4/test_gemma4_bridge.py @@ -0,0 +1,275 @@ +"""Parity tests for slime_plugins.mbridge.gemma4 and +slime/backends/megatron_utils/megatron_to_hf/gemma4.py. + +These exercise the ACTUAL production functions (Gemma4Bridge and +convert_gemma4_to_hf) rather than re-implementing the pack/unpack. +""" + +import importlib +import importlib.util +import pathlib +from types import SimpleNamespace + +import pytest +import torch + + +def _load_convert_module(): + """Import the weight-conversion module either from the installed slime + package or from the repo's working copy relative to this test file.""" + try: + return importlib.import_module("slime.backends.megatron_utils.megatron_to_hf.gemma4") + except ImportError: + pass + repo_path = pathlib.Path(__file__).resolve().parents[2] / ( + "slime/backends/megatron_utils/megatron_to_hf/gemma4.py" + ) + if not repo_path.exists(): + pytest.skip(f"convert_gemma4_to_hf source not found at {repo_path}") + spec = importlib.util.spec_from_file_location("_gemma4_conv_under_test", repo_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# Gemma4-31B canonical config values. +CFG_31B = SimpleNamespace( + hidden_size=5376, + num_attention_heads=32, + head_dim=256, + num_key_value_heads=16, + global_head_dim=512, + num_global_key_value_heads=4, + num_hidden_layers=60, + attention_k_eq_v=True, + layer_types=(["sliding_attention"] * 5 + ["full_attention"]) * 10, +) + + +def _pack_local_qkv(q, k, v): + """Megatron packs as [num_kv_heads, (q_per_kv + 2) * head_dim, hidden].""" + num_kv = CFG_31B.num_key_value_heads + head_dim = CFG_31B.head_dim + q_per_kv = CFG_31B.num_attention_heads // num_kv + q = q.view(num_kv, q_per_kv * head_dim, CFG_31B.hidden_size) + k = k.view(num_kv, head_dim, CFG_31B.hidden_size) + v = v.view(num_kv, head_dim, CFG_31B.hidden_size) + return torch.cat([q, k, v], dim=1).reshape(-1, CFG_31B.hidden_size).contiguous() + + +def _pack_global_qkv(q, k): + """K=V global layers: stored as [q, k, k].""" + num_kv = CFG_31B.num_global_key_value_heads + head_dim = CFG_31B.global_head_dim + q_per_kv = CFG_31B.num_attention_heads // num_kv + q = q.view(num_kv, q_per_kv * head_dim, CFG_31B.hidden_size) + k = k.view(num_kv, head_dim, CFG_31B.hidden_size) + return torch.cat([q, k, k], dim=1).reshape(-1, CFG_31B.hidden_size).contiguous() + + +def test_convert_gemma4_to_hf_local_layer_roundtrip(monkeypatch): + """Load convert_gemma4_to_hf and verify roundtrip for a local layer.""" + conv = _load_convert_module() + + # Prime the config cache so we don't need a real HF checkpoint on disk. + conv._config_cache["config"] = { + "global_attn_layers": {i for i, t in enumerate(CFG_31B.layer_types) if t == "full_attention"}, + "local_head_dim": CFG_31B.head_dim, + "global_head_dim": CFG_31B.global_head_dim, + "num_attention_heads": CFG_31B.num_attention_heads, + "local_num_kv_heads": CFG_31B.num_key_value_heads, + "global_num_kv_heads": CFG_31B.num_global_key_value_heads, + "hidden_size": CFG_31B.hidden_size, + } + + # Build a random local qkv and convert it. + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + v = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + packed = _pack_local_qkv(q, k, v) + + # Layer 0 is sliding (local). + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, "module.module.decoder.layers.0.self_attention.linear_qkv.weight", packed, + ) + names = {n for n, _ in emitted} + assert names == { + "model.language_model.layers.0.self_attn.q_proj.weight", + "model.language_model.layers.0.self_attn.k_proj.weight", + "model.language_model.layers.0.self_attn.v_proj.weight", + } + out = dict(emitted) + assert torch.allclose(out["model.language_model.layers.0.self_attn.q_proj.weight"], q) + assert torch.allclose(out["model.language_model.layers.0.self_attn.k_proj.weight"], k) + assert torch.allclose(out["model.language_model.layers.0.self_attn.v_proj.weight"], v) + + +def test_convert_gemma4_to_hf_global_layer_emits_no_v_proj(): + conv = _load_convert_module() + + conv._config_cache["config"] = { + "global_attn_layers": {5, 11, 17, 23, 29, 35, 41, 47, 53, 59}, + "local_head_dim": CFG_31B.head_dim, + "global_head_dim": CFG_31B.global_head_dim, + "num_attention_heads": CFG_31B.num_attention_heads, + "local_num_kv_heads": CFG_31B.num_key_value_heads, + "global_num_kv_heads": CFG_31B.num_global_key_value_heads, + "hidden_size": CFG_31B.hidden_size, + } + + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_global_key_value_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + packed = _pack_global_qkv(q, k) + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, "module.module.decoder.layers.5.self_attention.linear_qkv.weight", packed, + ) + names = {n for n, _ in emitted} + # Global K=V layers: only q and k are emitted, v is absent. + assert names == { + "model.language_model.layers.5.self_attn.q_proj.weight", + "model.language_model.layers.5.self_attn.k_proj.weight", + } + + +def test_global_layer_index_matches_layer_types(): + """The 1-indexed layer_number mod 6 == 0 heuristic must agree with HF's + authoritative `layer_types`. This test guards against future Gemma4 variants + shipping with a non-regular pattern. + """ + # Build layer_types that mirror the production 31B config. + lt = [] + for i in range(CFG_31B.num_hidden_layers): + # 0-indexed: global layers are at 5, 11, 17, ... (every 6th) + lt.append("full_attention" if (i + 1) % 6 == 0 else "sliding_attention") + # Ensure our heuristic picks the same indices. + heuristic = {i for i in range(CFG_31B.num_hidden_layers) if (i + 1) % 6 == 0} + truth = {i for i, t in enumerate(lt) if t == "full_attention"} + assert heuristic == truth + + +def test_mlp_gate_up_roundtrip(): + """linear_fc1 in Megatron packs [gate, up] along dim 0.""" + gate = torch.randn(21504, CFG_31B.hidden_size) + up = torch.randn(21504, CFG_31B.hidden_size) + fused = torch.cat([gate, up], dim=0) + gate2, up2 = fused.chunk(2, dim=0) + assert torch.equal(gate, gate2) and torch.equal(up, up2) + + +def test_convert_gemma4_to_hf_moe_expert_weights_stacked(): + """Per-expert fc1/fc2 weights (from TEGroupedLinear) stream in one at a time + and the converter buffers them until all experts arrive, then emits a single + stacked 3D tensor matching sglang's Gemma4 loader expectation. + """ + conv = _load_convert_module() + num_experts = 4 # keep test fast + conv._config_cache["config"] = { + "global_attn_layers": {5}, + "local_head_dim": 256, "global_head_dim": 512, + "num_attention_heads": 16, + "local_num_kv_heads": 8, "global_num_kv_heads": 2, + "hidden_size": 2816, + "num_experts": num_experts, + } + conv._expert_buffers.clear() + args = SimpleNamespace(hf_checkpoint="/nonexistent") + + # Stream all 4 experts' linear_fc1 tensors. Only the last should emit. + fc1_tensors = [torch.randn(2 * 704, 2816) for _ in range(num_experts)] + emitted_total = [] + for e, t in enumerate(fc1_tensors): + out = conv.convert_gemma4_to_hf( + args, f"module.module.decoder.layers.3.mlp.experts.linear_fc1.weight{e}", t, + ) + emitted_total.append(out) + # Only the last flush produces output. + assert all(len(out) == 0 for out in emitted_total[:-1]) + last = emitted_total[-1] + assert len(last) == 1 + name, stacked = last[0] + assert name == "model.language_model.layers.3.experts.gate_up_proj" + assert stacked.shape == (num_experts, 2 * 704, 2816) + for e, t in enumerate(fc1_tensors): + assert torch.equal(stacked[e], t) + + # Same for linear_fc2. + fc2_tensors = [torch.randn(2816, 704) for _ in range(num_experts)] + emitted_total = [] + for e, t in enumerate(fc2_tensors): + out = conv.convert_gemma4_to_hf( + args, f"module.module.decoder.layers.3.mlp.experts.linear_fc2.weight{e}", t, + ) + emitted_total.append(out) + assert all(len(out) == 0 for out in emitted_total[:-1]) + last = emitted_total[-1] + assert len(last) == 1 + name, stacked = last[0] + assert name == "model.language_model.layers.3.experts.down_proj" + assert stacked.shape == (num_experts, 2816, 704) + for e, t in enumerate(fc2_tensors): + assert torch.equal(stacked[e], t) + + +def test_convert_gemma4_to_hf_moe_router_weights(): + """Router lives at `.mlp.router.*` under the MoE variant since + `self.mlp = Gemma4MoELayer` for MoE-enabled layers.""" + conv = _load_convert_module() + conv._config_cache["config"] = { + "global_attn_layers": {5}, + "local_head_dim": 256, "global_head_dim": 512, + "num_attention_heads": 16, + "local_num_kv_heads": 8, "global_num_kv_heads": 2, + "hidden_size": 2816, + } + args = SimpleNamespace(hf_checkpoint="/nonexistent") + for mcore_rest, hf_tail in [ + ("mlp.router.proj.weight", "router.proj.weight"), + ("mlp.router.scale", "router.scale"), + ("mlp.router.per_expert_scale", "router.per_expert_scale"), + ]: + param = torch.randn(4) + emitted = conv.convert_gemma4_to_hf( + args, f"module.module.decoder.layers.3.{mcore_rest}", param, + ) + assert len(emitted) == 1 + assert emitted[0][0] == f"model.language_model.layers.3.{hf_tail}" + + +def test_convert_gemma4_to_hf_dense_mlp_sibling(): + """Under MoE variant the parallel dense MLP lives at `.dense_mlp.*`; it + must still map to HF's `.mlp.*` since HF uses a single MLP naming.""" + conv = _load_convert_module() + conv._config_cache["config"] = { + "global_attn_layers": set(), + "local_head_dim": 256, "global_head_dim": 512, + "num_attention_heads": 16, + "local_num_kv_heads": 8, "global_num_kv_heads": 2, + "hidden_size": 2816, + } + args = SimpleNamespace(hf_checkpoint="/nonexistent") + + gate = torch.randn(2112, 2816) + up = torch.randn(2112, 2816) + fused = torch.cat([gate, up], dim=0) + + emitted = conv.convert_gemma4_to_hf( + args, "module.module.decoder.layers.0.dense_mlp.linear_fc1.weight", fused, + ) + names = {n for n, _ in emitted} + assert names == { + "model.language_model.layers.0.mlp.gate_proj.weight", + "model.language_model.layers.0.mlp.up_proj.weight", + } + + down = torch.randn(2816, 2112) + emitted = conv.convert_gemma4_to_hf( + args, "module.module.decoder.layers.0.dense_mlp.linear_fc2.weight", down, + ) + assert emitted == [("model.language_model.layers.0.mlp.down_proj.weight", down)] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_cp_attention.py b/tests/gemma4/test_gemma4_cp_attention.py new file mode 100644 index 0000000000..4c12c91166 --- /dev/null +++ b/tests/gemma4/test_gemma4_cp_attention.py @@ -0,0 +1,251 @@ +"""Parity tests for SDPACoreAttention. + +These tests exercise the PRODUCTION code paths in slime_plugins.models.gemma4 +— not re-implementations — so divergence between the dispatch logic in +`forward` and the hand-written CP math would be caught here. + +Runs on a single GPU (or CPU, but flash-attn is skipped). A distributed +group is faked via `torch.distributed` with world_size=1, which turns the +CP all-gather into a no-op and lets us re-use the CP code on a single +device while still hitting the differentiable-gather autograd path. +""" + +import os + +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +@pytest.fixture(scope="module", autouse=True) +def _init_dist(): + """Initialize a single-rank process group so CP code can call into it.""" + if dist.is_initialized(): + yield + return + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29555") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, rank=0, world_size=1) + try: + # Megatron's parallel_state needs to be initialized for CP helpers to work. + try: + from megatron.core import parallel_state as mpu + mpu.initialize_model_parallel(context_parallel_size=1) + except Exception: + pass + yield + finally: + dist.destroy_process_group() + + +def _ref_attention(query, key, value, cu_seqlens, scale, sliding_window=None): + """Ground-truth varlen causal (optionally sliding) attention via SDPA, + with manual GQA expansion. query/key/value in [T, n, h] format.""" + t = query.shape[0] + nq, nk = query.shape[1], key.shape[1] + q = query.unsqueeze(0).transpose(1, 2).float() # [1, n, T, h] + k = key.unsqueeze(0).transpose(1, 2).float() + v = value.unsqueeze(0).transpose(1, 2).float() + if nq != nk: + k = k.repeat_interleave(nq // nk, dim=1) + v = v.repeat_interleave(nq // nk, dim=1) + + mask = torch.full((t, t), float("-inf"), device=query.device, dtype=torch.float32) + for i in range(len(cu_seqlens) - 1): + s, e = int(cu_seqlens[i]), int(cu_seqlens[i + 1]) + for qi in range(s, e): + lo = s if sliding_window is None else max(s, qi - sliding_window + 1) + mask[qi, lo:qi + 1] = 0.0 + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask[None, None, :, :], scale=scale) + return out.transpose(1, 2).reshape(t, -1).to(query.dtype) + + +def _make_core_attention(sliding_window: int | None, softmax_scale: float): + """Build an SDPACoreAttention without going through Megatron's spec system.""" + from types import SimpleNamespace + from slime_plugins.models.gemma4 import SDPACoreAttention + + config = SimpleNamespace( + attention_dropout=0.0, + sliding_window=sliding_window or 1024, + context_parallel_size=1, + ) + core = SDPACoreAttention( + config=config, layer_number=1, attn_mask_type=None, softmax_scale=softmax_scale, + ) + core._is_sliding = sliding_window is not None + return core + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_global_thd_sdpa_per_subseq_matches_reference(): + """Global layer (head_dim=512), no CP, packed thd — uses the SDPA per-sub-seq path.""" + torch.manual_seed(0) + device = "cuda" + dtype = torch.float32 # fp32 for exact parity + + nq, nk, hn = 8, 2, 512 + scale = 1.0 / (hn ** 0.5) + lens = [13, 20, 7] + cu = torch.tensor([0] + list(__import__("itertools").accumulate(lens)), dtype=torch.int32, device=device) + t = int(cu[-1]) + q = torch.randn(t, nq, hn, device=device, dtype=dtype) + k = torch.randn(t, nk, hn, device=device, dtype=dtype) + v = torch.randn(t, nk, hn, device=device, dtype=dtype) + + ref = _ref_attention(q, k, v, cu, scale=scale) + + core = _make_core_attention(sliding_window=None, softmax_scale=scale) + out = core._forward_thd_sdpa_per_subseq(q, k, v, cu) + assert out.shape == (t, nq * hn) + + cos = F.cosine_similarity(ref.flatten().unsqueeze(0), out.flatten().unsqueeze(0)).item() + assert cos > 0.9999, f"global SDPA per-sub-seq mismatch, cosine={cos}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_flash_thd_with_sliding_window(): + """Sliding-window layer, head_dim=256, thd — must apply sliding window.""" + try: + import flash_attn # noqa + except ImportError: + pytest.skip("flash_attn not installed") + + torch.manual_seed(1) + device = "cuda" + dtype = torch.bfloat16 + + nq, nk, hn = 16, 8, 256 + scale = 1.0 / (hn ** 0.5) + lens = [1200, 800] # > sliding_window on the first sequence + cu = torch.tensor([0] + list(__import__("itertools").accumulate(lens)), dtype=torch.int32, device=device) + t = int(cu[-1]) + q = torch.randn(t, nq, hn, device=device, dtype=dtype) + k = torch.randn(t, nk, hn, device=device, dtype=dtype) + v = torch.randn(t, nk, hn, device=device, dtype=dtype) + + core = _make_core_attention(sliding_window=1024, softmax_scale=scale) + out = core._forward_thd_flash(q, k, v, cu) + assert out.shape == (t, nq * hn) + assert not torch.isnan(out).any() + + # Reference: apply sliding window in the mask. + ref = _ref_attention(q.float(), k.float(), v.float(), cu, scale=scale, sliding_window=1024) + cos = F.cosine_similarity(ref.flatten().unsqueeze(0), out.float().flatten().unsqueeze(0)).item() + assert cos > 0.999, f"flash+sliding mismatch, cosine={cos}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_forward_dispatches_correctly_by_layer_type_and_headdim(): + """Smoke-test: `forward()` picks the right internal path without crashing.""" + torch.manual_seed(2) + device = "cuda" + dtype = torch.bfloat16 + + # Construct a packed_seq_params-like object. + from types import SimpleNamespace + cu = torch.tensor([0, 64, 192], dtype=torch.int32, device=device) + packed = SimpleNamespace(cu_seqlens_q=cu) + + # Case 1: sliding layer, head_dim=256 → flash path. + core = _make_core_attention(sliding_window=1024, softmax_scale=1.0 / (256 ** 0.5)) + q = torch.randn(192, 8, 256, device=device, dtype=dtype) + k = torch.randn(192, 4, 256, device=device, dtype=dtype) + v = torch.randn(192, 4, 256, device=device, dtype=dtype) + out = core.forward(q, k, v, packed_seq_params=packed) + assert out.shape == (192, 8 * 256) + assert not torch.isnan(out).any() + + # Case 2: global layer, head_dim=512, CP=1 → SDPA per-sub-seq. + core_g = _make_core_attention(sliding_window=None, softmax_scale=1.0 / (512 ** 0.5)) + qg = torch.randn(192, 8, 512, device=device, dtype=dtype) + kg = torch.randn(192, 2, 512, device=device, dtype=dtype) + vg = torch.randn(192, 2, 512, device=device, dtype=dtype) + out = core_g.forward(qg, kg, vg, packed_seq_params=packed) + assert out.shape == (192, 8 * 512) + assert not torch.isnan(out).any() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_cp_global_gradient_flow_end_to_end(): + """With world_size=1, `_forward_cp_subseq_mask` becomes a pass-through but + must still let gradients flow through K, V, Q without the + deprecated-autograd warning or NaNs. + """ + torch.manual_seed(3) + device = "cuda" + dtype = torch.float32 + + nq, nk, hn = 8, 2, 512 + scale = 1.0 / (hn ** 0.5) + # Each sub-seq length must be divisible by 2*cp_size; cp_size=2 -> div by 4. + cu = torch.tensor([0, 32, 96], dtype=torch.int32, device=device) + t = int(cu[-1]) + from types import SimpleNamespace + packed = SimpleNamespace(cu_seqlens_q=cu) + q = torch.randn(t, nq, hn, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(t, nk, hn, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(t, nk, hn, device=device, dtype=dtype, requires_grad=True) + + core = _make_core_attention(sliding_window=None, softmax_scale=scale) + # Force the CP-global path by toggling cp_size on the config AND checking + # that Megatron's CP world size is 1. With world_size=1, this exercises + # the differentiable gather_from_sequence_parallel_region path. + core.config.context_parallel_size = 2 # bluffs the dispatch into the CP path + try: + out = core._forward_cp_subseq_mask(q, k, v, packed, sliding_window=None) + except Exception: + pytest.skip("Megatron parallel_state not initialized; skipping CP path smoke test") + + assert out.shape == (t, nq * hn) + assert not torch.isnan(out).any() + out.sum().backward() + assert q.grad is not None and not torch.isnan(q.grad).any() + assert k.grad is not None and not torch.isnan(k.grad).any() + assert v.grad is not None and not torch.isnan(v.grad).any() + # Critical: K and V grads must be non-zero. The old raw + # `dist.all_gather_into_tensor` path dropped these grads on non-owning + # ranks; with world_size=1 it would still work, but on the differentiable + # path we check the signal is propagating at all. + assert (k.grad.abs() > 0).any() + assert (v.grad.abs() > 0).any() + + +def test_zigzag_global_indices_cp1_is_identity(): + """With cp_size=1, the zig-zag permutation should be identity (rank 0 + holds the whole sequence).""" + from slime_plugins.models.gemma4 import SDPACoreAttention + device = torch.device("cpu") + # local_len = total local tokens = 8, cp_rank=0, cp_size=1 -> cs=4 + idx = SDPACoreAttention._zigzag_global_indices( + local_len=8, cp_rank=0, cp_size=1, device=device, + ) + # With cp=1: rank 0 owns global [0, cs) ++ [cs, 2*cs) = [0, 8) + assert idx.tolist() == list(range(8)) + + +def test_zigzag_global_indices_cp2_matches_slime_slice(): + """Verify the zig-zag map matches slime's slice_with_cp convention for + cp_size=2: rank 0 owns chunks (0, 3), rank 1 owns chunks (1, 2).""" + from slime_plugins.models.gemma4 import SDPACoreAttention + device = torch.device("cpu") + # 4 chunks of size 4 -> total 16 tokens. cp=2, so chunk_size = 4. + # Rank 0 local tokens = 8 = 2*cs -> global [0,4) ++ [12,16) + # Rank 1 local tokens = 8 = 2*cs -> global [4,8) ++ [8,12) + idx_r0 = SDPACoreAttention._zigzag_global_indices( + local_len=8, cp_rank=0, cp_size=2, device=device, + ) + idx_r1 = SDPACoreAttention._zigzag_global_indices( + local_len=8, cp_rank=1, cp_size=2, device=device, + ) + assert idx_r0.tolist() == [0, 1, 2, 3, 12, 13, 14, 15] + assert idx_r1.tolist() == [4, 5, 6, 7, 8, 9, 10, 11] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_dual_rope.py b/tests/gemma4/test_gemma4_dual_rope.py new file mode 100644 index 0000000000..5ddf65eae5 --- /dev/null +++ b/tests/gemma4/test_gemma4_dual_rope.py @@ -0,0 +1,159 @@ +"""Unit tests for ``DualRotaryEmbedding``. + +``DualRotaryEmbedding`` wraps a pair of Megatron ``RotaryEmbedding`` modules +— one for sliding layers (local), one for global layers — and produces a +single concatenated tensor so downstream code that expects a single rope +output (distributed checkpointing, CP sharding) continues to work. +``Gemma4TransformerLayer.forward`` splits it back per-layer based on whether +the layer is sliding or global. + +We test two things: + +1. Concat/split semantics on synthetic tensors (CPU-only). This is the + novel logic ``DualRotaryEmbedding`` actually adds beyond Megatron's own + ``RotaryEmbedding``. +2. End-to-end forward with real Megatron ``RotaryEmbedding`` modules — the + only CUDA-gated part because Megatron forces ``inv_freq`` to CUDA inside + ``get_emb``. +""" + +import importlib.util +import pathlib +import sys + +import pytest +import torch + + +def _load_dual_rotary_embedding(): + """Import ``DualRotaryEmbedding`` without triggering the module-level + ``from megatron.training import get_args`` that ``gemma4_provider`` does + at import time (unavailable in minimal test containers). We exec the + module with the unavailable import stubbed out.""" + import types + + if "megatron.training" not in sys.modules: + stub = types.ModuleType("megatron.training") + stub.get_args = lambda: None + sys.modules["megatron.training"] = stub + if "megatron.training.arguments" not in sys.modules: + stub2 = types.ModuleType("megatron.training.arguments") + stub2.core_transformer_config_from_args = lambda *a, **k: None + sys.modules["megatron.training.arguments"] = stub2 + + repo_path = pathlib.Path(__file__).resolve().parents[2] / ( + "slime_plugins/models/gemma4_provider.py" + ) + spec = importlib.util.spec_from_file_location( + "_gemma4_provider_under_test", repo_path + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.DualRotaryEmbedding + + +DualRotaryEmbedding = _load_dual_rotary_embedding() + + +class _FakeRope: + """Minimal RotaryEmbedding stand-in that returns a deterministic tensor + encoding its own identity — so we can verify that the global slice + really came from the global rope (and vice versa).""" + + def __init__(self, dim: int, tag: float): + self.dim = dim + self.tag = tag + + def __call__(self, seq_len, **kwargs): + # Shape [seq_len, 1, 1, dim] — same layout Megatron produces. + # Value: seq-index * 100 + dim-index, plus a per-rope tag so we can + # tell global from local apart slice-by-slice. + s = torch.arange(seq_len, dtype=torch.float).view(seq_len, 1, 1, 1) + d = torch.arange(self.dim, dtype=torch.float).view(1, 1, 1, self.dim) + return s * 100.0 + d + self.tag + + def get_rotary_seq_len(self, *args, **kwargs): + # Sentinel used in the delegation test. + return ("fake_seq_len_result", args, kwargs) + + +def test_dual_rope_concat_shape_global_first(): + local = _FakeRope(dim=256, tag=0.1) + glob = _FakeRope(dim=512, tag=0.9) + dual = DualRotaryEmbedding(local, glob, global_dim=512) + + seq_len = 16 + combined = dual(seq_len) + assert combined.shape == (seq_len, 1, 1, 512 + 256) + + # Verify the GLOBAL slice came from the global rope (tag=0.9, dim=512) + # and the local slice from the local rope (tag=0.1, dim=256). + # Global first. + global_slice = combined[..., :512] + local_slice = combined[..., 512:] + assert torch.equal(global_slice, glob(seq_len)) + assert torch.equal(local_slice, local(seq_len)) + + +def test_dual_rope_split_matches_layer_convention(): + """The concat format must round-trip through the layer's split + convention: + + rotary_pos_emb[..., :global_dim] -> global layers + rotary_pos_emb[..., global_dim:] -> sliding layers + + This is a regression guard against any reshuffle (e.g. swapping the + concat order) that would silently feed the wrong RoPE to each layer.""" + global_dim, local_dim = 384, 192 + local = _FakeRope(dim=local_dim, tag=11.0) + glob = _FakeRope(dim=global_dim, tag=22.0) + dual = DualRotaryEmbedding(local, glob, global_dim=global_dim) + + seq_len = 8 + combined = dual(seq_len) + + # Mimic Gemma4TransformerLayer.forward's split logic. + for is_sliding, expected_rope in [(False, glob), (True, local)]: + if is_sliding: + sliced = combined[..., global_dim:] + else: + sliced = combined[..., :global_dim] + assert torch.equal(sliced, expected_rope(seq_len)), ( + f"split for is_sliding={is_sliding} did not recover the right rope" + ) + + +def test_dual_rope_delegates_get_rotary_seq_len_to_local(): + """``get_rotary_seq_len`` must delegate to local_rope — both ropes share + seq-length logic (they only differ in theta and partial-rotary), so the + answer is the same, and delegation guarantees that.""" + local = _FakeRope(dim=256, tag=0.0) + glob = _FakeRope(dim=512, tag=0.0) + dual = DualRotaryEmbedding(local, glob, global_dim=512) + + result = dual.get_rotary_seq_len("a", b=2) + # _FakeRope.get_rotary_seq_len returns a sentinel tuple; confirm it + # really came from the LOCAL one. + assert result[0] == "fake_seq_len_result" + assert result[1] == ("a",) + assert result[2] == {"b": 2} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Megatron RotaryEmbedding.forward requires CUDA") +def test_dual_rope_end_to_end_with_real_megatron_rope(): + """Integration smoke test: wire real Megatron ``RotaryEmbedding`` into + DualRotaryEmbedding and sanity-check shape + split correctness.""" + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + + local = RotaryEmbedding(kv_channels=256, rotary_percent=1.0, rotary_base=10_000.0) + glob = RotaryEmbedding(kv_channels=512, rotary_percent=1.0, rotary_base=1_000_000.0) + dual = DualRotaryEmbedding(local, glob, global_dim=512) + + combined = dual(64) + assert combined.shape[-1] == 512 + 256 + assert torch.equal(combined[..., :512], glob(64)) + assert torch.equal(combined[..., 512:], local(64)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_hf_key_contract.py b/tests/gemma4/test_gemma4_hf_key_contract.py new file mode 100644 index 0000000000..40c1f5ce9c --- /dev/null +++ b/tests/gemma4/test_gemma4_hf_key_contract.py @@ -0,0 +1,208 @@ +"""Contract test: every HF Gemma4 state-dict key must be producible by our +megatron→HF converter. + +Why this exists: ``convert_gemma4_to_hf`` emits HF-format tensors that are +then consumed by sglang (during rollout weight-update) and by +``convert_torch_dist_to_hf.py`` (for offline eval export). Both downstream +consumers walk HF's *expected* state-dict keys and look them up in the +converter output. If a future change to the Megatron side drops a key that +HF still wants, the tests in ``test_gemma4_bridge.py`` — which only exercise +specific mcore→HF mappings in isolation — won't notice; the first symptom is +a tensor-shape or missing-key crash at weight load, hours into a training +run. + +This test pins the contract by: + 1. Instantiating a tiny 2-layer (1 sliding + 1 global) MoE Gemma4 via HF, + snapshotting its state-dict key set (the "HF expected" set). + 2. Running a synthetic Megatron-side key list through + ``convert_gemma4_to_hf``, gathering the emitted HF key set. + 3. Asserting the emitted set covers every HF key (modulo the known + ``v_proj`` omission on K=V global layers, where HF itself sets + ``v_proj = None`` and omits the parameter from its state_dict). + +Anything that fails this contract would silently break a training run. +""" +from types import SimpleNamespace + +import pytest +import torch + + +# Synthesized Megatron state-dict keys for the tiny 2-layer MoE config below. +# Matches what ``get_model(model_provider, ...)`` produces at load time — +# enumerated here rather than instantiated to keep this test CPU-only and +# independent of Megatron init. +def _mcore_keys_tiny_moe(num_experts: int = 2) -> list[str]: + base = [ + "module.module.embedding.word_embeddings.weight", + "module.module.decoder.final_layernorm.weight", + ] + # Output layer is tied to embedding and shares key in HF; Megatron saves + # it separately. + base.append("module.module.output_layer.weight") + for layer_idx in (0, 1): + prefix = f"module.module.decoder.layers.{layer_idx}" + base.extend([ + # Attention + f"{prefix}.self_attention.linear_qkv.weight", + f"{prefix}.self_attention.linear_qkv.layer_norm_weight", + f"{prefix}.self_attention.linear_proj.weight", + f"{prefix}.self_attention.q_layernorm.weight", + f"{prefix}.self_attention.k_layernorm.weight", + # post_attention_layernorm lives outside TE-fused norm paths. + f"{prefix}.post_attention_layernorm.weight", + # Per-layer scalar buffer (loaded from HF ckpt via provider + # hook; saved by Megatron alongside trainable params). + f"{prefix}.layer_scalar", + # Dense-MLP sibling (parallel to the MoE block in 26B-A4B). + f"{prefix}.dense_mlp.linear_fc1.weight", + f"{prefix}.dense_mlp.linear_fc1.layer_norm_weight", + f"{prefix}.dense_mlp.linear_fc2.weight", + # Fused pre/post FFN layernorms around the dense+MoE add. + f"{prefix}.pre_mlp_layernorm.weight", + f"{prefix}.post_feedforward_layernorm.weight", + f"{prefix}.post_feedforward_layernorm_1.weight", + f"{prefix}.post_feedforward_layernorm_2.weight", + # MoE block internals — pre_feedforward_layernorm_2 moved inside + # Gemma4MoELayer in the current code, so it lives under .mlp.*. + f"{prefix}.mlp.pre_feedforward_layernorm_2.weight", + # Router + f"{prefix}.mlp.router.proj.weight", + f"{prefix}.mlp.router.scale", + f"{prefix}.mlp.router.per_expert_scale", + ]) + # Per-expert weights (names use global expert indices; our converter + # buffers + flushes to stacked 3D tensors). + for e in range(num_experts): + base.extend([ + f"{prefix}.mlp.experts.linear_fc1.weight{e}", + f"{prefix}.mlp.experts.linear_fc2.weight{e}", + ]) + return base + + +def _build_tiny_hf_model(): + """Build a 2-layer (1 sliding + 1 global) MoE Gemma4 via HF; return its + ``model.language_model`` state-dict keys. + + Intentionally ``hidden_size_per_layer_input=0`` to disable the + per-layer-input gate (which our plugin doesn't implement and our + converter doesn't map); and ``attention_k_eq_v=True`` to match the + real 26B/31B configs (so HF's global-layer v_proj is ``None`` and + absent from state_dict). + """ + from transformers.models.gemma4 import ( + configuration_gemma4 as C, modeling_gemma4 as M, + ) + + text_cfg = C.Gemma4TextConfig( + vocab_size=64, hidden_size=32, intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, num_key_value_heads=2, + num_global_key_value_heads=2, + head_dim=16, global_head_dim=32, + sliding_window=64, rope_theta=10000.0, + layer_types=["sliding_attention", "full_attention"], + enable_moe_block=True, num_experts=2, moe_intermediate_size=48, + top_k_experts=2, + hidden_size_per_layer_input=0, + attention_k_eq_v=True, + ) + full_cfg = C.Gemma4Config( + text_config=text_cfg.to_dict(), vision_config=None, audio_config=None, + ) + hf_model = M.Gemma4ForConditionalGeneration(full_cfg) + return set( + k for k in hf_model.state_dict().keys() if "language_model" in k + ) + + +def test_converter_emits_every_hf_key(): + """Run every synthesized Megatron key through ``convert_gemma4_to_hf``; + assert the union of emitted HF keys covers HF's expected state_dict.""" + transformers_gemma4 = pytest.importorskip("transformers.models.gemma4") + del transformers_gemma4 # only needed to gate + + from slime.backends.megatron_utils.megatron_to_hf import gemma4 as conv + + # Seed the converter's module-global config cache with the tiny config + # directly (avoids instantiating AutoConfig, which needs a ckpt dir). + # _get_config is the only reader; it returns early when "config" in + # _config_cache. + conv._config_cache["config"] = { + "global_attn_layers": {1}, # layer 1 is full_attention + "local_head_dim": 16, + "global_head_dim": 32, + "num_attention_heads": 4, + "local_num_kv_heads": 2, + "global_num_kv_heads": 2, + "hidden_size": 32, + "num_experts": 2, + } + # Clear expert-flush buffers so a prior test run doesn't leak state. + conv.reset_expert_buffers() + + args = SimpleNamespace() + # Dummy tensors sized to match the converter's view-reshape math. The + # converter only cares about the *names* for this test, but its QKV path + # calls .view() with real shapes, so we need plausible tensors. + def _fake_tensor_for(name: str) -> torch.Tensor: + if name.endswith("self_attention.linear_qkv.weight"): + # Packed [num_kv_heads * (q_per_kv + 2) * head_dim, hidden_size] + # For sliding: 2 kv_heads * (2 + 2) * 16 = 128, hidden=32. + # For global: 2 kv_heads * (2 + 2) * 32 = 256, hidden=32. + if "layers.1" in name: + return torch.zeros(256, 32) + return torch.zeros(128, 32) + if name.endswith("self_attention.linear_proj.weight"): + return torch.zeros(32, 64) # o_proj [hidden, num_q_heads*head_dim] + if "dense_mlp.linear_fc1.weight" in name: + return torch.zeros(128, 32) # gate||up packed: [2*inter, hidden] + if "dense_mlp.linear_fc2.weight" in name: + return torch.zeros(32, 64) + if "mlp.router.proj.weight" in name: + return torch.zeros(2, 32) + if "mlp.router.scale" in name or "mlp.router.per_expert_scale" in name: + return torch.zeros(2) + if "experts.linear_fc1.weight" in name: + return torch.zeros(96, 32) # packed 2*moe_inter = 96 + if "experts.linear_fc2.weight" in name: + return torch.zeros(32, 48) + if "embedding.word_embeddings" in name or "output_layer" in name: + return torch.zeros(64, 32) + if "layer_scalar" in name: + return torch.tensor([1.0]) + # Layernorm-family defaults. + return torch.zeros(32) + + emitted: set[str] = set() + for mcore_name in _mcore_keys_tiny_moe(num_experts=2): + t = _fake_tensor_for(mcore_name) + out = conv.convert_gemma4_to_hf(args, mcore_name, t) + for hf_name, _hf_param in out: + emitted.add(hf_name) + + expected = _build_tiny_hf_model() + + missing = expected - emitted + assert not missing, ( + f"HF expects {len(missing)} key(s) the converter never emits; this " + f"would surface as a weight-load crash or silently-random weights in " + f"sglang. Missing:\n " + "\n ".join(sorted(missing)) + ) + + # Also surface extras (emitted but HF doesn't want) as a warning — not a + # hard failure, since the converter may legitimately emit aliases (e.g. + # tied embeddings maps output_layer → embed_tokens). Print them so a + # reviewer can sanity-check. + extras = emitted - expected + if extras: + print( + f"[info] converter emits {len(extras)} key(s) HF doesn't have in " + f"its state_dict (tied embeddings / aliases):\n " + + "\n ".join(sorted(extras)) + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_layer_integration.py b/tests/gemma4/test_gemma4_layer_integration.py new file mode 100644 index 0000000000..222719652c --- /dev/null +++ b/tests/gemma4/test_gemma4_layer_integration.py @@ -0,0 +1,245 @@ +"""CUDA-gated integration tests for ``Gemma4TransformerLayer``. + +Builds a real layer through ``get_gemma4_layer_spec_te``, runs a small +forward, and checks output shape, finiteness, and that ``is_sliding`` +matches the layer index. Numerical HF parity is covered outside slime. +""" + +import os + +import pytest +import torch + + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="Gemma4TransformerLayer requires CUDA + TE kernels", +) + + +def _init_single_rank_dist(): + """Initialize a world_size=1 process group (NCCL if CUDA else gloo) so + Megatron's parallel_state helpers work without a real multi-GPU job. + + Megatron's ``parallel_state`` is module-global: if a prior test module + left it initialized (e.g. test_gemma4_cp_attention) its + ``initialize_model_parallel`` asserts on re-init. Clear it first. + """ + import torch.distributed as dist + from megatron.core import parallel_state as mpu + + if mpu.model_parallel_is_initialized(): + mpu.destroy_model_parallel() + if not dist.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29566") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, rank=0, world_size=1) + mpu.initialize_model_parallel() + + +@pytest.fixture(scope="module", autouse=True) +def _dist(): + _init_single_rank_dist() + yield + # Leave the PG initialized for the module; tearing down between modules + # interacts badly with shared Megatron state. + + +def _build_layer_config( + num_layers=6, hidden_size=128, ffn_hidden_size=256, num_heads=8, + num_kv_heads=4, head_dim=128, global_head_dim=256, num_global_kv_heads=2, + sliding_window=64, +): + """Build a minimal Gemma4TransformerConfig suitable for a smoke test.""" + from slime_plugins.models.gemma4 import Gemma4TransformerConfig + + cfg = Gemma4TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_heads, + num_query_groups=num_kv_heads, + kv_channels=head_dim, + hidden_dropout=0.0, + attention_dropout=0.0, + bf16=True, + pipeline_dtype=torch.bfloat16, + params_dtype=torch.bfloat16, + add_bias_linear=False, + add_qkv_bias=False, + gated_linear_unit=True, + activation_func=torch.nn.functional.gelu, # placeholder + normalization="RMSNorm", + layernorm_epsilon=1e-6, + attention_softmax_in_fp32=True, + persist_layer_norm=True, + bias_activation_fusion=False, + bias_dropout_fusion=True, + apply_rope_fusion=False, + qk_layernorm=True, + sequence_parallel=False, + tensor_model_parallel_size=1, + ) + # Gemma4-specific extensions. + cfg.global_kv_channels = global_head_dim + cfg.global_num_query_groups = num_global_kv_heads + cfg.global_partial_rotary_factor = 0.25 + cfg.attention_k_eq_v = True + cfg.final_logit_softcapping = 30.0 + cfg.enable_moe_block = False + cfg.sliding_window = sliding_window + cfg.sliding_window_pattern = 6 # every 6th layer is global + cfg.softmax_scale = 1.0 + return cfg + + +@requires_cuda +def test_layer_builds_and_forwards_sliding(): + """Layer 1 (1-indexed mod 6 != 0) is a sliding layer. Build it, run + forward, verify shape + finiteness.""" + from functools import partial + import torch.nn.functional as F + from megatron.core.transformer.spec_utils import build_module + from slime_plugins.models.gemma4 import get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + spec = get_gemma4_layer_spec_te(cfg) + + layer = build_module(spec, config=cfg, layer_number=1) + layer = layer.cuda().to(torch.bfloat16) + assert layer.is_sliding is True + assert layer._is_global is False + + # Input: [seq, batch, hidden] + seq, batch = 16, 1 + h = torch.randn(seq, batch, cfg.hidden_size, device="cuda", dtype=torch.bfloat16) + + # Minimal rotary_pos_emb — just provide a local-dim tensor. + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + rope = RotaryEmbedding(kv_channels=cfg.kv_channels, rotary_percent=1.0) + rotary = rope(seq).cuda() + + out, _ctx = layer(h, rotary_pos_emb=rotary, attention_mask=None) + assert out.shape == h.shape + assert torch.isfinite(out).all() + + +@requires_cuda +def test_layer_global_path_builds_and_forwards(): + """Layer 6 (1-indexed mod 6 == 0) is a global layer with head_dim=256, + num_kv_heads=2. Build it, run forward, verify shape + finiteness.""" + from functools import partial + import torch.nn.functional as F + from megatron.core.transformer.spec_utils import build_module + from slime_plugins.models.gemma4 import get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + spec = get_gemma4_layer_spec_te(cfg) + + # layer_number=6 → 1-indexed → global (mod 6 == 0) + layer = build_module(spec, config=cfg, layer_number=6) + layer = layer.cuda().to(torch.bfloat16) + assert layer.is_sliding is False + assert layer._is_global is True + + seq, batch = 16, 1 + h = torch.randn(seq, batch, cfg.hidden_size, device="cuda", dtype=torch.bfloat16) + + # Global layers use global_head_dim rotary. + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + rope = RotaryEmbedding(kv_channels=cfg.global_kv_channels, rotary_percent=1.0) + rotary = rope(seq).cuda() + + out, _ctx = layer(h, rotary_pos_emb=rotary, attention_mask=None) + assert out.shape == h.shape + assert torch.isfinite(out).all() + + +@requires_cuda +def test_layer_does_not_mutate_shared_config(): + """Building a global layer must NOT leave the shared `config` with + global-layer kv_channels/num_query_groups — that would break a + subsequently-built sliding layer from the same spec.""" + from functools import partial + import torch.nn.functional as F + from megatron.core.transformer.spec_utils import build_module + from slime_plugins.models.gemma4 import get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + orig_kv = cfg.kv_channels + orig_nqg = cfg.num_query_groups + + spec = get_gemma4_layer_spec_te(cfg) + build_module(spec, config=cfg, layer_number=6).cuda() + # Now shared config must still have the sliding-layer values. + assert cfg.kv_channels == orig_kv, ( + f"building a global layer mutated shared config.kv_channels: " + f"{orig_kv} -> {cfg.kv_channels}" + ) + assert cfg.num_query_groups == orig_nqg, ( + f"building a global layer mutated shared config.num_query_groups: " + f"{orig_nqg} -> {cfg.num_query_groups}" + ) + + +def test_layer_spec_builds_without_cuda(): + """Constructing the layer spec (no instantiation) must work on CPU — + catches regressions like a typo'd submodule or a stray hard CUDA dep + inside the spec factory itself.""" + from functools import partial + import torch.nn.functional as F + from slime_plugins.models.gemma4 import ( + Gemma4SelfAttention, Gemma4TransformerLayer, get_gemma4_layer_spec_te, + ) + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + spec = get_gemma4_layer_spec_te(cfg) + + assert spec.module is Gemma4TransformerLayer + assert spec.submodules.self_attention.module is Gemma4SelfAttention + # post_attention_layernorm + post_feedforward_layernorm must be real + # LayerNorms (not IdentityOp), since Gemma uses them. + from megatron.core.transformer.identity_op import IdentityOp + assert spec.submodules.post_attention_layernorm is not IdentityOp + assert spec.submodules.post_feedforward_layernorm is not IdentityOp + + +def test_layer_spec_moe_variant_includes_dense_mlp_spec(): + """With enable_moe_block=True, the spec's `mlp` submodule must be a + Gemma4MoELayer and `dense_mlp` must be set (not IdentityOp).""" + from functools import partial + import torch.nn.functional as F + from megatron.core.transformer.identity_op import IdentityOp + from slime_plugins.models.gemma4 import ( + Gemma4MoELayer, get_gemma4_layer_spec_te, + ) + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + cfg.enable_moe_block = True + cfg.num_moe_experts = 8 + cfg.moe_router_topk = 2 + cfg.moe_ffn_hidden_size = 128 + cfg.moe_token_dispatcher_type = "alltoall" + cfg.moe_grouped_gemm = True + cfg.moe_aux_loss_coeff = 0.0 + cfg.moe_router_load_balancing_type = "none" + cfg.moe_router_score_function = "softmax" + cfg.moe_router_topk_scaling_factor = 1.0 + cfg.moe_router_pre_softmax = False + + spec = get_gemma4_layer_spec_te(cfg) + assert spec.submodules.mlp.module is Gemma4MoELayer + assert spec.submodules.dense_mlp is not IdentityOp, ( + "dense_mlp must be a concrete spec when enable_moe_block=True" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_layer_scalar_broadcast.py b/tests/gemma4/test_gemma4_layer_scalar_broadcast.py new file mode 100644 index 0000000000..d1e870ba95 --- /dev/null +++ b/tests/gemma4/test_gemma4_layer_scalar_broadcast.py @@ -0,0 +1,127 @@ +"""Multi-rank test for Gemma4's layer_scalar rank-0-read + broadcast path. + +The single-process tests in ``test_gemma4_provider.py`` confirm the +safetensors-read / PP-offset logic, but they never exercise the real +distributed path: ``_read_layer_scalars_from_safetensors`` runs only on +rank 0, then ``_broadcast_layer_scalars`` fans the dict out to the rest +via ``torch.distributed.broadcast_object_list``. A regression where +rank > 0 ends up with ``None`` or the default 1.0 would silently drift +activations on every forward pass — caught only much later by parity +tests. + +This test spawns 2 gloo ranks (no CUDA required), has rank 0 fabricate a +safetensors ckpt in a tmpdir, then both ranks call ``_load_layer_scalars`` +on a minimal inner model. We then assert the loaded scalars match on both +ranks. +""" + +import json +import os +import tempfile + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def _worker(rank: int, world_size: int, master_port: int, ckpt_dir: str, out_dir: str): + """Run on each spawned rank: init PG, build a fake inner model, + call _load_layer_scalars, write the resulting scalars to a per-rank + file so the parent process can diff them.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + from slime_plugins.models import gemma4_provider as _provider + import megatron.core.transformer.transformer_layer as tl + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(3): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + _provider._load_layer_scalars( + inner, ckpt_dir, config=type("C", (), {})() + ) + finally: + tl.get_transformer_layer_offset = orig_offset + + loaded = [layer.layer_scalar.item() for layer in inner.decoder.layers] + out_path = os.path.join(out_dir, f"rank{rank}.json") + with open(out_path, "w") as fp: + json.dump({"rank": rank, "scalars": loaded}, fp) + finally: + dist.destroy_process_group() + + +def _write_fake_checkpoint(ckpt_dir: str, scalars: dict[int, float]) -> None: + """Produce a safetensors file per layer + a matching + model.safetensors.index.json. Matches + _read_layer_scalars_from_safetensors' expectations. + """ + from safetensors.torch import save_file + + weight_map = {} + for layer_idx, value in scalars.items(): + tensor_name = f"model.language_model.layers.{layer_idx}.layer_scalar" + fname = f"layer_{layer_idx}.safetensors" + save_file( + {tensor_name: torch.tensor([value], dtype=torch.float32)}, + os.path.join(ckpt_dir, fname), + ) + weight_map[tensor_name] = fname + + with open(os.path.join(ckpt_dir, "model.safetensors.index.json"), "w") as fp: + json.dump({"metadata": {}, "weight_map": weight_map}, fp) + + +def test_layer_scalars_broadcast_to_all_ranks(): + """Rank 0 reads scalars from disk, broadcasts to rank 1 via + broadcast_object_list — both ranks must end up with identical + layer_scalar values.""" + expected = {0: 0.5, 1: 1.25, 2: 2.0} + + with tempfile.TemporaryDirectory() as tmp: + ckpt_dir = os.path.join(tmp, "ckpt") + os.makedirs(ckpt_dir) + _write_fake_checkpoint(ckpt_dir, expected) + + out_dir = os.path.join(tmp, "out") + os.makedirs(out_dir) + # Deterministic port; tests serialized per-module so no race. + master_port = 29577 + + mp.spawn( + _worker, + args=(2, master_port, ckpt_dir, out_dir), + nprocs=2, + join=True, + ) + + with open(os.path.join(out_dir, "rank0.json")) as fp: + r0 = json.load(fp) + with open(os.path.join(out_dir, "rank1.json")) as fp: + r1 = json.load(fp) + + assert r0["rank"] == 0 + assert r1["rank"] == 1 + assert r0["scalars"] == pytest.approx([0.5, 1.25, 2.0]) + assert r1["scalars"] == pytest.approx([0.5, 1.25, 2.0]), ( + "rank 1 did not receive the broadcast scalars — check " + "_broadcast_layer_scalars" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_provider.py b/tests/gemma4/test_gemma4_provider.py new file mode 100644 index 0000000000..3150851bae --- /dev/null +++ b/tests/gemma4/test_gemma4_provider.py @@ -0,0 +1,375 @@ +"""Unit tests for ``gemma4_provider.py`` hooks and helpers. + +Covers: +- ``_install_hooks``: embedding-scale, softcap, dual-RoPE wiring (integration + bits are in ``test_gemma4_dual_rope.py``). +- ``_load_layer_scalars``: reading per-layer scalar buffers from an HF + safetensors checkpoint, with the PP offset translation. + +These are pure-Python/CPU tests. We import the provider module by hand to +avoid Megatron-wide imports that require CUDA / a process group.""" + +import importlib.util +import json +import pathlib +import sys +from types import SimpleNamespace + +import pytest +import torch + + +def _load_provider_module(): + """Import ``gemma4_provider`` without triggering the module-level + ``from megatron.training import get_args``.""" + import types + + if "megatron.training" not in sys.modules: + stub = types.ModuleType("megatron.training") + stub.get_args = lambda: None + sys.modules["megatron.training"] = stub + if "megatron.training.arguments" not in sys.modules: + stub2 = types.ModuleType("megatron.training.arguments") + stub2.core_transformer_config_from_args = lambda *a, **k: None + sys.modules["megatron.training.arguments"] = stub2 + + repo_path = pathlib.Path(__file__).resolve().parents[2] / ( + "slime_plugins/models/gemma4_provider.py" + ) + spec = importlib.util.spec_from_file_location( + "_gemma4_provider_under_test", repo_path + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +_provider = _load_provider_module() + + +# ============================================================================ +# Softcap hook (part of _install_hooks) +# ============================================================================ + +def test_install_hooks_softcap_wraps_tensor_output(): + """With final_logit_softcapping=30, the output_layer forward hook must + transform tensor output `x` → tanh(x / 30) * 30.""" + # Build a minimal `inner` and `args` that _install_hooks recognises. + inner = torch.nn.Module() + inner.output_layer = torch.nn.Linear(4, 8, bias=False) + + hf_text = SimpleNamespace(final_logit_softcapping=30.0) + # Monkey-patch the helper in the provider module so we don't need a real + # HF checkpoint on disk. + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _path: hf_text + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=4) + _provider._install_hooks( + model=inner, args=args, config=config, + pre_process=False, post_process=True, + ) + finally: + _provider._load_hf_text_config = orig + + # Run output_layer forward — the hook should softcap the result. We + # compute `raw` manually (the forward hook wraps __call__, so calling + # output_layer(x) returns the softcapped version, not the raw). + x = torch.randn(2, 4) + raw = x @ inner.output_layer.weight.T # same math as Linear but no hook + hooked = inner.output_layer(x) # goes through the hook + expected = torch.tanh(raw / 30.0) * 30.0 + assert torch.allclose(hooked, expected, atol=1e-6), ( + "softcap hook did not apply tanh(x/cap)*cap to the tensor output" + ) + # And the hooked output is in the softcap range. + assert hooked.abs().max().item() <= 30.0 + + +def test_install_hooks_softcap_wraps_tuple_output(): + """Some Megatron layers return (output, bias) tuples. The softcap hook + must only transform the first element and leave the rest untouched.""" + inner = torch.nn.Module() + # Minimal stand-in for Megatron ColumnParallelLinear-style output. + class _TupleOutLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(torch.randn(8, 4)) + + def forward(self, x): + return x @ self.w.T, None # (output, bias) + + inner.output_layer = _TupleOutLayer() + hf_text = SimpleNamespace(final_logit_softcapping=30.0) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _path: hf_text + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=4) + _provider._install_hooks( + model=inner, args=args, config=config, + pre_process=False, post_process=True, + ) + finally: + _provider._load_hf_text_config = orig + + x = torch.randn(3, 4) + hooked, bias = inner.output_layer(x) + raw = x @ inner.output_layer.w.T + expected = torch.tanh(raw / 30.0) * 30.0 + assert torch.allclose(hooked, expected, atol=1e-6) + assert bias is None # tuple tail preserved + + +def test_install_hooks_no_softcap_when_disabled(): + """When final_logit_softcapping is None / 0, no hook is registered.""" + inner = torch.nn.Module() + inner.output_layer = torch.nn.Linear(4, 8, bias=False) + + for cap_value in (None, 0, 0.0): + # Clear any previous hook + for h in list(inner.output_layer._forward_hooks.keys()): + inner.output_layer._forward_hooks.pop(h) + + hf_text = SimpleNamespace(final_logit_softcapping=cap_value) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _p, _t=hf_text: _t + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=4) + _provider._install_hooks( + model=inner, args=args, config=config, + pre_process=False, post_process=True, + ) + finally: + _provider._load_hf_text_config = orig + assert len(inner.output_layer._forward_hooks) == 0, ( + f"softcap hook should not register when cap={cap_value!r}" + ) + + +# ============================================================================ +# Embedding scale hook (part of _install_hooks) +# ============================================================================ + +def _install_embed_hook(inner, hidden): + hf_text = SimpleNamespace(final_logit_softcapping=None) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _path: hf_text + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=hidden) + _provider._install_hooks( + model=inner, args=args, config=config, + pre_process=True, post_process=False, + ) + finally: + _provider._load_hf_text_config = orig + + +def test_install_hooks_embedding_scale_fp32_weight(): + """With fp32 embedding weights, the scale is applied in fp32 — matches + ``Gemma4TextScaledWordEmbedding.forward = emb * embed_scale.to(weight.dtype)``.""" + hidden = 1024 + inner = torch.nn.Module() + inner.embedding = torch.nn.Embedding(100, hidden) # fp32 by default + _install_embed_hook(inner, hidden) + + ids = torch.tensor([[1, 2, 3]]) + hooked = inner.embedding(ids) + raw = inner.embedding.weight[ids] + expected_scale = torch.tensor(hidden ** 0.5) # fp32 full precision + assert torch.allclose(hooked, raw * expected_scale, atol=1e-6), ( + "embed scale must be applied in fp32 when weight is fp32" + ) + + +def test_install_hooks_embedding_scale_bf16_weight(): + """With bf16 embedding weights, the scale is cast to bf16 before + multiplying — matching HF's ``embed_scale.to(weight.dtype)`` semantics. + This guards against a previous impl that pre-cast the scale to bf16 + regardless of weight dtype.""" + hidden = 1024 + inner = torch.nn.Module() + inner.embedding = torch.nn.Embedding(100, hidden).to(torch.bfloat16) + _install_embed_hook(inner, hidden) + + ids = torch.tensor([[1, 2, 3]]) + hooked = inner.embedding(ids) + raw = inner.embedding.weight[ids] + # Expected: scale cast to bf16 at forward time. + expected_scale = torch.tensor(hidden ** 0.5).to(torch.bfloat16) + assert torch.allclose(hooked, raw * expected_scale, atol=1e-2), ( + "embed scale must be cast to bf16 when weight is bf16" + ) + + +# ============================================================================ +# _load_layer_scalars +# ============================================================================ + +def _write_fake_safetensors_layer_scalars(ckpt_dir, scalars): + """Write a minimal safetensors checkpoint containing only layer_scalar + tensors, plus an index.json so _load_layer_scalars can find them.""" + from safetensors.torch import save_file + weight_map = {} + for layer_idx, value in scalars.items(): + tensor_name = f"model.language_model.layers.{layer_idx}.layer_scalar" + fname = f"layer_{layer_idx}.safetensors" + save_file({tensor_name: torch.tensor(value)}, str(ckpt_dir / fname)) + weight_map[tensor_name] = fname + index = {"metadata": {}, "weight_map": weight_map} + (ckpt_dir / "model.safetensors.index.json").write_text(json.dumps(index)) + + +def test_load_layer_scalars_applies_values_to_layers(tmp_path): + """Confirm scalars from safetensors are copied into layer.layer_scalar.""" + scalars = {0: 0.5, 1: 1.5, 2: 2.5} + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + # Build a minimal "inner" with layer_scalar buffers on 3 layers. + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(3): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + # Stub get_transformer_layer_offset -> 0 (no PP). + import megatron.core.transformer.transformer_layer as tl + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + for i, expected in scalars.items(): + assert inner.decoder.layers[i].layer_scalar.item() == pytest.approx(expected), ( + f"layer {i}: expected {expected}, got {inner.decoder.layers[i].layer_scalar.item()}" + ) + + +def test_load_layer_scalars_respects_pp_offset(tmp_path): + """Under PP, inner.decoder.layers holds only this rank's local subset; + local index i must translate to global HF index i + pp_offset.""" + # HF checkpoint has scalars for layers 10, 11, 12 (e.g., PP rank 1 of 2 + # on a 20-layer model — local layers 0..9 map to global 10..19). + scalars = {10: 0.7, 11: 0.8, 12: 0.9} + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + # Local `inner` has 3 layers representing global 10, 11, 12. + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(3): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + import megatron.core.transformer.transformer_layer as tl + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 10 # PP offset + try: + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + assert inner.decoder.layers[0].layer_scalar.item() == pytest.approx(0.7) + assert inner.decoder.layers[1].layer_scalar.item() == pytest.approx(0.8) + assert inner.decoder.layers[2].layer_scalar.item() == pytest.approx(0.9) + + +def test_load_layer_scalars_raises_by_default_when_missing(tmp_path, monkeypatch): + """By default, a missing layer_scalar for any local layer fails loudly + (wrong scalars materially change activations vs HF). This mirrors the + provider's fail-loud posture for checkpoint drift.""" + monkeypatch.delenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", raising=False) + scalars = {0: 0.5} # only layer 0 has a scalar + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(2): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + import megatron.core.transformer.transformer_layer as tl + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + with pytest.raises(KeyError, match="missing in checkpoint"): + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + +def test_load_layer_scalars_defaults_to_one_when_missing_with_opt_in(tmp_path, monkeypatch): + """With GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1, a missing scalar logs a + warning and falls back to the default 1.0.""" + monkeypatch.setenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", "1") + scalars = {0: 0.5} # only layer 0 has a scalar + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(2): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + import megatron.core.transformer.transformer_layer as tl + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + assert inner.decoder.layers[0].layer_scalar.item() == pytest.approx(0.5) + assert inner.decoder.layers[1].layer_scalar.item() == pytest.approx(1.0) + + +def test_load_layer_scalars_raises_when_no_index_file(tmp_path, monkeypatch): + """Missing index.json is fail-loud by default (checkpoint lacks the + layer_scalar tensors we need). The legacy skip-and-warn behavior is + available via GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1.""" + monkeypatch.delenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", raising=False) + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + inner.decoder.layers = torch.nn.ModuleList([torch.nn.Module()]) + inner.decoder.layers[0].register_buffer("layer_scalar", torch.ones(1)) + + # No index.json in tmp_path — read returns None, provider should raise. + with pytest.raises(RuntimeError, match="No layer_scalar weights found"): + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + + +def test_load_layer_scalars_skips_when_no_index_file_with_opt_in(tmp_path, monkeypatch, caplog): + """With opt-in flag, missing index.json degrades to a warning + default.""" + import logging + monkeypatch.setenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", "1") + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + inner.decoder.layers = torch.nn.ModuleList([torch.nn.Module()]) + inner.decoder.layers[0].register_buffer("layer_scalar", torch.ones(1)) + + # No index.json in tmp_path. + with caplog.at_level(logging.WARNING, logger=_provider.__name__): + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + # Scalar unchanged. + assert inner.decoder.layers[0].layer_scalar.item() == 1.0 + assert any("No safetensors index" in r.message for r in caplog.records) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_qkv_roundtrip.py b/tests/gemma4/test_gemma4_qkv_roundtrip.py new file mode 100644 index 0000000000..b0bcb6a002 --- /dev/null +++ b/tests/gemma4/test_gemma4_qkv_roundtrip.py @@ -0,0 +1,210 @@ +"""HF ↔ Mcore QKV weight conversion roundtrip tests. + +We already have forward-only tests for ``convert_gemma4_to_hf`` in +``test_gemma4_bridge.py``. This file closes the loop: for every attention +layer type (sliding local, global K=V), start from HF ``[q, k, v]`` tensors, +pack into the Mcore ``linear_qkv.weight`` layout via ``Gemma4Bridge. +_weight_to_mcore_format``, unpack via ``convert_gemma4_to_hf``, and confirm +bit-identity with the originals. + +Guards against silent drift in either direction of the converter as we +evolve Gemma4 support. +""" + +import importlib +import importlib.util +import pathlib +from types import SimpleNamespace + +import pytest +import torch + +from slime_plugins.mbridge.gemma4 import Gemma4Bridge + + +def _load_convert_module(): + try: + return importlib.import_module("slime.backends.megatron_utils.megatron_to_hf.gemma4") + except ImportError: + pass + repo_path = pathlib.Path(__file__).resolve().parents[2] / ( + "slime/backends/megatron_utils/megatron_to_hf/gemma4.py" + ) + if not repo_path.exists(): + pytest.skip(f"convert module not found at {repo_path}") + spec = importlib.util.spec_from_file_location("_gemma4_conv_rt", repo_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# Gemma4-31B canonical config. +CFG_31B = SimpleNamespace( + hidden_size=5376, + num_attention_heads=32, + head_dim=256, + num_key_value_heads=16, + global_head_dim=512, + num_global_key_value_heads=4, + num_hidden_layers=60, + attention_k_eq_v=True, + # layer_types=["sliding_attention"] * 5 + ["full_attention"], repeated. + layer_types=(["sliding_attention"] * 5 + ["full_attention"]) * 10, +) +_GLOBAL_LAYERS_31B = {i for i, t in enumerate(CFG_31B.layer_types) if t == "full_attention"} + + +def _build_bridge_stub(cfg): + """Make a Gemma4Bridge instance without going through its real __init__ + (which expects an HF checkpoint on disk). We set just enough attributes + for ``_weight_to_mcore_format`` to run.""" + b = object.__new__(Gemma4Bridge) + b._GLOBAL_ATTN_LAYERS = { + i for i, t in enumerate(cfg.layer_types) if t == "full_attention" + } + b.hf_config = SimpleNamespace(text_config=cfg) + return b + + +def _prime_convert_config(conv): + conv._config_cache["config"] = { + "global_attn_layers": _GLOBAL_LAYERS_31B, + "local_head_dim": CFG_31B.head_dim, + "global_head_dim": CFG_31B.global_head_dim, + "num_attention_heads": CFG_31B.num_attention_heads, + "local_num_kv_heads": CFG_31B.num_key_value_heads, + "global_num_kv_heads": CFG_31B.num_global_key_value_heads, + "hidden_size": CFG_31B.hidden_size, + } + + +def test_sliding_layer_qkv_roundtrip(): + """Start from HF [q, k, v] → pack into Mcore layout → unpack → same tensors.""" + torch.manual_seed(0) + conv = _load_convert_module() + _prime_convert_config(conv) + bridge = _build_bridge_stub(CFG_31B) + + # Layer 0 is sliding in our config. + layer_idx = 0 + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + v = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + + mcore_name = f"decoder.layers.{layer_idx}.self_attention.linear_qkv.weight" + packed = bridge._weight_to_mcore_format(mcore_name, [q, k, v]) + assert packed.shape == ( + CFG_31B.num_attention_heads * CFG_31B.head_dim + + 2 * CFG_31B.num_key_value_heads * CFG_31B.head_dim, + CFG_31B.hidden_size, + ) + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, f"module.module.{mcore_name}", packed, + ) + out = dict(emitted) + assert set(out) == { + f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight", + f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight", + f"model.language_model.layers.{layer_idx}.self_attn.v_proj.weight", + } + assert torch.allclose( + out[f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight"], q + ) + assert torch.allclose( + out[f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight"], k + ) + assert torch.allclose( + out[f"model.language_model.layers.{layer_idx}.self_attn.v_proj.weight"], v + ) + + +def test_global_k_eq_v_layer_qkv_roundtrip(): + """K=V global layer: HF ships [q, k] only. The bridge reconstructs V + from K during pack; convert_gemma4_to_hf unpacks and emits [q, k] only + (no v_proj). Roundtrip must still recover q and k bit-identically.""" + torch.manual_seed(1) + conv = _load_convert_module() + _prime_convert_config(conv) + bridge = _build_bridge_stub(CFG_31B) + + layer_idx = 5 # global + assert layer_idx in _GLOBAL_LAYERS_31B + + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_global_key_value_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + + mcore_name = f"decoder.layers.{layer_idx}.self_attention.linear_qkv.weight" + packed = bridge._weight_to_mcore_format(mcore_name, [q, k]) + # Packed shape: num_kv * (q_per_kv + 2) * head_dim rows, with V=K slot. + q_per_kv = CFG_31B.num_attention_heads // CFG_31B.num_global_key_value_heads + expected_rows = ( + CFG_31B.num_global_key_value_heads + * (q_per_kv + 2) + * CFG_31B.global_head_dim + ) + assert packed.shape == (expected_rows, CFG_31B.hidden_size) + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, f"module.module.{mcore_name}", packed, + ) + out = dict(emitted) + # Global K=V: only q and k are emitted. + assert set(out) == { + f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight", + f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight", + } + assert torch.allclose( + out[f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight"], q + ) + assert torch.allclose( + out[f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight"], k + ) + + +def test_sliding_layer_roundtrip_rejects_wrong_shape(): + """The new shape assertions (B6/B8) fire when q_proj rows don't match + num_kv_heads * group_dim — a common symptom of a miscounted config.""" + bridge = _build_bridge_stub(CFG_31B) + + # Wrong: use GLOBAL head_dim on a SLIDING layer → q rows won't match. + q_bad = torch.randn(CFG_31B.num_attention_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + k_bad = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + v_bad = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + + with pytest.raises(AssertionError, match="q_proj rows"): + bridge._weight_to_mcore_format( + "decoder.layers.0.self_attention.linear_qkv.weight", + [q_bad, k_bad, v_bad], + ) + + +def test_mlp_fc1_asserts_wrong_count(): + """linear_fc1.weight expects exactly [gate_proj, up_proj] from HF. Passing + 3 tensors hits the assert (passing 1 is short-circuited to a pass-through + earlier in _weight_to_mcore_format for unrelated single-tensor cases).""" + bridge = _build_bridge_stub(CFG_31B) + with pytest.raises(AssertionError, match="linear_fc1.weight expects"): + bridge._weight_to_mcore_format( + "decoder.layers.0.mlp.linear_fc1.weight", + [torch.randn(4, 4), torch.randn(4, 4), torch.randn(4, 4)], + ) + + +def test_mlp_fc1_pack_concatenates_gate_up(): + """linear_fc1.weight packs HF [gate_proj, up_proj] → [gate; up] along dim 0.""" + bridge = _build_bridge_stub(CFG_31B) + gate = torch.randn(CFG_31B.hidden_size, CFG_31B.hidden_size) + up = torch.randn(CFG_31B.hidden_size, CFG_31B.hidden_size) + packed = bridge._weight_to_mcore_format( + "decoder.layers.0.mlp.linear_fc1.weight", [gate, up], + ) + assert packed.shape == (2 * CFG_31B.hidden_size, CFG_31B.hidden_size) + assert torch.equal(packed[:CFG_31B.hidden_size], gate) + assert torch.equal(packed[CFG_31B.hidden_size:], up) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gemma4/test_gemma4_router.py b/tests/gemma4/test_gemma4_router.py new file mode 100644 index 0000000000..3f89b40e05 --- /dev/null +++ b/tests/gemma4/test_gemma4_router.py @@ -0,0 +1,168 @@ +"""Unit tests for ``Gemma4Router`` and the ``Gemma4MoELayer.route`` adapter. + +These are pure-Python/CPU tests that exercise the routing arithmetic without +going through Megatron's MoE dispatch infrastructure. +""" + +from types import SimpleNamespace + +import pytest +import torch + +from slime_plugins.models.gemma4 import Gemma4MoELayer, Gemma4Router + + +def _make_router_config(hidden_size=16, num_experts=8, top_k=2, eps=1e-6): + return SimpleNamespace( + hidden_size=hidden_size, + num_moe_experts=num_experts, + moe_router_topk=top_k, + layernorm_epsilon=eps, + ) + + +def test_router_outputs_have_correct_shapes(): + torch.manual_seed(0) + cfg = _make_router_config(num_experts=8, top_k=2) + router = Gemma4Router(cfg) + h = torch.randn(5, cfg.hidden_size) + weights, idx = router(h) + assert weights.shape == (5, cfg.moe_router_topk) + assert idx.shape == (5, cfg.moe_router_topk) + assert idx.min() >= 0 and idx.max() < cfg.num_moe_experts + + +def test_router_weights_sum_to_one_before_per_expert_scale(): + """With per_expert_scale all ones, top-k weights must sum to 1.0 per + token — the router normalises internally.""" + torch.manual_seed(1) + cfg = _make_router_config(num_experts=8, top_k=3) + router = Gemma4Router(cfg) + # Leave per_expert_scale as its default (all ones). + h = torch.randn(6, cfg.hidden_size) + weights, _idx = router(h) + sums = weights.sum(dim=-1) + assert torch.allclose(sums, torch.ones_like(sums), atol=1e-6) + + +def test_router_per_expert_scale_multiplies_output(): + """Setting ``per_expert_scale`` to a constant c should scale weights by c.""" + torch.manual_seed(2) + cfg = _make_router_config(num_experts=4, top_k=2) + router = Gemma4Router(cfg) + # Fix per_expert_scale = 3.0 for all experts. + with torch.no_grad(): + router.per_expert_scale.fill_(3.0) + h = torch.randn(4, cfg.hidden_size) + weights, _idx = router(h) + sums = weights.sum(dim=-1) + assert torch.allclose(sums, torch.full_like(sums, 3.0), atol=1e-6) + + +def _make_moe_route_stub(): + """Build a Gemma4MoELayer stub that only exposes the ``route`` method. + + We bypass ``__init__`` (which builds a full Megatron MoE dispatcher) + and populate only the attributes read by ``route``: ``self.router`` + and ``self.config.num_moe_experts``.""" + obj = object.__new__(Gemma4MoELayer) + torch.nn.Module.__init__(obj) + cfg = _make_router_config(num_experts=6, top_k=2) + obj.router = Gemma4Router(cfg) + obj.config = cfg + return obj, cfg + + +def test_moe_route_packs_topk_into_dense_probs_and_routing_map(): + """route() must produce (probs [T, E], routing_map [T, E]) from the + compact (top_k_weights [T, K], top_k_index [T, K]) router output.""" + torch.manual_seed(3) + obj, cfg = _make_moe_route_stub() + h = torch.randn(4, cfg.hidden_size) + probs, routing_map = obj.route(h) + + T, E = 4, cfg.num_moe_experts + assert probs.shape == (T, E) + assert routing_map.shape == (T, E) + assert routing_map.dtype == torch.bool + + # Each row has exactly top_k nonzero entries; routing_map matches. + assert (probs != 0).sum(dim=-1).eq(cfg.moe_router_topk).all() + assert routing_map.eq(probs != 0).all() + + # Probs sum to the same total per row as the compact top-k weights (with + # default per_expert_scale=1, that's 1.0 per row). + expected_sums = probs.sum(dim=-1) + assert torch.allclose(expected_sums, torch.ones(T), atol=1e-6) + + +def test_moe_route_accepts_3d_input_by_flattening(): + """route() must flatten [S, B, H] (or any prefix dims) to [T, H] before + routing, so it works both in thd (2D) and [seq, batch, hidden] (3D) + layouts. Only the output's leading dimension is exercised here.""" + torch.manual_seed(4) + obj, cfg = _make_moe_route_stub() + h = torch.randn(3, 2, cfg.hidden_size) # [S, B, H] + probs, routing_map = obj.route(h) + # T = 3 * 2 = 6 + assert probs.shape == (6, cfg.num_moe_experts) + assert routing_map.shape == (6, cfg.num_moe_experts) + + +def _hf_reference_router(h, proj_w, scale, per_expert_scale, top_k, eps=1e-6): + """Reference implementation of the HF Gemma4 router equation: + + h_norm = rmsnorm_noscale(h) # no-learnable-scale RMSNorm + h_norm2 = h_norm * scale / sqrt(H) # per-hidden learnable scale + logits = proj_w @ h_norm2 # [T, E] + probs = softmax(logits) + top_w, top_i = topk(probs, k=top_k) + top_w = top_w / sum(top_w) # renormalize + top_w = top_w * per_expert_scale[top_i] # per-expert scale multiplier + + This closes the loop on what Gemma4Router computes: exercises every step + (RMSNorm without scale, per-hidden scale, proj, softmax, topk, renormalise, + per-expert scale) and guards against silent reordering of those ops in + future refactors. + """ + # RMSNorm (no scale), float-precision to match Gemma4Router.VNorm. + h = h.float() + norm = h * torch.pow(h.pow(2).mean(-1, keepdim=True) + eps, -0.5) + h_norm2 = norm * scale * (h.shape[-1] ** -0.5) + logits = torch.nn.functional.linear(h_norm2, proj_w) + probs = torch.softmax(logits, dim=-1) + top_w, top_i = torch.topk(probs, k=top_k, dim=-1) + top_w = top_w / top_w.sum(dim=-1, keepdim=True) + top_w = top_w * per_expert_scale[top_i] + return top_w, top_i + + +def test_router_matches_hf_reference_equation(): + """Gemma4Router.forward must produce the exact HF router output up to + kernel noise. This covers the full router equation.""" + torch.manual_seed(42) + cfg = _make_router_config(hidden_size=32, num_experts=8, top_k=2) + router = Gemma4Router(cfg) + # Use realistic, non-trivial weights. + with torch.no_grad(): + router.scale.copy_(torch.randn(cfg.hidden_size) * 0.1 + 1.0) + router.per_expert_scale.copy_(torch.randn(cfg.num_moe_experts) * 0.2 + 1.0) + + h = torch.randn(5, cfg.hidden_size) + w, idx = router(h) + w_ref, idx_ref = _hf_reference_router( + h, router.proj.weight, router.scale, router.per_expert_scale, + cfg.moe_router_topk, eps=cfg.layernorm_epsilon, + ) + + # topk indices must match exactly. + assert torch.equal(idx, idx_ref), ( + f"router top-k indices diverge: ours={idx}, ref={idx_ref}" + ) + assert torch.allclose(w.float(), w_ref, atol=1e-5), ( + "router top-k weights diverge from HF reference" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])