Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
'always': True,
'cpu': True,
'tests': [
{'test_file': 'test_megatron_argument_validation.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_path_loading_contracts.py', 'num_gpus': 0},
Expand Down
33 changes: 33 additions & 0 deletions slime/backends/megatron_utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@
logger = logging.getLogger(__name__)


_ALLGATHER_CP_DSA_ARCHITECTURES = {
"DeepseekV32ForCausalLM",
"GlmMoeDsaForCausalLM",
}


def _is_allgather_cp_dsa_model(hf_config):
if hf_config is None:
return False

architecture_names = getattr(hf_config, "architectures", None) or []
return any(name in _ALLGATHER_CP_DSA_ARCHITECTURES for name in architecture_names)


def _validate_allgather_cp_supported(args, hf_config=None):
if not getattr(args, "allgather_cp", False) or getattr(args, "context_parallel_size", 1) <= 1:
return

if _is_allgather_cp_dsa_model(hf_config):
return

raise ValueError(
"--allgather-cp with --context-parallel-size > 1 is currently only supported for "
"DSA attention models (DeepSeek-V3.2 and GLM-5.1). Non-DSA models still use the "
"zigzag CP layout and would silently scramble token order under allgather CP. "
"Please remove --allgather-cp, set --context-parallel-size 1, or use a supported DSA model."
)


def _has_dense_moe_layers(args):
moe_layer_freq = getattr(args, "moe_layer_freq", None)
if moe_layer_freq is None:
Expand Down Expand Up @@ -151,10 +180,14 @@ def megatron_parse_args(extra_args_provider, skip_hf_validate=False):
"""Parse megatron args, validate HF config, and set defaults."""
args = _megatron_parse_args(extra_args_provider=extra_args_provider, ignore_unknown_args=True)

hf_config = None
if args.hf_checkpoint and not skip_hf_validate:
hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
_hf_validate_args(args, hf_config)

if not skip_hf_validate:
_validate_allgather_cp_supported(args, hf_config)

args.rank = 0
args.world_size = args.actor_num_nodes * args.actor_num_gpus_per_node
args = _set_default_megatron_args(args)
Expand Down
140 changes: 140 additions & 0 deletions tests/test_megatron_argument_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import importlib.util
import sys
import types
from pathlib import Path

import pytest


def load_arguments_module(monkeypatch):
megatron_mod = types.ModuleType("megatron")
training_mod = types.ModuleType("megatron.training")
arguments_mod = types.ModuleType("megatron.training.arguments")
tokenizer_pkg_mod = types.ModuleType("megatron.training.tokenizer")
tokenizer_mod = types.ModuleType("megatron.training.tokenizer.tokenizer")
transformers_mod = types.ModuleType("transformers")

arguments_mod.parse_args = lambda *args, **kwargs: None
arguments_mod.validate_args = lambda args: args
tokenizer_mod._vocab_size_with_padding = lambda vocab_size, _args: vocab_size
transformers_mod.AutoConfig = types.SimpleNamespace(from_pretrained=lambda *args, **kwargs: None)

monkeypatch.setitem(sys.modules, "megatron", megatron_mod)
monkeypatch.setitem(sys.modules, "megatron.training", training_mod)
monkeypatch.setitem(sys.modules, "megatron.training.arguments", arguments_mod)
monkeypatch.setitem(sys.modules, "megatron.training.tokenizer", tokenizer_pkg_mod)
monkeypatch.setitem(sys.modules, "megatron.training.tokenizer.tokenizer", tokenizer_mod)
monkeypatch.setitem(sys.modules, "transformers", transformers_mod)

module_path = Path(__file__).resolve().parents[1] / "slime" / "backends" / "megatron_utils" / "arguments.py"
module_name = "test_megatron_argument_validation_module"
sys.modules.pop(module_name, None)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module


def make_qwen3_6_args(**overrides):
values = dict(
hidden_size=2048,
num_attention_heads=16,
num_layers=40,
ffn_hidden_size=512,
moe_ffn_hidden_size=512,
moe_shared_expert_intermediate_size=512,
moe_layer_freq=[1] * 40,
untie_embeddings_and_output_weights=True,
norm_epsilon=1e-6,
layernorm_epsilon=1e-6,
rotary_base=10000000,
)
values.update(overrides)
return types.SimpleNamespace(**values)


def make_qwen3_6_hf_config():
text_config = types.SimpleNamespace(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=40,
intermediate_size=5632,
moe_intermediate_size=512,
shared_expert_intermediate_size=512,
num_experts=256,
tie_word_embeddings=False,
rms_norm_eps=1e-6,
rope_parameters={"rope_theta": 10000000},
)
return types.SimpleNamespace(text_config=text_config)


def make_allgather_cp_args(**overrides):
values = dict(
allgather_cp=True,
context_parallel_size=2,
)
values.update(overrides)
return types.SimpleNamespace(**values)


@pytest.mark.unit
def test_hf_validate_all_moe_skips_dense_intermediate_size(monkeypatch):
module = load_arguments_module(monkeypatch)

module._hf_validate_args(make_qwen3_6_args(), make_qwen3_6_hf_config())


@pytest.mark.unit
def test_hf_validate_checks_moe_intermediate_size(monkeypatch):
module = load_arguments_module(monkeypatch)

with pytest.raises(AssertionError, match="moe_intermediate_size"):
module._hf_validate_args(make_qwen3_6_args(moe_ffn_hidden_size=256), make_qwen3_6_hf_config())


@pytest.mark.unit
def test_hf_validate_checks_dense_intermediate_size_when_moe_has_dense_layers(monkeypatch):
module = load_arguments_module(monkeypatch)

args = make_qwen3_6_args(moe_layer_freq=[0] + [1] * 39)

with pytest.raises(AssertionError, match="intermediate_size"):
module._hf_validate_args(args, make_qwen3_6_hf_config())


@pytest.mark.unit
def test_allgather_cp_rejects_non_dsa_cp_models(monkeypatch):
module = load_arguments_module(monkeypatch)
args = make_allgather_cp_args()
hf_config = types.SimpleNamespace(architectures=["Qwen3ForCausalLM"], model_type="qwen3")

with pytest.raises(ValueError, match="only supported for DSA attention models"):
module._validate_allgather_cp_supported(args, hf_config)


@pytest.mark.unit
@pytest.mark.parametrize(
"hf_config",
[
types.SimpleNamespace(architectures=["DeepseekV32ForCausalLM"], model_type="deepseek_v3"),
types.SimpleNamespace(architectures=["GlmMoeDsaForCausalLM"], model_type="glm"),
],
)
def test_allgather_cp_allows_dsa_architectures(monkeypatch, hf_config):
module = load_arguments_module(monkeypatch)

module._validate_allgather_cp_supported(make_allgather_cp_args(), hf_config)


@pytest.mark.unit
def test_allgather_cp_ignores_cp_size_one(monkeypatch):
module = load_arguments_module(monkeypatch)
args = make_allgather_cp_args(context_parallel_size=1)

module._validate_allgather_cp_supported(args)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading