From 617ffdd1a13f61b74b676b30dbb7d0b3657141d6 Mon Sep 17 00:00:00 2001 From: Xuan Wang Date: Thu, 9 Apr 2026 17:49:56 -0700 Subject: [PATCH] feat: add GLM-5 SFT loss mask support (--loss-mask-type glm5) --- slime/utils/arguments.py | 2 +- slime/utils/mask_utils.py | 94 ++++++++ tests/utils/test_loss_mask_type_glm5.py | 305 ++++++++++++++++++++++++ 3 files changed, 400 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_loss_mask_type_glm5.py diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index a634d1f003..b5ab5e90cc 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1252,7 +1252,7 @@ def add_rollout_buffer_arguments(parser): "--loss-mask-type", type=str, default="qwen", - choices=["qwen", "qwen3", "qwen3_5", "distill_qwen"], + choices=["qwen", "qwen3", "qwen3_5", "distill_qwen", "glm5"], help="Loss mask type", ) parser.add_argument( diff --git a/slime/utils/mask_utils.py b/slime/utils/mask_utils.py index efe5e159f1..7cc2e14d51 100644 --- a/slime/utils/mask_utils.py +++ b/slime/utils/mask_utils.py @@ -213,6 +213,98 @@ def gen_multi_turn_loss_mask_distill_qwen( loss_mask = [0] * len(token_ids) return token_ids, loss_mask + def gen_multi_turn_loss_mask_glm5( + self, messages: list[dict], tools: list[dict] = None + ) -> tuple[list[int], list[int]]: + """Generate loss masks for GLM-5 chat template. + + GLM-5 uses role-token delimiters with no closing tags: + [gMASK]<|system|>...<|user|>...<|assistant|>content... + + Assistant messages start with ``<|assistant|>`` (or ``<|assistant|>... + `` when thinking is enabled). We mask only the assistant content tokens + (everything after ```` until the next role token or end of sequence). + """ + rendered_text = self.tokenizer.apply_chat_template( + messages, tokenize=False, tools=tools, add_generation_prompt=False + ) + tokenized = self.tokenizer(rendered_text, add_special_tokens=False, return_offsets_mapping=True) + token_ids = tokenized["input_ids"] + offset_mapping = tokenized.get("offset_mapping") + + if offset_mapping is None: + raise ValueError( + "GLM-5 loss mask generation requires a fast tokenizer " + "with `return_offsets_mapping` support." + ) + + expected_token_ids = self.tokenizer( + self.tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=False), + add_special_tokens=False, + )["input_ids"] + if token_ids != expected_token_ids: + raise ValueError( + "GLM-5 rendered text tokenization does not match " + "re-tokenized output." + ) + + assistant_header = "<|assistant|>" + think_close = "" + role_markers = ("<|user|>", "<|assistant|>", "<|system|>", "<|observation|>") + + char_mask = [0] * len(rendered_text) + cursor = 0 + + for message in messages: + if message["role"] != "assistant": + continue + + header_pos = rendered_text.find(assistant_header, cursor) + if header_pos < 0: + raise ValueError("Failed to locate <|assistant|> in rendered GLM-5 chat template output.") + + content_start = header_pos + len(assistant_header) + + # Find the end of this assistant message: next role token or end of string + end_pos = len(rendered_text) + for marker in role_markers: + marker_pos = rendered_text.find(marker, content_start) + if 0 <= marker_pos < end_pos: + end_pos = marker_pos + + cursor = end_pos + + if message.get("step_loss_mask", 1) != 1: + continue + + # Skip past or ... block at the start of assistant content + mask_start = content_start + if rendered_text[mask_start : mask_start + len(think_close)] == think_close: + # Simple case: immediately after <|assistant|> + mask_start += len(think_close) + elif rendered_text[mask_start : mask_start + len("")] == "": + # Thinking enabled: ... + think_end = rendered_text.find(think_close, mask_start) + if think_end >= 0 and think_end < end_pos: + mask_start = think_end + len(think_close) + + for pos in range(mask_start, end_pos): + char_mask[pos] = 1 + + # Convert char-level mask to token-level using offset mapping + char_mask_prefix_sum = [0] + for value in char_mask: + char_mask_prefix_sum.append(char_mask_prefix_sum[-1] + value) + + loss_mask = [] + for start, end in offset_mapping: + if end <= start: + loss_mask.append(0) + else: + loss_mask.append(1 if char_mask_prefix_sum[end] - char_mask_prefix_sum[start] > 0 else 0) + + return token_ids, loss_mask + def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple[list[int], list[int]]: if self.tokenizer_type == "qwen": if "<|Assistant|>" in self.tokenizer.get_added_vocab(): @@ -225,6 +317,8 @@ def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple return self.gen_multi_turn_loss_mask_qwen3_5(messages, tools) elif self.tokenizer_type == "distill_qwen": return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools) + elif self.tokenizer_type == "glm5": + return self.gen_multi_turn_loss_mask_glm5(messages, tools) else: raise ValueError(f"Unsupported tokenizer type: {self.tokenizer_type}") diff --git a/tests/utils/test_loss_mask_type_glm5.py b/tests/utils/test_loss_mask_type_glm5.py new file mode 100644 index 0000000000..287c6d24c3 --- /dev/null +++ b/tests/utils/test_loss_mask_type_glm5.py @@ -0,0 +1,305 @@ +from slime.utils.mask_utils import MultiTurnLossMaskGenerator + + +class FakeGlm5Tokenizer: + """A tiny char-level tokenizer that models the GLM-5 chat template formatting. + + GLM-5 uses role-token delimiters with no closing tags: + [gMASK]<|system|>...<|user|>...<|assistant|>content... + + Key behaviors: + 1. Sequence starts with ``[gMASK]`` + 2. Role tokens: ``<|system|>``, ``<|user|>``, ``<|assistant|>``, ``<|observation|>`` + 3. No closing tags — messages end at the next role token or end of string + 4. Assistant turns always start with ```` (thinking disabled) + 5. Tool calls use ``.........`` + 6. Tool responses use ``<|observation|>...`` + """ + + def __call__(self, text, add_special_tokens=False, return_offsets_mapping=False): + encoded = {"input_ids": [ord(ch) for ch in text]} + if return_offsets_mapping: + encoded["offset_mapping"] = [(index, index + 1) for index in range(len(text))] + return encoded + + def decode(self, token_ids): + return "".join(chr(token_id) for token_id in token_ids) + + def apply_chat_template( + self, + messages, + tokenize=True, + tools=None, + add_generation_prompt=False, + return_dict=False, + add_special_tokens=False, + **kwargs, + ): + rendered = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) + if tokenize: + return [ord(ch) for ch in rendered] + return rendered + + def render(self, messages, tools=None, add_generation_prompt=False): + rendered, _ = self.render_with_expected_mask( + messages, tools=tools, add_generation_prompt=add_generation_prompt + ) + return rendered + + def render_with_expected_mask(self, messages, tools=None, add_generation_prompt=False): + pieces = [] + mask = [] + + # GLM-5 prefix + prefix = "[gMASK]" + pieces.append(prefix) + mask.extend([0] * len(prefix)) + + # Tool instructions go into the system message + tool_instruction = self._build_tool_instructions(tools) if tools else "" + + for index, message in enumerate(messages): + role = message["role"] + + if role == "system": + piece = f"<|system|>{message['content']}{tool_instruction}" + pieces.append(piece) + mask.extend([0] * len(piece)) + + elif role == "user": + piece = f"<|user|>{message['content']}" + pieces.append(piece) + mask.extend([0] * len(piece)) + + elif role == "assistant": + header = "<|assistant|>" + content = message.get("content", "") or "" + tool_calls_text = self._render_tool_calls(message.get("tool_calls")) + target = f"{content}{tool_calls_text}" + + pieces.append(header) + mask.extend([0] * len(header)) + + pieces.append(target) + if message.get("step_loss_mask", 1) != 1: + mask.extend([0] * len(target)) + else: + mask.extend([1] * len(target)) + + elif role == "tool": + # First tool response in a group gets <|observation|> + if index == 0 or messages[index - 1]["role"] != "tool": + piece = f"<|observation|>{message['content']}" + else: + piece = f"{message['content']}" + pieces.append(piece) + mask.extend([0] * len(piece)) + + if add_generation_prompt: + gen = "<|assistant|>" + pieces.append(gen) + mask.extend([0] * len(gen)) + + return "".join(pieces), mask + + @staticmethod + def _build_tool_instructions(tools): + if not tools: + return "" + import json + + tool_specs = "\n".join(json.dumps(tool) for tool in tools) + return ( + "\n# Tools\n\n" + "You may call one or more functions to assist with the user query.\n\n" + "\n" + f"{tool_specs}\n" + "" + ) + + @staticmethod + def _render_tool_calls(tool_calls): + if not tool_calls: + return "" + pieces = [] + for tc in tool_calls: + func = tc.get("function", tc) + pieces.append(f"{func['name']}") + for k, v in func.get("arguments", {}).items(): + val = v if isinstance(v, str) else str(v) + pieces.append(f"{k}{val}") + pieces.append("") + return "".join(pieces) + + +def test_glm5_single_turn(): + """Basic single-turn: only assistant content is masked.""" + tokenizer = FakeGlm5Tokenizer() + messages = [ + {"role": "system", "content": "SYSTEM"}, + {"role": "user", "content": "USER"}, + {"role": "assistant", "content": "ANSWER"}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages) + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + selected = generator.get_text_from_loss_mask(token_ids, loss_mask) + assert selected == ["ANSWER"] + + +def test_glm5_multi_turn(): + """Multi-turn: each assistant turn is independently masked.""" + tokenizer = FakeGlm5Tokenizer() + messages = [ + {"role": "system", "content": "SYSTEM"}, + {"role": "user", "content": "USER_1"}, + {"role": "assistant", "content": "ANSWER_1"}, + {"role": "user", "content": "USER_2"}, + {"role": "assistant", "content": "ANSWER_2"}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages) + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + selected = generator.get_text_from_loss_mask(token_ids, loss_mask) + assert selected == ["ANSWER_1", "ANSWER_2"] + + +def test_glm5_step_loss_mask(): + """step_loss_mask=0 suppresses loss on specific assistant turns.""" + tokenizer = FakeGlm5Tokenizer() + messages = [ + {"role": "user", "content": "USER_1"}, + {"role": "assistant", "content": "SKIP_THIS", "step_loss_mask": 0}, + {"role": "user", "content": "USER_2"}, + {"role": "assistant", "content": "KEEP_THIS"}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages) + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + selected = generator.get_text_from_loss_mask(token_ids, loss_mask) + assert selected == ["KEEP_THIS"] + + +def test_glm5_tool_call_flow(): + """Tool calling: assistant tool calls are masked, tool responses are not.""" + tokenizer = FakeGlm5Tokenizer() + messages = [ + {"role": "user", "content": "USER"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"function": {"name": "get_weather", "arguments": {"city": "Paris"}}}], + }, + {"role": "tool", "content": '{"temp": 22}'}, + {"role": "assistant", "content": "It is 22C in Paris."}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages) + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + selected = generator.get_text_from_loss_mask(token_ids, loss_mask) + assert selected == [ + "get_weathercityParis", + "It is 22C in Paris.", + ] + + +def test_glm5_tool_call_with_tools_schema(): + """Tool calling with tools schema injected into system message.""" + tokenizer = FakeGlm5Tokenizer() + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + }, + } + ] + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"function": {"name": "get_weather", "arguments": {"city": "NYC"}}}], + }, + {"role": "tool", "content": '{"temp": 15}'}, + {"role": "assistant", "content": "15C in NYC."}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages, tools=tools) + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages, tools=tools) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + selected = generator.get_text_from_loss_mask(token_ids, loss_mask) + assert selected == [ + "get_weathercityNYC", + "15C in NYC.", + ] + + +def test_glm5_no_system_message(): + """Conversation without system message.""" + tokenizer = FakeGlm5Tokenizer() + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + + expected_text, expected_mask = tokenizer.render_with_expected_mask(messages) + expected_token_ids = tokenizer(expected_text, add_special_tokens=False)["input_ids"] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert token_ids == expected_token_ids + assert loss_mask == expected_mask + selected = generator.get_text_from_loss_mask(token_ids, loss_mask) + assert selected == ["Hi!"] + + +def test_glm5_lengths_match(): + """token_ids and loss_mask always have the same length.""" + tokenizer = FakeGlm5Tokenizer() + messages = [ + {"role": "system", "content": "S"}, + {"role": "user", "content": "U1"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "U2"}, + {"role": "assistant", "content": "A2", "step_loss_mask": 0}, + {"role": "user", "content": "U3"}, + {"role": "assistant", "content": "A3"}, + ] + + generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type="glm5") + token_ids, loss_mask = generator.get_loss_mask(messages) + + assert len(token_ids) == len(loss_mask)