Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tunix/rl/agentic/parser/chat_template_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,87 @@ def parse(

def _init_generation_prompt(self) -> str:
return self.tokens.assistant_token


class Gemma4ChatTemplateParser(BaseChatTemplateParser):
"""Parser for Gemma 4 models."""

def __init__(self, tokenizer, enable_thinking: bool = True):
super().__init__(tokenizer, enable_thinking=enable_thinking)
# Also sanitize the base <turn|> token (without trailing newline) to guard
# against model-generated control tokens trailing in message contents.
self._tokens_to_sanitize.add("<turn|>")

def _init_tokens(self) -> TokenConfig:
return TokenConfig(
bos_token="<bos>",
eot_token="<turn|>\n",
system_token=self._get_system_token(),
user_token="<|turn>user\n",
assistant_token=self._get_assistant_token(),
tool_start_token="<|tool_call>",
tool_end_token="<tool_call|>",
tool_response_start_token="<|tool_response>",
tool_response_end_token="<tool_response|>",
message_separator="",
)

def _get_system_token(self) -> str:
token = "<|turn>system\n"
if self.enable_thinking:
token += "<|think|>"
return token

def _get_assistant_token(self) -> str:
token = "<|turn>model\n"
if not self.enable_thinking:
token += "<|channel>thought\n<channel|>"
return token

def _init_generation_prompt(self) -> str:
return self.tokens.assistant_token

def _handle_first_message(self, messages: List[Dict[str, str]]) -> str:
"""Prepend bos and system think token if needed."""
prefix = self.tokens.bos_token
if messages and messages[0]["role"] not in ("system", "developer"):
if self.enable_thinking:
prefix += "<|turn>system\n<|think|><turn|>\n"
return prefix

def _parse_tool(self, content: str) -> str:
return (
self.tokens.tool_response_start_token
+ content
+ self.tokens.tool_response_end_token
)

def _parse_system(self, content: str) -> str:
return self.tokens.system_token + content.strip() + self.tokens.eot_token

def _parse_user(self, content: str) -> str:
return self.tokens.user_token + content.strip() + self.tokens.eot_token

@staticmethod
def _strip_thinking(text: str) -> str:
"""Strip <|channel>...<channel|> blocks from text.

Matches the Jinja template's strip_thinking macro: for each segment
between <channel|> delimiters, drop everything from <|channel> onward.
Applied to assistant content so past thinking is not re-fed to the model.
"""
parts = text.split("<channel|>")
result = []
for part in parts:
idx = part.find("<|channel>")
result.append(part[:idx] if idx != -1 else part)
return "".join(result)

def _parse_assistant(self, content: str) -> str:
cleaned_content = self._strip_thinking(content).strip()
if cleaned_content.endswith("<turn|>"):
return "<|turn>model\n" + cleaned_content + "\n"
return "<|turn>model\n" + cleaned_content + self.tokens.eot_token



Loading