diff --git a/scripts/run-glm5-744B-A40B-sft.sh b/scripts/run-glm5-744B-A40B-sft.sh new file mode 100644 index 0000000000..9f1a1fe859 --- /dev/null +++ b/scripts/run-glm5-744B-A40B-sft.sh @@ -0,0 +1,159 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +if [ -z "${BASE_DIR}" ]; then + echo "BASE_DIR is not set. Please set it to the base directory of your checkpoints." + exit 1 +fi + +if [ -z "${MASTER_ADDR}" ]; then + echo "MASTER_ADDR is not set. Please set it to the master node address." + exit 1 +fi + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/glm5-744B-A40B.sh" + +CKPT_ARGS=( + --hf-checkpoint $BASE_DIR/GLM-5 + --ref-load $BASE_DIR/GLM-5_torch_dist/ + --load $BASE_DIR/GLM-5_slime/ + --save $BASE_DIR/GLM-5_slime/ + --save-interval 1000 +) + +SFT_ARGS=( + --rollout-function-path slime.rollout.sft_rollout.generate_rollout + --prompt-data $BASE_DIR/openhermes2_5.parquet + --input-key messages + --rollout-shuffle + --num-epoch 3 + --rollout-batch-size 128 + --global-batch-size 128 + + --loss-type sft_loss + --loss-mask-type glm5 + --calculate-per-token-loss + --disable-compute-advantages-and-returns + --debug-train-only +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 4 + --decoder-last-pipeline-num-layers 18 + --expert-model-parallel-size 32 + --expert-tensor-parallel-size 1 + --context-parallel-size 2 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 16384 + --data-pad-size-multiplier 4096 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style cosine + --min-lr 1e-6 + --lr-warmup-fraction 0.1 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project slime-dev + # --wandb-group glm5-sft + # --wandb-key ${WANDB_KEY} +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash + + # use deepep for megatron + --moe-enable-deepep + --moe-token-dispatcher-type flex +) + +# launch the master node of ray in container +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do + if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then + continue + fi + echo "Starting Ray worker on ${WORKER_IP}" + ssh root@"${WORKER_IP}" \ + "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" & +done +wait + + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"no_proxy\": \"${no_proxy}\", + \"MASTER_ADDR\": \"${MASTER_ADDR}\", + \"INDEXER_ROPE_NEOX_STYLE\": \"0\", + \"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK\": \"32\", + \"NVSHMEM_DISABLE_NCCL\": \"1\", + \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train_async.py \ + --actor-num-nodes 32 \ + --actor-num-gpus-per-node 8 \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${SFT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index a634d1f003..e7e899320d 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1252,7 +1252,7 @@ def add_rollout_buffer_arguments(parser): "--loss-mask-type", type=str, default="qwen", - choices=["qwen", "qwen3", "qwen3_5", "distill_qwen"], + choices=["qwen", "qwen3", "qwen3_5", "glm5", "distill_qwen"], help="Loss mask type", ) parser.add_argument( diff --git a/slime/utils/mask_utils.py b/slime/utils/mask_utils.py index efe5e159f1..277a60a60d 100644 --- a/slime/utils/mask_utils.py +++ b/slime/utils/mask_utils.py @@ -195,6 +195,100 @@ def gen_multi_turn_loss_mask_qwen3_5( return token_ids, loss_mask + def gen_multi_turn_loss_mask_glm5( + self, messages: list[dict], tools: list[dict] = None + ) -> tuple[list[int], list[int]]: + rendered_text = self.tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, return_dict=False) + template_token_ids = self.tokenizer.apply_chat_template( + messages, tokenize=True, tools=tools, return_dict=False + ) + + stop_words = ("<|endoftext|>", "<|user|>", "<|observation|>") + final_stop = "<|user|>" + append_final_stop = ( + bool(messages) + and messages[-1]["role"] == "assistant" + and messages[-1].get("step_loss_mask", 1) == 1 + and not rendered_text.endswith(stop_words) + ) + if append_final_stop: + rendered_text += final_stop + + tokenized = self.tokenizer(rendered_text, add_special_tokens=False, return_offsets_mapping=True) + token_ids = tokenized["input_ids"] + offset_mapping = tokenized.get("offset_mapping") + + if offset_mapping is None: + raise ValueError( + "GLM5 loss mask generation requires a fast tokenizer with `return_offsets_mapping` support." + ) + + if token_ids[: len(template_token_ids)] != template_token_ids: + raise ValueError( + "GLM5 rendered text tokenization does not match " + "`apply_chat_template(..., tokenize=True)` output." + ) + + assistant_header = "<|assistant|>" + stop_markers = ("<|endoftext|>", "<|user|>", "<|observation|>") + think_prefix = "" + no_think_prefix = "" + + char_mask = [0] * len(rendered_text) + cursor = 0 + + for message in messages: + if message["role"] != "assistant": + continue + + header_pos = rendered_text.find(assistant_header, cursor) + if header_pos < 0: + raise ValueError("Failed to locate assistant message in rendered GLM5 chat template output.") + + content_start = header_pos + len(assistant_header) + next_stop = min( + ( + (marker_pos, marker) + for marker in stop_markers + if (marker_pos := rendered_text.find(marker, content_start)) >= 0 + ), + default=None, + ) + span_end = next_stop[0] if next_stop is not None else len(rendered_text) + cursor = span_end + + if message.get("step_loss_mask", 1) != 1: + continue + + mask_start = content_start + while mask_start < span_end and rendered_text[mask_start].isspace(): + mask_start += 1 + + if rendered_text.startswith(think_prefix, mask_start): + mask_start += len(think_prefix) + elif rendered_text.startswith(no_think_prefix, mask_start): + mask_start += len(no_think_prefix) + + for pos in range(mask_start, span_end): + char_mask[pos] = 1 + if next_stop is not None: + stop_start, stop_marker = next_stop + for pos in range(stop_start, stop_start + len(stop_marker)): + char_mask[pos] = 1 + + char_mask_prefix_sum = [0] + for value in char_mask: + char_mask_prefix_sum.append(char_mask_prefix_sum[-1] + value) + + loss_mask = [] + for start, end in offset_mapping: + if end <= start: + loss_mask.append(0) + else: + loss_mask.append(1 if char_mask_prefix_sum[end] - char_mask_prefix_sum[start] > 0 else 0) + + return token_ids, loss_mask + def gen_multi_turn_loss_mask_distill_qwen( self, messages: list[dict], tools: list[dict] = None ) -> tuple[list[int], list[int]]: @@ -223,6 +317,8 @@ def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple return self.gen_multi_turn_loss_mask_qwen3(messages, tools) elif self.tokenizer_type == "qwen3_5": return self.gen_multi_turn_loss_mask_qwen3_5(messages, tools) + elif self.tokenizer_type == "glm5": + return self.gen_multi_turn_loss_mask_glm5(messages, tools) elif self.tokenizer_type == "distill_qwen": return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools) else: diff --git a/tests/utils/test_loss_mask_type_glm5.py b/tests/utils/test_loss_mask_type_glm5.py new file mode 100644 index 0000000000..0378a671eb --- /dev/null +++ b/tests/utils/test_loss_mask_type_glm5.py @@ -0,0 +1,210 @@ +from slime.utils.mask_utils import MultiTurnLossMaskGenerator + + +class FakeGLM5Tokenizer: + """A small char-level tokenizer that models GLM5 chat template boundaries.""" + + def __call__(self, text, add_special_tokens=False, return_offsets_mapping=False): + encoded = {"input_ids": [ord(ch) for ch in text]} + if return_offsets_mapping: + encoded["offset_mapping"] = [(index, index + 1) for index in range(len(text))] + return encoded + + def decode(self, token_ids): + return "".join(chr(token_id) for token_id in token_ids) + + def apply_chat_template( + self, + messages, + tokenize=True, + tools=None, + add_generation_prompt=False, + return_dict=False, + add_special_tokens=False, + **kwargs, + ): + rendered = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) + if tokenize: + return [ord(ch) for ch in rendered] + return rendered + + def render(self, messages, tools=None, add_generation_prompt=False): + rendered, _ = self.render_with_expected_mask( + messages, tools=tools, add_generation_prompt=add_generation_prompt + ) + return rendered + + def render_with_expected_mask(self, messages, tools=None, add_generation_prompt=False): + pieces = ["[gMASK]"] + mask = [0] * len(pieces[0]) + last_user_index = self._find_last_user_index(messages) + + if tools: + tools_text = "<|system|># Tools" + "".join(str(tool) for tool in tools) + pieces.append(tools_text) + mask.extend([0] * len(tools_text)) + + for index, message in enumerate(messages): + role = message["role"] + + if role == "system": + marker = "<|system|>" + piece = f"{marker}{self._visible_text(message['content'])}" + pieces.append(piece) + mask.extend(self._role_mask(marker, piece, messages, index)) + continue + + if role == "user": + marker = "<|user|>" + piece = f"{marker}{self._visible_text(message['content'])}" + pieces.append(piece) + mask.extend(self._role_mask(marker, piece, messages, index)) + continue + + if role == "tool": + marker = "<|observation|>" + piece = f"{marker}{self._visible_text(message['content'])}" + pieces.append(piece) + mask.extend(self._role_mask(marker, piece, messages, index)) + continue + + if role != "assistant": + raise NotImplementedError(f"Unsupported role in test tokenizer: {role}") + + prefix = "<|assistant|>" + pieces.append(prefix) + mask.extend([0] * len(prefix)) + + reasoning, content = self._split_assistant_content(self._visible_text(message.get("content", ""))) + if reasoning and index > last_user_index: + think_prefix = "" + piece = f"{think_prefix}{reasoning}" + pieces.append(piece) + if message.get("step_loss_mask", 1) == 1: + mask.extend([0] * len(think_prefix)) + mask.extend([1] * (len(piece) - len(think_prefix))) + else: + mask.extend([0] * len(piece)) + else: + no_think_prefix = "" + pieces.append(no_think_prefix) + mask.extend([0] * len(no_think_prefix)) + + if content.strip(): + piece = content.strip() + pieces.append(piece) + mask.extend([message.get("step_loss_mask", 1)] * len(piece)) + + tool_call_text = self._render_tool_calls(message.get("tool_calls")) + if tool_call_text: + pieces.append(tool_call_text) + mask.extend([message.get("step_loss_mask", 1)] * len(tool_call_text)) + + if add_generation_prompt: + piece = "<|assistant|>" + pieces.append(piece) + mask.extend([0] * len(piece)) + + return "".join(pieces), mask + + @staticmethod + def _visible_text(content): + if isinstance(content, str): + return content + if isinstance(content, list): + text = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text.append(item.get("text", "")) + elif isinstance(item, str): + text.append(item) + return "".join(text) + return "" if content is None else str(content) + + @staticmethod + def _split_assistant_content(content): + if "" not in content: + return "", content + reasoning = content.split("")[0].rstrip("\n").split("")[-1].lstrip("\n") + answer = content.split("")[-1].lstrip("\n") + return reasoning, answer + + @staticmethod + def _render_tool_calls(tool_calls): + if not tool_calls: + return "" + + pieces = [] + for tool_call in tool_calls: + function_call = tool_call.get("function", tool_call) + pieces.append(f"{function_call['name']}") + for key, value in function_call.get("arguments", {}).items(): + pieces.append(f"{key}{value}") + pieces.append("") + return "".join(pieces) + + @staticmethod + def _find_last_user_index(messages): + last_user_index = -1 + for index, message in enumerate(messages): + if message["role"] == "user": + last_user_index = index + return last_user_index + + @staticmethod + def _role_mask(marker, piece, messages, index): + if index > 0 and messages[index - 1]["role"] == "assistant" and messages[index - 1].get("step_loss_mask", 1) == 1: + return [1] * len(marker) + [0] * (len(piece) - len(marker)) + return [0] * len(piece) + + +def test_glm5_loss_mask_matches_multi_turn_rendering(): + tokenizer = FakeGLM5Tokenizer() + messages = [ + {"role": "system", "content": "SYSTEM"}, + {"role": "user", "content": "USER_1"}, + {"role": "assistant", "content": "OLD_REASONING\nANSWER_1"}, + {"role": "user", "content": "USER_2"}, + {"role": "assistant", "content": "REASONING_2\nANSWER_2"}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages) + expected_text += "<|user|>" + expected_mask += [1] * len("<|user|>") + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + assert generator.get_text_from_loss_mask(token_ids, loss_mask) == [ + "ANSWER_1<|user|>", + "REASONING_2ANSWER_2<|user|>", + ] + + +def test_glm5_loss_mask_handles_tool_calls_and_step_loss_mask(): + tokenizer = FakeGLM5Tokenizer() + messages = [ + {"role": "user", "content": "USER"}, + { + "role": "assistant", + "content": "CALL", + "tool_calls": [{"function": {"name": "terminal", "arguments": {"command": "ls"}}}], + }, + {"role": "tool", "content": "README.md"}, + {"role": "assistant", "content": "FINAL", "step_loss_mask": 0}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages) + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + assert generator.get_text_from_loss_mask(token_ids, loss_mask) == [ + "CALLterminalcommandls<|observation|>", + ]