Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions slime/backends/megatron_utils/arguments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ast
import logging

from megatron.training.arguments import parse_args as _megatron_parse_args
from megatron.training.arguments import validate_args as _megatron_validate_args
from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
Expand All @@ -9,6 +11,35 @@
logger = logging.getLogger(__name__)


def _has_dense_moe_layers(args):
moe_layer_freq = getattr(args, "moe_layer_freq", None)
if moe_layer_freq is None:
return True

if isinstance(moe_layer_freq, str):
try:
moe_layer_freq = ast.literal_eval(moe_layer_freq)
except (SyntaxError, ValueError):
return "0" in moe_layer_freq

try:
return any(int(layer_freq) == 0 for layer_freq in moe_layer_freq)
except TypeError:
return int(moe_layer_freq) == 0


def _is_moe_config(hf_config):
return any(
hasattr(hf_config, attr)
for attr in (
"moe_intermediate_size",
"num_experts",
"n_routed_experts",
"num_local_experts",
)
)


def validate_args(args):
"""Run megatron's own validate_args plus slime-specific megatron validations."""

Expand Down Expand Up @@ -49,15 +80,22 @@ def equal(x, y):
else:
_hf_rope_theta = getattr(hf_config, "rope_theta", None)

validate_dense_ffn = not _is_moe_config(hf_config) or _has_dense_moe_layers(args)

for hf_config_name, megatron_config_name, compare_fn in [
("hidden_size", "hidden_size", equal),
("num_attention_heads", "num_attention_heads", equal),
("num_hidden_layers", "num_layers", equal),
("intermediate_size", "ffn_hidden_size", equal),
("moe_intermediate_size", "moe_ffn_hidden_size", equal),
("shared_expert_intermediate_size", "moe_shared_expert_intermediate_size", equal),
("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y),
("rms_norm_eps", "norm_epsilon", equal),
("rms_norm_eps", "layernorm_epsilon", equal),
]:
if hf_config_name == "intermediate_size" and not validate_dense_ffn:
continue

if hasattr(hf_config, hf_config_name) and hasattr(args, megatron_config_name):
if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)):
errors.append(
Expand Down
Loading