diff --git a/scripts/run-glm5-744B-A40B-sft.sh b/scripts/run-glm5-744B-A40B-sft.sh
new file mode 100644
index 0000000000..9f1a1fe859
--- /dev/null
+++ b/scripts/run-glm5-744B-A40B-sft.sh
@@ -0,0 +1,159 @@
+#!/bin/bash
+
+# for rerun the task
+pkill -9 sglang
+sleep 3
+ray stop --force
+pkill -9 ray
+pkill -9 python
+sleep 3
+pkill -9 ray
+pkill -9 python
+
+set -ex
+
+if [ -z "${BASE_DIR}" ]; then
+ echo "BASE_DIR is not set. Please set it to the base directory of your checkpoints."
+ exit 1
+fi
+
+if [ -z "${MASTER_ADDR}" ]; then
+ echo "MASTER_ADDR is not set. Please set it to the master node address."
+ exit 1
+fi
+
+# will prevent ray from buffering stdout/stderr
+export PYTHONBUFFERED=16
+
+NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
+if [ "$NVLINK_COUNT" -gt 0 ]; then
+ HAS_NVLINK=1
+else
+ HAS_NVLINK=0
+fi
+echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
+
+SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
+source "${SCRIPT_DIR}/models/glm5-744B-A40B.sh"
+
+CKPT_ARGS=(
+ --hf-checkpoint $BASE_DIR/GLM-5
+ --ref-load $BASE_DIR/GLM-5_torch_dist/
+ --load $BASE_DIR/GLM-5_slime/
+ --save $BASE_DIR/GLM-5_slime/
+ --save-interval 1000
+)
+
+SFT_ARGS=(
+ --rollout-function-path slime.rollout.sft_rollout.generate_rollout
+ --prompt-data $BASE_DIR/openhermes2_5.parquet
+ --input-key messages
+ --rollout-shuffle
+ --num-epoch 3
+ --rollout-batch-size 128
+ --global-batch-size 128
+
+ --loss-type sft_loss
+ --loss-mask-type glm5
+ --calculate-per-token-loss
+ --disable-compute-advantages-and-returns
+ --debug-train-only
+)
+
+PERF_ARGS=(
+ --tensor-model-parallel-size 4
+ --sequence-parallel
+ --pipeline-model-parallel-size 4
+ --decoder-last-pipeline-num-layers 18
+ --expert-model-parallel-size 32
+ --expert-tensor-parallel-size 1
+ --context-parallel-size 2
+
+ --recompute-granularity full
+ --recompute-method uniform
+ --recompute-num-layers 1
+
+ --use-dynamic-batch-size
+ --max-tokens-per-gpu 16384
+ --data-pad-size-multiplier 4096
+)
+
+OPTIMIZER_ARGS=(
+ --optimizer adam
+ --lr 1e-5
+ --lr-decay-style cosine
+ --min-lr 1e-6
+ --lr-warmup-fraction 0.1
+ --weight-decay 0.1
+ --adam-beta1 0.9
+ --adam-beta2 0.98
+
+ --optimizer-cpu-offload
+ --overlap-cpu-optimizer-d2h-h2d
+ --use-precision-aware-optimizer
+)
+
+WANDB_ARGS=(
+ # --use-wandb
+ # --wandb-project slime-dev
+ # --wandb-group glm5-sft
+ # --wandb-key ${WANDB_KEY}
+)
+
+MISC_ARGS=(
+ # default dropout in megatron is 0.1
+ --attention-dropout 0.0
+ --hidden-dropout 0.0
+ # should be good for model performance
+ --accumulate-allreduce-grads-in-fp32
+ --attention-softmax-in-fp32
+ # need to comment this when using model with MLA
+ --attention-backend flash
+
+ # use deepep for megatron
+ --moe-enable-deepep
+ --moe-token-dispatcher-type flex
+)
+
+# launch the master node of ray in container
+export no_proxy="127.0.0.1,${MASTER_ADDR}"
+ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
+for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do
+ if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then
+ continue
+ fi
+ echo "Starting Ray worker on ${WORKER_IP}"
+ ssh root@"${WORKER_IP}" \
+ "pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" &
+done
+wait
+
+
+# Build the runtime environment JSON with proper variable substitution
+RUNTIME_ENV_JSON="{
+ \"env_vars\": {
+ \"PYTHONPATH\": \"/root/Megatron-LM/\",
+ \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
+ \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
+ \"no_proxy\": \"${no_proxy}\",
+ \"MASTER_ADDR\": \"${MASTER_ADDR}\",
+ \"INDEXER_ROPE_NEOX_STYLE\": \"0\",
+ \"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK\": \"32\",
+ \"NVSHMEM_DISABLE_NCCL\": \"1\",
+ \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\"
+ }
+}"
+
+ray job submit --address="http://127.0.0.1:8265" \
+ --runtime-env-json="${RUNTIME_ENV_JSON}" \
+ -- python3 train_async.py \
+ --actor-num-nodes 32 \
+ --actor-num-gpus-per-node 8 \
+ ${MODEL_ARGS[@]} \
+ ${CKPT_ARGS[@]} \
+ ${SFT_ARGS[@]} \
+ ${OPTIMIZER_ARGS[@]} \
+ ${WANDB_ARGS[@]} \
+ ${PERF_ARGS[@]} \
+ ${EVAL_ARGS[@]} \
+ ${MISC_ARGS[@]}
diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py
index a634d1f003..e7e899320d 100644
--- a/slime/utils/arguments.py
+++ b/slime/utils/arguments.py
@@ -1252,7 +1252,7 @@ def add_rollout_buffer_arguments(parser):
"--loss-mask-type",
type=str,
default="qwen",
- choices=["qwen", "qwen3", "qwen3_5", "distill_qwen"],
+ choices=["qwen", "qwen3", "qwen3_5", "glm5", "distill_qwen"],
help="Loss mask type",
)
parser.add_argument(
diff --git a/slime/utils/mask_utils.py b/slime/utils/mask_utils.py
index efe5e159f1..277a60a60d 100644
--- a/slime/utils/mask_utils.py
+++ b/slime/utils/mask_utils.py
@@ -195,6 +195,100 @@ def gen_multi_turn_loss_mask_qwen3_5(
return token_ids, loss_mask
+ def gen_multi_turn_loss_mask_glm5(
+ self, messages: list[dict], tools: list[dict] = None
+ ) -> tuple[list[int], list[int]]:
+ rendered_text = self.tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, return_dict=False)
+ template_token_ids = self.tokenizer.apply_chat_template(
+ messages, tokenize=True, tools=tools, return_dict=False
+ )
+
+ stop_words = ("<|endoftext|>", "<|user|>", "<|observation|>")
+ final_stop = "<|user|>"
+ append_final_stop = (
+ bool(messages)
+ and messages[-1]["role"] == "assistant"
+ and messages[-1].get("step_loss_mask", 1) == 1
+ and not rendered_text.endswith(stop_words)
+ )
+ if append_final_stop:
+ rendered_text += final_stop
+
+ tokenized = self.tokenizer(rendered_text, add_special_tokens=False, return_offsets_mapping=True)
+ token_ids = tokenized["input_ids"]
+ offset_mapping = tokenized.get("offset_mapping")
+
+ if offset_mapping is None:
+ raise ValueError(
+ "GLM5 loss mask generation requires a fast tokenizer with `return_offsets_mapping` support."
+ )
+
+ if token_ids[: len(template_token_ids)] != template_token_ids:
+ raise ValueError(
+ "GLM5 rendered text tokenization does not match "
+ "`apply_chat_template(..., tokenize=True)` output."
+ )
+
+ assistant_header = "<|assistant|>"
+ stop_markers = ("<|endoftext|>", "<|user|>", "<|observation|>")
+ think_prefix = ""
+ no_think_prefix = ""
+
+ char_mask = [0] * len(rendered_text)
+ cursor = 0
+
+ for message in messages:
+ if message["role"] != "assistant":
+ continue
+
+ header_pos = rendered_text.find(assistant_header, cursor)
+ if header_pos < 0:
+ raise ValueError("Failed to locate assistant message in rendered GLM5 chat template output.")
+
+ content_start = header_pos + len(assistant_header)
+ next_stop = min(
+ (
+ (marker_pos, marker)
+ for marker in stop_markers
+ if (marker_pos := rendered_text.find(marker, content_start)) >= 0
+ ),
+ default=None,
+ )
+ span_end = next_stop[0] if next_stop is not None else len(rendered_text)
+ cursor = span_end
+
+ if message.get("step_loss_mask", 1) != 1:
+ continue
+
+ mask_start = content_start
+ while mask_start < span_end and rendered_text[mask_start].isspace():
+ mask_start += 1
+
+ if rendered_text.startswith(think_prefix, mask_start):
+ mask_start += len(think_prefix)
+ elif rendered_text.startswith(no_think_prefix, mask_start):
+ mask_start += len(no_think_prefix)
+
+ for pos in range(mask_start, span_end):
+ char_mask[pos] = 1
+ if next_stop is not None:
+ stop_start, stop_marker = next_stop
+ for pos in range(stop_start, stop_start + len(stop_marker)):
+ char_mask[pos] = 1
+
+ char_mask_prefix_sum = [0]
+ for value in char_mask:
+ char_mask_prefix_sum.append(char_mask_prefix_sum[-1] + value)
+
+ loss_mask = []
+ for start, end in offset_mapping:
+ if end <= start:
+ loss_mask.append(0)
+ else:
+ loss_mask.append(1 if char_mask_prefix_sum[end] - char_mask_prefix_sum[start] > 0 else 0)
+
+ return token_ids, loss_mask
+
def gen_multi_turn_loss_mask_distill_qwen(
self, messages: list[dict], tools: list[dict] = None
) -> tuple[list[int], list[int]]:
@@ -223,6 +317,8 @@ def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple
return self.gen_multi_turn_loss_mask_qwen3(messages, tools)
elif self.tokenizer_type == "qwen3_5":
return self.gen_multi_turn_loss_mask_qwen3_5(messages, tools)
+ elif self.tokenizer_type == "glm5":
+ return self.gen_multi_turn_loss_mask_glm5(messages, tools)
elif self.tokenizer_type == "distill_qwen":
return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools)
else:
diff --git a/tests/utils/test_loss_mask_type_glm5.py b/tests/utils/test_loss_mask_type_glm5.py
new file mode 100644
index 0000000000..0378a671eb
--- /dev/null
+++ b/tests/utils/test_loss_mask_type_glm5.py
@@ -0,0 +1,210 @@
+from slime.utils.mask_utils import MultiTurnLossMaskGenerator
+
+
+class FakeGLM5Tokenizer:
+ """A small char-level tokenizer that models GLM5 chat template boundaries."""
+
+ def __call__(self, text, add_special_tokens=False, return_offsets_mapping=False):
+ encoded = {"input_ids": [ord(ch) for ch in text]}
+ if return_offsets_mapping:
+ encoded["offset_mapping"] = [(index, index + 1) for index in range(len(text))]
+ return encoded
+
+ def decode(self, token_ids):
+ return "".join(chr(token_id) for token_id in token_ids)
+
+ def apply_chat_template(
+ self,
+ messages,
+ tokenize=True,
+ tools=None,
+ add_generation_prompt=False,
+ return_dict=False,
+ add_special_tokens=False,
+ **kwargs,
+ ):
+ rendered = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt)
+ if tokenize:
+ return [ord(ch) for ch in rendered]
+ return rendered
+
+ def render(self, messages, tools=None, add_generation_prompt=False):
+ rendered, _ = self.render_with_expected_mask(
+ messages, tools=tools, add_generation_prompt=add_generation_prompt
+ )
+ return rendered
+
+ def render_with_expected_mask(self, messages, tools=None, add_generation_prompt=False):
+ pieces = ["[gMASK]"]
+ mask = [0] * len(pieces[0])
+ last_user_index = self._find_last_user_index(messages)
+
+ if tools:
+ tools_text = "<|system|># Tools" + "".join(str(tool) for tool in tools)
+ pieces.append(tools_text)
+ mask.extend([0] * len(tools_text))
+
+ for index, message in enumerate(messages):
+ role = message["role"]
+
+ if role == "system":
+ marker = "<|system|>"
+ piece = f"{marker}{self._visible_text(message['content'])}"
+ pieces.append(piece)
+ mask.extend(self._role_mask(marker, piece, messages, index))
+ continue
+
+ if role == "user":
+ marker = "<|user|>"
+ piece = f"{marker}{self._visible_text(message['content'])}"
+ pieces.append(piece)
+ mask.extend(self._role_mask(marker, piece, messages, index))
+ continue
+
+ if role == "tool":
+ marker = "<|observation|>"
+ piece = f"{marker}{self._visible_text(message['content'])}"
+ pieces.append(piece)
+ mask.extend(self._role_mask(marker, piece, messages, index))
+ continue
+
+ if role != "assistant":
+ raise NotImplementedError(f"Unsupported role in test tokenizer: {role}")
+
+ prefix = "<|assistant|>"
+ pieces.append(prefix)
+ mask.extend([0] * len(prefix))
+
+ reasoning, content = self._split_assistant_content(self._visible_text(message.get("content", "")))
+ if reasoning and index > last_user_index:
+ think_prefix = ""
+ piece = f"{think_prefix}{reasoning}"
+ pieces.append(piece)
+ if message.get("step_loss_mask", 1) == 1:
+ mask.extend([0] * len(think_prefix))
+ mask.extend([1] * (len(piece) - len(think_prefix)))
+ else:
+ mask.extend([0] * len(piece))
+ else:
+ no_think_prefix = ""
+ pieces.append(no_think_prefix)
+ mask.extend([0] * len(no_think_prefix))
+
+ if content.strip():
+ piece = content.strip()
+ pieces.append(piece)
+ mask.extend([message.get("step_loss_mask", 1)] * len(piece))
+
+ tool_call_text = self._render_tool_calls(message.get("tool_calls"))
+ if tool_call_text:
+ pieces.append(tool_call_text)
+ mask.extend([message.get("step_loss_mask", 1)] * len(tool_call_text))
+
+ if add_generation_prompt:
+ piece = "<|assistant|>"
+ pieces.append(piece)
+ mask.extend([0] * len(piece))
+
+ return "".join(pieces), mask
+
+ @staticmethod
+ def _visible_text(content):
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ text = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ text.append(item.get("text", ""))
+ elif isinstance(item, str):
+ text.append(item)
+ return "".join(text)
+ return "" if content is None else str(content)
+
+ @staticmethod
+ def _split_assistant_content(content):
+ if "" not in content:
+ return "", content
+ reasoning = content.split("")[0].rstrip("\n").split("")[-1].lstrip("\n")
+ answer = content.split("")[-1].lstrip("\n")
+ return reasoning, answer
+
+ @staticmethod
+ def _render_tool_calls(tool_calls):
+ if not tool_calls:
+ return ""
+
+ pieces = []
+ for tool_call in tool_calls:
+ function_call = tool_call.get("function", tool_call)
+ pieces.append(f"{function_call['name']}")
+ for key, value in function_call.get("arguments", {}).items():
+ pieces.append(f"{key}{value}")
+ pieces.append("")
+ return "".join(pieces)
+
+ @staticmethod
+ def _find_last_user_index(messages):
+ last_user_index = -1
+ for index, message in enumerate(messages):
+ if message["role"] == "user":
+ last_user_index = index
+ return last_user_index
+
+ @staticmethod
+ def _role_mask(marker, piece, messages, index):
+ if index > 0 and messages[index - 1]["role"] == "assistant" and messages[index - 1].get("step_loss_mask", 1) == 1:
+ return [1] * len(marker) + [0] * (len(piece) - len(marker))
+ return [0] * len(piece)
+
+
+def test_glm5_loss_mask_matches_multi_turn_rendering():
+ tokenizer = FakeGLM5Tokenizer()
+ messages = [
+ {"role": "system", "content": "SYSTEM"},
+ {"role": "user", "content": "USER_1"},
+ {"role": "assistant", "content": "OLD_REASONING\nANSWER_1"},
+ {"role": "user", "content": "USER_2"},
+ {"role": "assistant", "content": "REASONING_2\nANSWER_2"},
+ ]
+
+ expected_text, expected_mask = tokenizer.render_with_expected_mask(messages)
+ expected_text += "<|user|>"
+ expected_mask += [1] * len("<|user|>")
+ expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"]
+
+ generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5")
+ token_ids, loss_mask = generator.get_loss_mask(messages)
+
+ assert token_ids == expected_token_ids
+ assert loss_mask == expected_mask
+ assert generator.get_text_from_loss_mask(token_ids, loss_mask) == [
+ "ANSWER_1<|user|>",
+ "REASONING_2ANSWER_2<|user|>",
+ ]
+
+
+def test_glm5_loss_mask_handles_tool_calls_and_step_loss_mask():
+ tokenizer = FakeGLM5Tokenizer()
+ messages = [
+ {"role": "user", "content": "USER"},
+ {
+ "role": "assistant",
+ "content": "CALL",
+ "tool_calls": [{"function": {"name": "terminal", "arguments": {"command": "ls"}}}],
+ },
+ {"role": "tool", "content": "README.md"},
+ {"role": "assistant", "content": "FINAL", "step_loss_mask": 0},
+ ]
+
+ expected_text, expected_mask = tokenizer.render_with_expected_mask(messages)
+ expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"]
+
+ generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5")
+ token_ids, loss_mask = generator.get_loss_mask(messages)
+
+ assert token_ids == expected_token_ids
+ assert loss_mask == expected_mask
+ assert generator.get_text_from_loss_mask(token_ids, loss_mask) == [
+ "CALLterminalcommandls<|observation|>",
+ ]