diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index fc78dd348..8cb815c1f 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -450,7 +450,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 509a3d19d..50fe47227 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -54,6 +54,7 @@ 'always': True, 'cpu': True, 'tests': [ + {'test_file': 'test_megatron_argument_validation.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_path_loading_contracts.py', 'num_gpus': 0}, diff --git a/slime/backends/megatron_utils/arguments.py b/slime/backends/megatron_utils/arguments.py index b93198689..ab3f9efae 100644 --- a/slime/backends/megatron_utils/arguments.py +++ b/slime/backends/megatron_utils/arguments.py @@ -11,6 +11,35 @@ logger = logging.getLogger(__name__) +_ALLGATHER_CP_DSA_ARCHITECTURES = { + "DeepseekV32ForCausalLM", + "GlmMoeDsaForCausalLM", +} + + +def _is_allgather_cp_dsa_model(hf_config): + if hf_config is None: + return False + + architecture_names = getattr(hf_config, "architectures", None) or [] + return any(name in _ALLGATHER_CP_DSA_ARCHITECTURES for name in architecture_names) + + +def _validate_allgather_cp_supported(args, hf_config=None): + if not getattr(args, "allgather_cp", False) or getattr(args, "context_parallel_size", 1) <= 1: + return + + if _is_allgather_cp_dsa_model(hf_config): + return + + raise ValueError( + "--allgather-cp with --context-parallel-size > 1 is currently only supported for " + "DSA attention models (DeepSeek-V3.2 and GLM-5.1). Non-DSA models still use the " + "zigzag CP layout and would silently scramble token order under allgather CP. " + "Please remove --allgather-cp, set --context-parallel-size 1, or use a supported DSA model." + ) + + def _has_dense_moe_layers(args): moe_layer_freq = getattr(args, "moe_layer_freq", None) if moe_layer_freq is None: @@ -151,10 +180,14 @@ def megatron_parse_args(extra_args_provider, skip_hf_validate=False): """Parse megatron args, validate HF config, and set defaults.""" args = _megatron_parse_args(extra_args_provider=extra_args_provider, ignore_unknown_args=True) + hf_config = None if args.hf_checkpoint and not skip_hf_validate: hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) _hf_validate_args(args, hf_config) + if not skip_hf_validate: + _validate_allgather_cp_supported(args, hf_config) + args.rank = 0 args.world_size = args.actor_num_nodes * args.actor_num_gpus_per_node args = _set_default_megatron_args(args) diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py new file mode 100644 index 000000000..83e8390ee --- /dev/null +++ b/tests/test_megatron_argument_validation.py @@ -0,0 +1,140 @@ +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + + +def load_arguments_module(monkeypatch): + megatron_mod = types.ModuleType("megatron") + training_mod = types.ModuleType("megatron.training") + arguments_mod = types.ModuleType("megatron.training.arguments") + tokenizer_pkg_mod = types.ModuleType("megatron.training.tokenizer") + tokenizer_mod = types.ModuleType("megatron.training.tokenizer.tokenizer") + transformers_mod = types.ModuleType("transformers") + + arguments_mod.parse_args = lambda *args, **kwargs: None + arguments_mod.validate_args = lambda args: args + tokenizer_mod._vocab_size_with_padding = lambda vocab_size, _args: vocab_size + transformers_mod.AutoConfig = types.SimpleNamespace(from_pretrained=lambda *args, **kwargs: None) + + monkeypatch.setitem(sys.modules, "megatron", megatron_mod) + monkeypatch.setitem(sys.modules, "megatron.training", training_mod) + monkeypatch.setitem(sys.modules, "megatron.training.arguments", arguments_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer", tokenizer_pkg_mod) + monkeypatch.setitem(sys.modules, "megatron.training.tokenizer.tokenizer", tokenizer_mod) + monkeypatch.setitem(sys.modules, "transformers", transformers_mod) + + module_path = Path(__file__).resolve().parents[1] / "slime" / "backends" / "megatron_utils" / "arguments.py" + module_name = "test_megatron_argument_validation_module" + sys.modules.pop(module_name, None) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def make_qwen3_6_args(**overrides): + values = dict( + hidden_size=2048, + num_attention_heads=16, + num_layers=40, + ffn_hidden_size=512, + moe_ffn_hidden_size=512, + moe_shared_expert_intermediate_size=512, + moe_layer_freq=[1] * 40, + untie_embeddings_and_output_weights=True, + norm_epsilon=1e-6, + layernorm_epsilon=1e-6, + rotary_base=10000000, + ) + values.update(overrides) + return types.SimpleNamespace(**values) + + +def make_qwen3_6_hf_config(): + text_config = types.SimpleNamespace( + hidden_size=2048, + num_attention_heads=16, + num_hidden_layers=40, + intermediate_size=5632, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts=256, + tie_word_embeddings=False, + rms_norm_eps=1e-6, + rope_parameters={"rope_theta": 10000000}, + ) + return types.SimpleNamespace(text_config=text_config) + + +def make_allgather_cp_args(**overrides): + values = dict( + allgather_cp=True, + context_parallel_size=2, + ) + values.update(overrides) + return types.SimpleNamespace(**values) + + +@pytest.mark.unit +def test_hf_validate_all_moe_skips_dense_intermediate_size(monkeypatch): + module = load_arguments_module(monkeypatch) + + module._hf_validate_args(make_qwen3_6_args(), make_qwen3_6_hf_config()) + + +@pytest.mark.unit +def test_hf_validate_checks_moe_intermediate_size(monkeypatch): + module = load_arguments_module(monkeypatch) + + with pytest.raises(AssertionError, match="moe_intermediate_size"): + module._hf_validate_args(make_qwen3_6_args(moe_ffn_hidden_size=256), make_qwen3_6_hf_config()) + + +@pytest.mark.unit +def test_hf_validate_checks_dense_intermediate_size_when_moe_has_dense_layers(monkeypatch): + module = load_arguments_module(monkeypatch) + + args = make_qwen3_6_args(moe_layer_freq=[0] + [1] * 39) + + with pytest.raises(AssertionError, match="intermediate_size"): + module._hf_validate_args(args, make_qwen3_6_hf_config()) + + +@pytest.mark.unit +def test_allgather_cp_rejects_non_dsa_cp_models(monkeypatch): + module = load_arguments_module(monkeypatch) + args = make_allgather_cp_args() + hf_config = types.SimpleNamespace(architectures=["Qwen3ForCausalLM"], model_type="qwen3") + + with pytest.raises(ValueError, match="only supported for DSA attention models"): + module._validate_allgather_cp_supported(args, hf_config) + + +@pytest.mark.unit +@pytest.mark.parametrize( + "hf_config", + [ + types.SimpleNamespace(architectures=["DeepseekV32ForCausalLM"], model_type="deepseek_v3"), + types.SimpleNamespace(architectures=["GlmMoeDsaForCausalLM"], model_type="glm"), + ], +) +def test_allgather_cp_allows_dsa_architectures(monkeypatch, hf_config): + module = load_arguments_module(monkeypatch) + + module._validate_allgather_cp_supported(make_allgather_cp_args(), hf_config) + + +@pytest.mark.unit +def test_allgather_cp_ignores_cp_size_one(monkeypatch): + module = load_arguments_module(monkeypatch) + args = make_allgather_cp_args(context_parallel_size=1) + + module._validate_allgather_cp_supported(args) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__]))