From 2176379e53ec88755c3aa7879463598bb2f2a4c3 Mon Sep 17 00:00:00 2001 From: Lin Chai Date: Thu, 28 May 2026 16:06:35 -0700 Subject: [PATCH] [Tunix] Introduced the custom Gemma4ChatTemplateParser which parses the multi-turn messages correctly. PiperOrigin-RevId: 923029241 --- .../parser/chat_template_parser/parser.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tunix/rl/agentic/parser/chat_template_parser/parser.py b/tunix/rl/agentic/parser/chat_template_parser/parser.py index 0db780557..5d18f0428 100644 --- a/tunix/rl/agentic/parser/chat_template_parser/parser.py +++ b/tunix/rl/agentic/parser/chat_template_parser/parser.py @@ -289,3 +289,87 @@ def parse( def _init_generation_prompt(self) -> str: return self.tokens.assistant_token + + +class Gemma4ChatTemplateParser(BaseChatTemplateParser): + """Parser for Gemma 4 models.""" + + def __init__(self, tokenizer, enable_thinking: bool = True): + super().__init__(tokenizer, enable_thinking=enable_thinking) + # Also sanitize the base token (without trailing newline) to guard + # against model-generated control tokens trailing in message contents. + self._tokens_to_sanitize.add("") + + def _init_tokens(self) -> TokenConfig: + return TokenConfig( + bos_token="", + eot_token="\n", + system_token=self._get_system_token(), + user_token="<|turn>user\n", + assistant_token=self._get_assistant_token(), + tool_start_token="<|tool_call>", + tool_end_token="", + tool_response_start_token="<|tool_response>", + tool_response_end_token="", + message_separator="", + ) + + def _get_system_token(self) -> str: + token = "<|turn>system\n" + if self.enable_thinking: + token += "<|think|>" + return token + + def _get_assistant_token(self) -> str: + token = "<|turn>model\n" + if not self.enable_thinking: + token += "<|channel>thought\n" + return token + + def _init_generation_prompt(self) -> str: + return self.tokens.assistant_token + + def _handle_first_message(self, messages: List[Dict[str, str]]) -> str: + """Prepend bos and system think token if needed.""" + prefix = self.tokens.bos_token + if messages and messages[0]["role"] not in ("system", "developer"): + if self.enable_thinking: + prefix += "<|turn>system\n<|think|>\n" + return prefix + + def _parse_tool(self, content: str) -> str: + return ( + self.tokens.tool_response_start_token + + content + + self.tokens.tool_response_end_token + ) + + def _parse_system(self, content: str) -> str: + return self.tokens.system_token + content.strip() + self.tokens.eot_token + + def _parse_user(self, content: str) -> str: + return self.tokens.user_token + content.strip() + self.tokens.eot_token + + @staticmethod + def _strip_thinking(text: str) -> str: + """Strip <|channel>... blocks from text. + + Matches the Jinja template's strip_thinking macro: for each segment + between delimiters, drop everything from <|channel> onward. + Applied to assistant content so past thinking is not re-fed to the model. + """ + parts = text.split("") + result = [] + for part in parts: + idx = part.find("<|channel>") + result.append(part[:idx] if idx != -1 else part) + return "".join(result) + + def _parse_assistant(self, content: str) -> str: + cleaned_content = self._strip_thinking(content).strip() + if cleaned_content.endswith(""): + return "<|turn>model\n" + cleaned_content + "\n" + return "<|turn>model\n" + cleaned_content + self.tokens.eot_token + + +