diff --git a/conf/tir.yaml b/conf/tir.yaml
new file mode 100644
index 00000000..606ab2dc
--- /dev/null
+++ b/conf/tir.yaml
@@ -0,0 +1,96 @@
+defaults:
+ - base
+ - override rewards: success_and_format
+ - _self_
+
+actor:
+ rollout_policy: pipelinerl.domains.tir.generate_tir_rollout
+ system_prompt: |
+ You are a math-focused AI Agent. Solve problems by combining clear symbolic reasoning
+ with short, deterministic Python code.
+ Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration.
+ Always present the final answer in LaTeX \boxed{}.
+ Do not express emotions or opinions about user questions.
+
+ Workflow:
+ 1. Draft a brief plan in plain text.
+ 2. Execute one run_python_code call to compute or verify the result.
+ 3. Finalize by calling MathAnswer with the LaTeX-formatted answer.
+
+ Python execution policy (run_python_code):
+ - Use Python strictly for pure computation to verify and validate the final answer.
+ - No network, file system, OS or environment access.
+ - Keep snippets minimal and self-contained; print only the final result.
+
+ Validation:
+ - Cross-check results (alternative derivation, invariants, higher precision) before finalizing.
+ - If execution fails, propose the minimal fix and retry.
+ Always verify with run_python_code before invoking MathAnswer.
+ task_template: "{task}"
+ agent_max_loops: 3
+ llm_max_rollouts: 128
+ max_rollout_retries: 20
+ rollout_workers: 8
+ shared_memory_entry_size: 1000000000
+
+rewards:
+ correct_answer_not_finished: 0.0
+ buffer_tokens: 0
+
+# Math verifier environment
+environments:
+ - key: math
+ mode: remote
+ _target_: pipelinerl.domains.math.MathEnvironment
+environment_key: math
+dataset_loader: pipelinerl.domains.math.load_datasets
+
+train_dataset_names:
+ - open_reasoner_zero_57k
+ - open_reasoner_zero_extended_72k
+test_dataset_names:
+ - aime_2025
+
+# SandboxFusion config
+sandbox_endpoint: ${oc.env:SANDBOX_ENDPOINT,http://127.0.0.1:8080}
+sandbox_timeout: 10.0
+
+# Optional reward shaping
+python_tool_shaping:
+ bonus_on_correct_with_python: 0.1
+ penalty_on_incorrect_without_python: 0.1
+ max_abs: 0.2
+
+# vLLM tool-call parser config
+vllm_config:
+ vllm_kwargs:
+ enable-auto-tool-choice: ""
+ tool-call-parser: rl_tool
+ tool-parser-plugin: ${hydra:runtime.cwd}/pipelinerl/rl_tool_parser_plugin.py
+ max_model_len: 32000
+
+llm:
+ parameters:
+ max_tokens: 16000
+ temperature: 1.0
+
+test_llm:
+ parameters:
+ max_tokens: 16000
+ temperature: 1.0
+ top_p: 0.95
+ top_k: 50
+
+finetune:
+ seq_length: 32000
+ seq_parallel: 8
+ gradient_accumulation_passes: 1024
+ rl:
+ policy_loss: gspo
+ overlong_filtering: true
+
+preprocess:
+ input: actor
+ output: training_data
+ n_workers: 8
+ shared_memory_entry_size: 1000000000
diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py
index d276747d..67c71a59 100644
--- a/pipelinerl/actor.py
+++ b/pipelinerl/actor.py
@@ -25,7 +25,7 @@
from pipelinerl.finetune_loop import calculate_train_steps
from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb
from pipelinerl.llm import TrainableLLM
-from pipelinerl.rollouts import BaseMetrics, RolloutResult
+from pipelinerl.rollouts import BaseMetrics, RolloutResult, rollout_has_overflow
from pipelinerl.shared_memory_array import SharedMemoryQueue
from pipelinerl.state import TrainerState
from pipelinerl.streams import (
@@ -399,7 +399,7 @@ def init_stats(self):
def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]:
metrics = {}
- metrics['overflow'] = all([not training_text.finished for training_text in result.training_texts ])
+ metrics['overflow'] = rollout_has_overflow(result.training_texts)
metrics['num_turns'] = len(result.training_texts)
metrics['prompt_tokens'] = [training_text.prompt_tokens for training_text in result.training_texts]
metrics['output_tokens'] = [training_text.output_tokens for training_text in result.training_texts]
diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py
index b305c458..26ca3da1 100644
--- a/pipelinerl/async_llm.py
+++ b/pipelinerl/async_llm.py
@@ -3,12 +3,13 @@
import logging
import aiohttp
+import litellm
import numpy as np
from PIL import Image
from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM
from pipelinerl.finetune.data import MASKED_TOKEN_ID
-from pipelinerl.rollouts import TrainingText
+from pipelinerl.rollouts import TrainingText, apply_rollout_reward
from pipelinerl.processor_factory import get_processor
from omegaconf import DictConfig, ListConfig, OmegaConf
@@ -83,7 +84,10 @@ def _is_retryable_abort_response(data: dict, collect_logprobs: bool) -> tuple[bo
async def llm_async_generate(
- llm: TrainableLLM, prompt: Prompt, session: aiohttp.ClientSession
+ llm: TrainableLLM,
+ prompt: Prompt,
+ session: aiohttp.ClientSession,
+ max_tokens_override: int | None = None,
) -> LLMCall:
llm.load_tokenizer()
headers = {"Content-Type": "application/json"}
@@ -117,6 +121,12 @@ async def llm_async_generate(
logger.debug(f"POST request to {llm.base_url}/v1/chat/completions")
+ if prompt.tools:
+ data["tools"] = _to_plain_obj(prompt.tools)
+
+ if max_tokens_override is not None:
+ data["max_tokens"] = max_tokens_override
+
# Merge extra_parameters first so that data (model, messages, logprobs settings) takes precedence
payload = _to_plain_obj({**extra_parameters, **data})
response_data = None
@@ -161,7 +171,8 @@ async def llm_async_generate(
try:
content = response_data["choices"][0]["message"]["content"]
- if not content:
+ raw_tool_calls = response_data["choices"][0]["message"].get("tool_calls", [])
+ if not content and not raw_tool_calls:
logger.warning(f"Empty completion {response_data}")
parsed_logprobs = []
@@ -188,7 +199,9 @@ async def llm_async_generate(
logger.exception(f"Failed to parse llm response: {response_data}")
raise
- output = LLMOutput(content=content)
+ output = LLMOutput(content=content or "")
+ if raw_tool_calls:
+ output.tool_calls = [litellm.ChatCompletionMessageToolCall(**tc) for tc in raw_tool_calls]
llm_call = llm.log_output(prompt, output, count_tokens=False)
llm_call.prompt_length_tokens = response_data["usage"]["prompt_tokens"]
llm_call.output_length_tokens = response_data["usage"]["completion_tokens"]
@@ -210,9 +223,20 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText:
images = []
use_processor = False
visual_features = None
- full_messages = llm_call.prompt.messages + [
- {"role": "assistant", "content": llm_call.output.content}
- ]
+ assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""}
+ if llm_call.output.tool_calls:
+ assistant_msg["tool_calls"] = [
+ {
+ "id": tc.id,
+ "type": "function",
+ "function": {
+ "name": tc.function.name,
+ "arguments": tc.function.arguments,
+ },
+ }
+ for tc in llm_call.output.tool_calls
+ ]
+ full_messages = llm_call.prompt.messages + [assistant_msg]
if hasattr(llm_call.prompt, "messages"):
images = extract_images_from_messages(llm_call.prompt.messages)
@@ -265,6 +289,8 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText:
else:
# Use tokenizer for text-only models
chat_kwargs = _to_plain_obj(llm.chat_template_kwargs) if llm.chat_template_kwargs else {}
+ if llm_call.prompt.tools:
+ chat_kwargs = {**chat_kwargs, "tools": _to_plain_obj(llm_call.prompt.tools)}
prompt_text = llm.tokenizer.apply_chat_template(
conversation=llm_call.prompt.messages,
tokenize=False,
@@ -285,7 +311,6 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText:
output_text = text[len(prompt_text) :]
- # Get the appropriate tokenizer (from processor if using vision model)
tokenizer = processor.tokenizer if use_processor else llm.tokenizer
if tokenizer.bos_token and text.startswith(tokenizer.bos_token):
@@ -304,7 +329,7 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText:
finished = finish_reason != "length"
else:
eos_token = tokenizer.eos_token or ""
- finished = bool(eos_token) and llm_call.output.content.endswith(eos_token)
+ finished = bool(eos_token) and (llm_call.output.content or "").endswith(eos_token)
prompt_tokens = llm_call.prompt_length_tokens
output_tokens = llm_call.output_length_tokens
@@ -319,3 +344,14 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText:
output_tokens=output_tokens,
visual_features=visual_features,
)
+
+
+def make_training_texts_from_llm_calls(
+ llm: TrainableLLM,
+ llm_calls: list[LLMCall],
+ reward: float | None = None,
+) -> list[TrainingText]:
+ training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls]
+ if reward is not None:
+ training_texts = apply_rollout_reward(training_texts, reward)
+ return training_texts
diff --git a/pipelinerl/domains/math/__init__.py b/pipelinerl/domains/math/__init__.py
index 9aee0b8f..7a9809b7 100644
--- a/pipelinerl/domains/math/__init__.py
+++ b/pipelinerl/domains/math/__init__.py
@@ -1,3 +1,3 @@
from .load_datasets import load_datasets
-from .rollouts import generate_math_rollout, RewardTable
+from .rollouts import generate_math_rollout, RewardTable, get_reward, length_penalty
from .verifier_api import MathEnvironment, verify_answer, verify_answer_rpc
\ No newline at end of file
diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py
index 3e6560c0..8b712ff0 100644
--- a/pipelinerl/domains/math/rollouts.py
+++ b/pipelinerl/domains/math/rollouts.py
@@ -55,6 +55,28 @@ def log_config(self, domain: str = "unknown") -> None:
f"buffer_tokens={self.buffer_tokens}"
)
+def get_reward(answer_status: str, finished: bool, reward_table: RewardTable) -> float:
+ match (answer_status, finished):
+ case ("wrong", False):
+ return reward_table.wrong_answer_not_finished
+ case ("wrong", True):
+ return reward_table.wrong_answer_finished
+ case ("no_answer", False):
+ return reward_table.no_answer_not_finished
+ case ("no_answer", True):
+ return reward_table.no_answer_finished
+ case ("unparsable", False):
+ return reward_table.unparsable_not_finished
+ case ("unparsable", True):
+ return reward_table.unparsable_finished
+ case ("correct", False):
+ return reward_table.correct_answer_not_finished
+ case ("correct", True):
+ return reward_table.correct_answer_finished
+ case _:
+ raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{finished}")
+
+
def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float:
"""
Compute the overlong penalty
@@ -100,25 +122,7 @@ async def generate_math_rollout(
trace = make_training_text(llm, llm_call)
# Determine reward based on answer status and finished state
- match (answer_status, trace.finished):
- case ("wrong", False):
- reward = rewards.wrong_answer_not_finished
- case ("wrong", True):
- reward = rewards.wrong_answer_finished
- case ("no_answer", False):
- reward = rewards.no_answer_not_finished
- case ("no_answer", True):
- reward = rewards.no_answer_finished
- case ("unparsable", False):
- reward = rewards.unparsable_not_finished
- case ("unparsable", True):
- reward = rewards.unparsable_finished
- case ("correct", False):
- reward = rewards.correct_answer_not_finished
- case ("correct", True):
- reward = rewards.correct_answer_finished
- case _:
- raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{trace.finished}")
+ reward = get_reward(answer_status, trace.finished, rewards)
# Apply discount factor based on output length
reward *= discount_factor**llm_call.output_length_tokens
diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py
index 8c489291..89e9cc5a 100644
--- a/pipelinerl/domains/miniwob/rollouts.py
+++ b/pipelinerl/domains/miniwob/rollouts.py
@@ -18,9 +18,9 @@
from tapeagents.remote_environment import AsyncRemoteEnvironment
from tapeagents.tools.simple_browser import PageObservation
-from pipelinerl.async_llm import make_training_text
+from pipelinerl.async_llm import make_training_texts_from_llm_calls
from pipelinerl.llm import LLMCall, TrainableLLM
-from pipelinerl.rollouts import BaseMetrics, RolloutResult
+from pipelinerl.rollouts import BaseMetrics, RolloutResult, summarize_training_texts
from pipelinerl.world import Job
from .steps import WebTape
@@ -271,13 +271,8 @@ async def _execute_rollout_with_timeout(
]
# (4) # For each LLM interaction in the tape, make a training example.
- all_finished = 1
- prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls]
- output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls]
- training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls]
- for text in training_texts:
- text.reward = reward
- all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0
+ training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward)
+ training_summary = summarize_training_texts(training_texts)
latency = time.time() - start_time
agent_time = tape.metadata.result.get("agent_execution_time", -1.0)
@@ -289,7 +284,7 @@ async def _execute_rollout_with_timeout(
success=reward > 0.5,
no_error=no_error,
no_answer=reward < 0,
- overflow=not all_finished,
+ overflow=training_summary.overflow,
n_llm_calls=n_llm_calls,
n_step_errors=n_step_errors,
n_page_observations=n_page_observations,
@@ -307,8 +302,6 @@ async def _execute_rollout_with_timeout(
latency=latency,
dataset_name=problem["dataset"],
domain="miniwob",
- prompt_tokens=prompt_tokens,
- output_tokens=output_tokens,
)
@@ -340,6 +333,4 @@ def _create_failed_rollout_result(problem: dict, start_time: float, error_type:
latency=latency,
dataset_name=problem["dataset"],
domain="miniwob",
- prompt_tokens=[],
- output_tokens=[],
)
diff --git a/pipelinerl/domains/tir/__init__.py b/pipelinerl/domains/tir/__init__.py
new file mode 100644
index 00000000..4a658bd0
--- /dev/null
+++ b/pipelinerl/domains/tir/__init__.py
@@ -0,0 +1 @@
+from .rollouts import generate_tir_rollout
diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py
new file mode 100644
index 00000000..c25a30b5
--- /dev/null
+++ b/pipelinerl/domains/tir/rollouts.py
@@ -0,0 +1,385 @@
+import asyncio
+import json
+import logging
+import random
+import re
+import time
+from dataclasses import dataclass
+from typing import Awaitable, Callable
+
+import aiohttp
+from omegaconf import DictConfig
+
+from sandbox_fusion import RunCodeRequest, set_sandbox_endpoint, run_code_async
+
+from pipelinerl.async_llm import llm_async_generate, make_training_texts_from_llm_calls
+from pipelinerl.domains.math import RewardTable, get_reward, length_penalty, verify_answer_rpc
+from pipelinerl.llm import Prompt, TrainableLLM
+from pipelinerl.rollouts import BaseMetrics, RolloutResult
+from pipelinerl.utils import get_environment_jobs, resolve_environment_key
+
+logger = logging.getLogger(__name__)
+
+_SANDBOX_CONFIGURED = False
+
+_BLOCKED_PATTERNS = [
+ re.compile(r"\bsys\.exit\b"),
+ re.compile(r"\bos\._exit\b"),
+ re.compile(r"\bos\.system\b"),
+ re.compile(r"\bsubprocess\b"),
+ re.compile(r"\bos\.popen\b"),
+ re.compile(r"\bos\.exec\w*\b"),
+ re.compile(r"\bos\.spawn\w*\b"),
+ re.compile(r"\bos\.kill\b"),
+ re.compile(r"\bshutil\.rmtree\b"),
+ re.compile(r"\bos\.remove\b"),
+ re.compile(r"\bos\.unlink\b"),
+]
+
+
+def build_tool_definitions() -> list[dict]:
+ return [
+ {
+ "type": "function",
+ "function": {
+ "name": "run_python_code",
+ "description": "Execute Python code. Print only the final result.",
+ "parameters": {
+ "type": "object",
+ "properties": {"code": {"type": "string", "description": "Python code to execute"}},
+ "required": ["code"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "MathAnswer",
+ "description": "Submit the final answer in LaTeX \\boxed{} format.",
+ "parameters": {
+ "type": "object",
+ "properties": {"answer": {"type": "string", "description": "The final answer"}},
+ "required": ["answer"],
+ },
+ },
+ },
+ ]
+
+
+def _check_code_safety(code: str) -> str | None:
+ for pattern in _BLOCKED_PATTERNS:
+ if pattern.search(code):
+ return f"Blocked: code contains forbidden pattern '{pattern.pattern}'"
+ return None
+
+
+async def execute_python_sandbox(code: str, endpoint: str, timeout: float) -> str:
+ """Execute Python code via SandboxFusion and return formatted output."""
+ global _SANDBOX_CONFIGURED
+ if not _SANDBOX_CONFIGURED:
+ set_sandbox_endpoint(endpoint)
+ _SANDBOX_CONFIGURED = True
+ logger.info("Configured SandboxFusion endpoint: %s", endpoint)
+
+ rejection = _check_code_safety(code)
+ if rejection is not None:
+ return rejection
+
+ try:
+ request = RunCodeRequest(code=code, language="python", run_timeout=timeout)
+ response = await run_code_async(request)
+
+ stdout = ""
+ stderr = ""
+ if response.run_result:
+ stdout = response.run_result.stdout or ""
+ stderr = response.run_result.stderr or ""
+
+ status = response.status.value if hasattr(response.status, "value") else str(response.status)
+ is_timeout = "timeout" in status.lower() or "timeout" in (response.message or "").lower()
+
+ parts = []
+ if stdout:
+ parts.append(stdout.rstrip())
+ if stderr:
+ parts.append(f"[stderr]\n{stderr.rstrip()}")
+ if is_timeout:
+ parts.append("[execution timed out]")
+ if not parts:
+ parts.append("[no output]")
+ return "\n".join(parts)
+
+ except asyncio.TimeoutError:
+ return "[execution timed out]"
+ except Exception as exc:
+ logger.warning("SandboxFusion error: %s", exc)
+ return f"[execution error: {exc}]"
+
+
+def _serialize_tool_calls(tool_calls) -> list[dict]:
+ """Serialize litellm tool call objects to dicts for conversation history."""
+ return [
+ {
+ "id": tc.id,
+ "type": "function",
+ "function": {
+ "name": tc.function.name,
+ "arguments": tc.function.arguments,
+ },
+ }
+ for tc in tool_calls
+ ]
+
+
+def _parse_tool_arguments(arguments: str, *, fallback_key: str | None = None) -> dict:
+ """Parse tool-call arguments into an object payload.
+
+ Valid JSON that is not an object should not crash the rollout loop. A bare
+ string can still be recovered for simple single-field tool schemas.
+ """
+ try:
+ parsed = json.loads(arguments)
+ except (json.JSONDecodeError, TypeError):
+ return {}
+ if isinstance(parsed, dict):
+ return parsed
+ if fallback_key is not None and isinstance(parsed, str):
+ return {fallback_key: parsed}
+ return {}
+
+
+@dataclass
+class _ToolContext:
+ sandbox_endpoint: str
+ sandbox_timeout: float
+ messages: list[dict]
+ final_answer: str | None = None
+ submitted_final_answer: bool = False
+ num_python_calls: int = 0
+
+
+ToolHandler = Callable[[object, _ToolContext], Awaitable[None]]
+
+
+async def _handle_math_answer(tc, ctx: _ToolContext) -> None:
+ args = _parse_tool_arguments(tc.function.arguments, fallback_key="answer")
+ ctx.final_answer = args.get("answer", "")
+ ctx.submitted_final_answer = True
+ ctx.messages.append({
+ "role": "tool",
+ "tool_call_id": tc.id,
+ "content": f"Answer submitted: {ctx.final_answer}",
+ })
+
+
+async def _handle_run_python_code(tc, ctx: _ToolContext) -> None:
+ args = _parse_tool_arguments(tc.function.arguments, fallback_key="code")
+ code = args.get("code") or args.get("python_code", "")
+ result = await execute_python_sandbox(code, ctx.sandbox_endpoint, ctx.sandbox_timeout)
+ ctx.num_python_calls += 1
+ ctx.messages.append({"role": "tool", "tool_call_id": tc.id, "content": result})
+
+
+async def _handle_unknown_tool(tc, ctx: _ToolContext) -> None:
+ ctx.messages.append({
+ "role": "tool",
+ "tool_call_id": tc.id,
+ "content": f"Unknown tool: {tc.function.name}",
+ })
+
+
+_TOOL_HANDLERS: dict[str, ToolHandler] = {
+ "MathAnswer": _handle_math_answer,
+ "run_python_code": _handle_run_python_code,
+}
+
+
+class RewardShaper:
+ def __init__(self, cfg: DictConfig, llm: TrainableLLM):
+ self._python_cfg = getattr(cfg, "python_tool_shaping", None)
+ self._length_cfg = getattr(cfg, "length_shaping", None)
+ self._max_gen_tokens = int(llm.parameters.get("max_tokens", 2048))
+
+ def compute(self, answer_status: str, num_python_calls: int, llm_calls: list) -> float:
+ return (
+ self._python_tool_bonus(answer_status, num_python_calls)
+ + self._length_adjustment(answer_status, llm_calls)
+ )
+
+ def _python_tool_bonus(self, answer_status: str, num_python_calls: int) -> float:
+ cfg = self._python_cfg
+ if cfg is None:
+ return 0.0
+ bonus = float(getattr(cfg, "bonus_on_correct_with_python", 0.0))
+ penalty = float(getattr(cfg, "penalty_on_incorrect_without_python", 0.0))
+ max_abs = float(getattr(cfg, "max_abs", 0.2))
+ total = 0.0
+ if answer_status == "correct" and num_python_calls >= 1:
+ total += bonus
+ if answer_status in ("wrong", "unparsable") and num_python_calls == 0:
+ total -= penalty
+ return max(-max_abs, min(max_abs, total))
+
+ def _length_adjustment(self, answer_status: str, llm_calls: list) -> float:
+ cfg = self._length_cfg
+ if cfg is None or not llm_calls:
+ return 0.0
+ if hasattr(cfg, "target_ratio"):
+ ratio = float(getattr(cfg, "target_ratio"))
+ target = int(max(1, ratio * self._max_gen_tokens))
+ target = max(int(getattr(cfg, "min_target_tokens", 0)), target)
+ target = min(int(getattr(cfg, "max_target_tokens", 10**9)), target)
+ else:
+ target = int(getattr(cfg, "target_output_tokens", 512))
+ slope = float(getattr(cfg, "slope", 0.0))
+ max_penalty = float(getattr(cfg, "max_penalty", 0.0))
+ bonus_short_correct = float(getattr(cfg, "bonus_on_short_correct", 0.0))
+
+ avg_out = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) / len(llm_calls)
+ total = 0.0
+ if slope > 0.0 and max_penalty > 0.0 and avg_out > target:
+ total -= min(max_penalty, slope * (avg_out - target))
+ if bonus_short_correct > 0.0 and answer_status == "correct" and avg_out <= target:
+ total += bonus_short_correct
+ return total
+
+
+class Metrics(BaseMetrics):
+ num_python_calls: int = 0
+ num_steps: int = 0
+ overflow: bool = False
+
+
+async def generate_tir_rollout(
+ cfg: DictConfig,
+ llm: TrainableLLM,
+ problem: dict,
+ session: aiohttp.ClientSession,
+) -> RolloutResult:
+ start = time.perf_counter()
+
+ messages: list[dict] = []
+ if cfg.actor.system_prompt:
+ messages.append({"role": "system", "content": cfg.actor.system_prompt})
+ messages.append({"role": "user", "content": cfg.actor.task_template.format(task=problem["task"])})
+
+ tools = build_tool_definitions()
+
+ llm_calls = []
+ ctx = _ToolContext(
+ sandbox_endpoint=str(cfg.sandbox_endpoint),
+ sandbox_timeout=float(cfg.sandbox_timeout),
+ messages=messages,
+ )
+ agent_max_loops = int(getattr(cfg.actor, "agent_max_loops", 3))
+ configured_max_tokens = int(llm.parameters.get("max_tokens", 16000))
+ max_model_len = int(cfg.vllm_config.vllm_kwargs.get("max_model_len", 32000))
+ min_generation_tokens = 256
+
+ for _turn in range(agent_max_loops):
+ prompt = Prompt(messages=list(messages), tools=tools)
+
+ llm.load_tokenizer()
+ prompt_token_ids = llm.tokenizer.apply_chat_template(
+ messages,
+ add_special_tokens=True,
+ add_generation_prompt=True,
+ tools=tools,
+ )
+ prompt_len = len(prompt_token_ids)
+ remaining = max_model_len - prompt_len
+ if remaining < min_generation_tokens:
+ logger.warning(
+ "Prompt length %d leaves only %d tokens for generation (max_model_len=%d), stopping loop",
+ prompt_len, remaining, max_model_len,
+ )
+ break
+ max_tokens_this_turn = min(configured_max_tokens, remaining)
+ if max_tokens_this_turn < configured_max_tokens:
+ logger.warning(
+ "Turn %d: capping max_tokens from %d to %d (prompt_len=%d, max_model_len=%d)",
+ _turn, configured_max_tokens, max_tokens_this_turn, prompt_len, max_model_len,
+ )
+
+ llm_call = await llm_async_generate(llm, prompt, session, max_tokens_override=max_tokens_this_turn)
+ llm_calls.append(llm_call)
+
+ if not llm_call.output.tool_calls:
+ break
+
+ assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""}
+ assistant_msg["tool_calls"] = _serialize_tool_calls(llm_call.output.tool_calls)
+ messages.append(assistant_msg)
+
+ for tc in llm_call.output.tool_calls:
+ handler = _TOOL_HANDLERS.get(tc.function.name, _handle_unknown_tool)
+ await handler(tc, ctx)
+ if ctx.submitted_final_answer:
+ break
+
+ if ctx.submitted_final_answer:
+ break
+
+ if ctx.final_answer is not None:
+ prediction = ctx.final_answer
+ elif llm_calls:
+ prediction = llm_calls[-1].output.content or ""
+ else:
+ prediction = ""
+
+ env_key = resolve_environment_key(cfg, default="math")
+ env_jobs = get_environment_jobs(cfg, env_key)
+ if not env_jobs:
+ raise RuntimeError("No environment servers available for math domain")
+ env_job = random.choice(env_jobs)
+ assert env_job.port is not None
+ answer_status = await verify_answer_rpc(
+ session=session,
+ host=env_job.hostname,
+ port=env_job.port,
+ prediction=prediction,
+ gold=problem["answer"],
+ strict=True,
+ )
+
+ reward_table = RewardTable(**dict(cfg.rewards))
+ base_reward = get_reward(answer_status, ctx.submitted_final_answer, reward_table)
+
+ discount_factor = float(getattr(cfg.actor, "discount_factor", 1.0))
+ if discount_factor != 1.0:
+ total_generated_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls)
+ base_reward *= discount_factor ** total_generated_tokens
+
+ buffer_tokens = getattr(reward_table, "buffer_tokens", 0)
+ if buffer_tokens:
+ max_tokens = int(llm.parameters.get("max_tokens", 0))
+ total_output_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls)
+ if max_tokens > 0:
+ base_reward += length_penalty(max_tokens, total_output_tokens, buffer_tokens)
+
+ shaping = RewardShaper(cfg, llm).compute(answer_status, ctx.num_python_calls, llm_calls)
+ reward = base_reward + shaping
+
+ training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward)
+ for text in training_texts:
+ text.finished = ctx.submitted_final_answer
+
+ latency = time.perf_counter() - start
+
+ metrics = Metrics(
+ reward=reward,
+ success=answer_status == "correct",
+ no_error=answer_status != "unparsable",
+ no_answer=answer_status == "no_answer",
+ num_python_calls=ctx.num_python_calls,
+ num_steps=len(llm_calls),
+ overflow=not ctx.submitted_final_answer,
+ )
+
+ return RolloutResult(
+ training_texts=training_texts,
+ metrics=metrics,
+ latency=latency,
+ dataset_name=problem.get("dataset"),
+ domain="tir",
+ )
diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py
index 2ba28c5e..dea58599 100644
--- a/pipelinerl/finetune/rl/__init__.py
+++ b/pipelinerl/finetune/rl/__init__.py
@@ -1,7 +1,7 @@
import logging
import os
from functools import partial
-from typing import Any
+from typing import Any, TYPE_CHECKING
from pydantic import BaseModel, Field
import numpy as np
@@ -9,10 +9,14 @@
import torch
import torch.nn.functional as F
from datasets import Dataset
-from transformers import PreTrainedModel
from pipelinerl.finetune.types import PipelineBatchEncoding
from pipelinerl.finetune.rl.utils import per_segment_sums
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+else:
+ PreTrainedModel = Any
+
from .utils import (
sum_sum,
mean_sum,
@@ -452,58 +456,72 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R
df_init = pd.DataFrame(dataset)
assert isinstance(df_init, pd.DataFrame)
- # Step 1: calculate group-level statistics
- df_stats = df_init[["group_id", "rollout_index", "step_index"]].copy()
+ # Step 1: calculate rollout-level token statistics and step-level reward statistics
+ df_stats = df_init[["group_id", "rollout_index", "step_index", "rewards"]].copy()
df_stats["num_tokens"] = df_init["input_ids"].apply(len)
- # We assume that rewards for all tokens are the same
- df_stats["rollout_reward"] = df_init["rewards"].apply(lambda x: x[0])
- # Check that the reward is the same for each step in the rollout
- assert df_stats.groupby(["group_id", "rollout_index"])["rollout_reward"].nunique().max() == 1
- # Only keep step_index == 0
- df_stats = df_stats[df_stats["step_index"] == 0].drop(columns=["step_index"])
+ df_stats["step_reward"] = df_stats["rewards"].apply(lambda rewards: rewards[0])
+ df_rollouts = (
+ df_stats.groupby(["group_id", "rollout_index"])
+ .agg(
+ rollout_tokens=("num_tokens", "sum"),
+ )
+ .reset_index()
+ )
+ df_group_tokens = (
+ df_rollouts.groupby("group_id")
+ .agg(
+ group_tokens=("rollout_tokens", "mean"),
+ )
+ .reset_index()
+ )
df_grouped = (
- df_stats.groupby("group_id")
+ df_stats.groupby(["group_id", "step_index"])
.agg(
- rollout_reward_sum=("rollout_reward", "sum"),
- rollout_reward_count=("rollout_reward", "count"),
- rollout_reward_std=("rollout_reward", "std"),
- group_tokens=("num_tokens", "mean"),
+ step_reward_sum=("step_reward", "sum"),
+ step_reward_count=("step_reward", "count"),
+ step_reward_std=("step_reward", "std"),
)
.reset_index()
)
- assert df_grouped.columns.tolist() == [
+ assert df_group_tokens.columns.tolist() == [
"group_id",
- "rollout_reward_sum",
- "rollout_reward_count",
- "rollout_reward_std",
"group_tokens",
]
+ assert df_grouped.columns.tolist() == [
+ "group_id",
+ "step_index",
+ "step_reward_sum",
+ "step_reward_count",
+ "step_reward_std",
+ ]
# Step 2: calculate advantages for each sample
df_advantages = pd.merge(
- df_init[["group_id", "rollout_index", "step_index", "rewards"]],
+ df_stats[["group_id", "rollout_index", "step_index", "rewards", "step_reward"]],
df_grouped,
- on="group_id",
+ on=["group_id", "step_index"],
how="left"
)
+ df_advantages = pd.merge(df_advantages, df_group_tokens, on="group_id", how="left")
assert len(df_advantages) == len(df_init)
+
def calculate_advantages(row):
rewards = row["rewards"]
- group_sum = row["rollout_reward_sum"]
- group_count = row["rollout_reward_count"]
- current_reward = rewards[0]
+ group_sum = row["step_reward_sum"]
+ group_count = row["step_reward_count"]
+ current_reward = row["step_reward"]
if group_count > 1:
loo_mean = (group_sum - current_reward) / (group_count - 1)
else:
loo_mean = current_reward
- std = row["rollout_reward_std"]
+ std = row["step_reward_std"]
if config.divide_advantage_by_std:
return [(r - loo_mean) / (np.nan_to_num(std) + 1e-4) for r in rewards]
return [(r - loo_mean) for r in rewards]
df_advantages["advantages"] = df_advantages.apply(calculate_advantages, axis=1)
df_advantages = df_advantages.drop(
- columns=["rewards", "rollout_reward_sum", "rollout_reward_count", "rollout_reward_std"]
+ columns=["rewards", "step_reward", "step_reward_sum", "step_reward_count", "step_reward_std"]
)
assert df_advantages.columns.tolist() == [
"group_id",
diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py
new file mode 100644
index 00000000..e48ec22c
--- /dev/null
+++ b/pipelinerl/rl_tool_parser_plugin.py
@@ -0,0 +1,202 @@
+"""
+Tool parser plugin for RL tool calling format.
+"""
+
+import json
+import logging
+import re
+
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
+from vllm.entrypoints.openai.tool_parsers import ToolParserManager
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ExtractedToolCallInformation,
+ ToolCall,
+ FunctionCall,
+)
+
+_JSON_SCALAR_TYPES = (dict, list, str, int, float, bool)
+
+
+def _build_tool_call(index: int, parsed: dict, *, force_id: str | None = None) -> ToolCall | None:
+ try:
+ args_obj = parsed.get("arguments", {})
+ if not isinstance(args_obj, _JSON_SCALAR_TYPES):
+ args_obj = {}
+ call_id = force_id if force_id is not None else parsed.get("id", f"call_{index}")
+ return ToolCall(
+ id=call_id,
+ type="function",
+ function=FunctionCall(
+ name=str(parsed.get("name", "")),
+ arguments=json.dumps(args_obj, ensure_ascii=False),
+ ),
+ )
+ except Exception:
+ logging.getLogger("pipelinerl.tool_parser").debug(
+ "Skipping malformed tool call", exc_info=True
+ )
+ return None
+
+
+@ToolParserManager.register_module("rl_tool")
+class HermesRLToolParser(ToolParser):
+ """
+ Tool parser for RL tool calling format using markers.
+ Supports both standard format and Apriel-style formats:
+ - [{...}, {...}] (preferred if present)
+ - [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE] wrapper
+ """
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+
+ self.tool_call_start_token = ""
+ self.tool_call_end_token = ""
+
+ self.tool_call_regex = re.compile(
+ r"(.*?)|(.*)", re.DOTALL
+ )
+
+ self.apriel_final_response_regex = re.compile(
+ r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]", re.DOTALL
+ )
+ # Lenient match: case-insensitive and tolerate a missing closing tag.
+ self.apriel_tool_calls_regex = re.compile(
+ r"\s*(.*?)\s*(?:|$)", re.DOTALL | re.IGNORECASE
+ )
+
+ # vLLM streaming hooks expect these attributes on the parser instance.
+ self.current_tool_name_sent = False
+ self.prev_tool_call_arr = []
+ self.current_tool_id = -1
+ self.streamed_args_for_tool = []
+
+ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ logger = logging.getLogger("pipelinerl.tool_parser")
+ final_response_match = None
+
+ try:
+ tool_calls_matches = list(self.apriel_tool_calls_regex.finditer(model_output))
+ if tool_calls_matches:
+ last_match = tool_calls_matches[-1]
+ tool_calls_json = last_match.group(1).strip()
+ parsed_calls = []
+ try:
+ parsed_calls = json.loads(tool_calls_json) if tool_calls_json else []
+ except Exception:
+ logger.debug("Failed to parse aggregated JSON; falling back", exc_info=True)
+ parsed_calls = []
+
+ tool_calls = [
+ tc for tc in (_build_tool_call(i, pc) for i, pc in enumerate(parsed_calls))
+ if tc is not None
+ ]
+
+ final_response_match = self.apriel_final_response_regex.search(model_output)
+ content = final_response_match.group(1).strip() if final_response_match else ""
+
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=content,
+ )
+
+ try:
+ tools_declared = bool(getattr(request, "tools", None))
+ except Exception:
+ tools_declared = False
+
+ if tools_declared:
+ candidate_strings: list[str] = []
+ final_response_match = self.apriel_final_response_regex.search(model_output)
+ if final_response_match:
+ candidate_strings.append(final_response_match.group(1).strip())
+ candidate_strings.append(model_output.strip())
+
+ for candidate in candidate_strings:
+ try:
+ parsed = json.loads(candidate)
+ except Exception:
+ continue
+ parsed_list = []
+ if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed:
+ parsed_list = [parsed]
+ elif isinstance(parsed, list) and all(isinstance(it, dict) for it in parsed):
+ parsed_list = [it for it in parsed if "name" in it and "arguments" in it]
+ if not parsed_list:
+ continue
+ tool_calls = [
+ tc for tc in (_build_tool_call(i, pc) for i, pc in enumerate(parsed_list))
+ if tc is not None
+ ]
+ content = final_response_match.group(1).strip() if final_response_match else ""
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=content,
+ )
+
+ # Fallback: legacy blocks.
+ content_to_search = model_output
+ final_response_match = self.apriel_final_response_regex.search(model_output)
+ if final_response_match:
+ final_response_content = final_response_match.group(1).strip()
+ if self.tool_call_start_token in final_response_content:
+ content_to_search = final_response_content
+ elif self.tool_call_start_token not in model_output:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=final_response_content
+ )
+
+ if self.tool_call_start_token not in content_to_search:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output
+ )
+
+ function_call_tuples = self.tool_call_regex.findall(content_to_search)
+
+ tool_calls = []
+ for i, match in enumerate(function_call_tuples):
+ json_str = match[0] if match[0] else match[1]
+ try:
+ parsed_call = json.loads(json_str.strip())
+ except Exception:
+ logger.debug("Skipping malformed JSON", exc_info=True)
+ continue
+ tc = _build_tool_call(i, parsed_call, force_id=f"call_{i}")
+ if tc is not None:
+ tool_calls.append(tc)
+
+ if tool_calls and final_response_match:
+ content = ""
+ elif final_response_match:
+ content = final_response_match.group(1).strip()
+ else:
+ content = model_output
+
+ return ExtractedToolCallInformation(
+ tools_called=bool(tool_calls),
+ tool_calls=tool_calls,
+ content=content
+ )
+
+ except Exception:
+ # Never propagate to the vLLM server.
+ logger.exception("Tool parser encountered an exception; returning safe fallback.")
+ if final_response_match:
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=final_response_match.group(1).strip()
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False,
+ tool_calls=[],
+ content=model_output
+ )
+
\ No newline at end of file
diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py
index 1200ba23..4c71dda7 100644
--- a/pipelinerl/rollouts.py
+++ b/pipelinerl/rollouts.py
@@ -1,5 +1,6 @@
+from dataclasses import dataclass
from pydantic import BaseModel, Field
-from typing import List, Optional, Dict
+from typing import List, Optional, Dict, Sequence
import numpy as np
class BaseMetrics(BaseModel):
@@ -65,3 +66,32 @@ class RolloutResult(BaseModel):
dataset_name: str | None = None
group_id: str | None = None
domain: str | None = None
+
+
+@dataclass(frozen=True)
+class TrainingTextSummary:
+ prompt_tokens: list[int]
+ output_tokens: list[int]
+ overflow: bool
+ num_turns: int
+
+
+def apply_rollout_reward(training_texts: Sequence[TrainingText], reward: float) -> list[TrainingText]:
+ texts = list(training_texts)
+ for training_text in texts:
+ training_text.reward = reward
+ return texts
+
+
+def rollout_has_overflow(training_texts: Sequence[TrainingText]) -> bool:
+ return any(not training_text.finished for training_text in training_texts)
+
+
+def summarize_training_texts(training_texts: Sequence[TrainingText]) -> TrainingTextSummary:
+ texts = list(training_texts)
+ return TrainingTextSummary(
+ prompt_tokens=[training_text.prompt_tokens for training_text in texts],
+ output_tokens=[training_text.output_tokens for training_text in texts],
+ overflow=rollout_has_overflow(texts),
+ num_turns=len(texts),
+ )
diff --git a/pyproject.toml b/pyproject.toml
index 81ecbd42..13060791 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,9 +73,12 @@ ifeval = [
"langdetect",
"absl-py",
]
+tir = [
+ "sandbox-fusion>=0.3.7",
+]
# Install all domain dependencies
domains = [
- "pipelinerl[coding,fn_calling,logic,ifeval]",
+ "pipelinerl[coding,fn_calling,logic,ifeval,tir]",
]
[tool.uv]