diff --git a/conf/tir.yaml b/conf/tir.yaml new file mode 100644 index 00000000..606ab2dc --- /dev/null +++ b/conf/tir.yaml @@ -0,0 +1,96 @@ +defaults: + - base + - override rewards: success_and_format + - _self_ + +actor: + rollout_policy: pipelinerl.domains.tir.generate_tir_rollout + system_prompt: | + You are a math-focused AI Agent. Solve problems by combining clear symbolic reasoning + with short, deterministic Python code. + Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration. + Always present the final answer in LaTeX \boxed{}. + Do not express emotions or opinions about user questions. + + Workflow: + 1. Draft a brief plan in plain text. + 2. Execute one run_python_code call to compute or verify the result. + 3. Finalize by calling MathAnswer with the LaTeX-formatted answer. + + Python execution policy (run_python_code): + - Use Python strictly for pure computation to verify and validate the final answer. + - No network, file system, OS or environment access. + - Keep snippets minimal and self-contained; print only the final result. + + Validation: + - Cross-check results (alternative derivation, invariants, higher precision) before finalizing. + - If execution fails, propose the minimal fix and retry. + Always verify with run_python_code before invoking MathAnswer. + task_template: "{task}" + agent_max_loops: 3 + llm_max_rollouts: 128 + max_rollout_retries: 20 + rollout_workers: 8 + shared_memory_entry_size: 1000000000 + +rewards: + correct_answer_not_finished: 0.0 + buffer_tokens: 0 + +# Math verifier environment +environments: + - key: math + mode: remote + _target_: pipelinerl.domains.math.MathEnvironment +environment_key: math +dataset_loader: pipelinerl.domains.math.load_datasets + +train_dataset_names: + - open_reasoner_zero_57k + - open_reasoner_zero_extended_72k +test_dataset_names: + - aime_2025 + +# SandboxFusion config +sandbox_endpoint: ${oc.env:SANDBOX_ENDPOINT,http://127.0.0.1:8080} +sandbox_timeout: 10.0 + +# Optional reward shaping +python_tool_shaping: + bonus_on_correct_with_python: 0.1 + penalty_on_incorrect_without_python: 0.1 + max_abs: 0.2 + +# vLLM tool-call parser config +vllm_config: + vllm_kwargs: + enable-auto-tool-choice: "" + tool-call-parser: rl_tool + tool-parser-plugin: ${hydra:runtime.cwd}/pipelinerl/rl_tool_parser_plugin.py + max_model_len: 32000 + +llm: + parameters: + max_tokens: 16000 + temperature: 1.0 + +test_llm: + parameters: + max_tokens: 16000 + temperature: 1.0 + top_p: 0.95 + top_k: 50 + +finetune: + seq_length: 32000 + seq_parallel: 8 + gradient_accumulation_passes: 1024 + rl: + policy_loss: gspo + overlong_filtering: true + +preprocess: + input: actor + output: training_data + n_workers: 8 + shared_memory_entry_size: 1000000000 diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index d276747d..67c71a59 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -25,7 +25,7 @@ from pipelinerl.finetune_loop import calculate_train_steps from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb from pipelinerl.llm import TrainableLLM -from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.rollouts import BaseMetrics, RolloutResult, rollout_has_overflow from pipelinerl.shared_memory_array import SharedMemoryQueue from pipelinerl.state import TrainerState from pipelinerl.streams import ( @@ -399,7 +399,7 @@ def init_stats(self): def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} - metrics['overflow'] = all([not training_text.finished for training_text in result.training_texts ]) + metrics['overflow'] = rollout_has_overflow(result.training_texts) metrics['num_turns'] = len(result.training_texts) metrics['prompt_tokens'] = [training_text.prompt_tokens for training_text in result.training_texts] metrics['output_tokens'] = [training_text.output_tokens for training_text in result.training_texts] diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index b305c458..26ca3da1 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -3,12 +3,13 @@ import logging import aiohttp +import litellm import numpy as np from PIL import Image from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM from pipelinerl.finetune.data import MASKED_TOKEN_ID -from pipelinerl.rollouts import TrainingText +from pipelinerl.rollouts import TrainingText, apply_rollout_reward from pipelinerl.processor_factory import get_processor from omegaconf import DictConfig, ListConfig, OmegaConf @@ -83,7 +84,10 @@ def _is_retryable_abort_response(data: dict, collect_logprobs: bool) -> tuple[bo async def llm_async_generate( - llm: TrainableLLM, prompt: Prompt, session: aiohttp.ClientSession + llm: TrainableLLM, + prompt: Prompt, + session: aiohttp.ClientSession, + max_tokens_override: int | None = None, ) -> LLMCall: llm.load_tokenizer() headers = {"Content-Type": "application/json"} @@ -117,6 +121,12 @@ async def llm_async_generate( logger.debug(f"POST request to {llm.base_url}/v1/chat/completions") + if prompt.tools: + data["tools"] = _to_plain_obj(prompt.tools) + + if max_tokens_override is not None: + data["max_tokens"] = max_tokens_override + # Merge extra_parameters first so that data (model, messages, logprobs settings) takes precedence payload = _to_plain_obj({**extra_parameters, **data}) response_data = None @@ -161,7 +171,8 @@ async def llm_async_generate( try: content = response_data["choices"][0]["message"]["content"] - if not content: + raw_tool_calls = response_data["choices"][0]["message"].get("tool_calls", []) + if not content and not raw_tool_calls: logger.warning(f"Empty completion {response_data}") parsed_logprobs = [] @@ -188,7 +199,9 @@ async def llm_async_generate( logger.exception(f"Failed to parse llm response: {response_data}") raise - output = LLMOutput(content=content) + output = LLMOutput(content=content or "") + if raw_tool_calls: + output.tool_calls = [litellm.ChatCompletionMessageToolCall(**tc) for tc in raw_tool_calls] llm_call = llm.log_output(prompt, output, count_tokens=False) llm_call.prompt_length_tokens = response_data["usage"]["prompt_tokens"] llm_call.output_length_tokens = response_data["usage"]["completion_tokens"] @@ -210,9 +223,20 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: images = [] use_processor = False visual_features = None - full_messages = llm_call.prompt.messages + [ - {"role": "assistant", "content": llm_call.output.content} - ] + assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} + if llm_call.output.tool_calls: + assistant_msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in llm_call.output.tool_calls + ] + full_messages = llm_call.prompt.messages + [assistant_msg] if hasattr(llm_call.prompt, "messages"): images = extract_images_from_messages(llm_call.prompt.messages) @@ -265,6 +289,8 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: else: # Use tokenizer for text-only models chat_kwargs = _to_plain_obj(llm.chat_template_kwargs) if llm.chat_template_kwargs else {} + if llm_call.prompt.tools: + chat_kwargs = {**chat_kwargs, "tools": _to_plain_obj(llm_call.prompt.tools)} prompt_text = llm.tokenizer.apply_chat_template( conversation=llm_call.prompt.messages, tokenize=False, @@ -285,7 +311,6 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: output_text = text[len(prompt_text) :] - # Get the appropriate tokenizer (from processor if using vision model) tokenizer = processor.tokenizer if use_processor else llm.tokenizer if tokenizer.bos_token and text.startswith(tokenizer.bos_token): @@ -304,7 +329,7 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: finished = finish_reason != "length" else: eos_token = tokenizer.eos_token or "" - finished = bool(eos_token) and llm_call.output.content.endswith(eos_token) + finished = bool(eos_token) and (llm_call.output.content or "").endswith(eos_token) prompt_tokens = llm_call.prompt_length_tokens output_tokens = llm_call.output_length_tokens @@ -319,3 +344,14 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: output_tokens=output_tokens, visual_features=visual_features, ) + + +def make_training_texts_from_llm_calls( + llm: TrainableLLM, + llm_calls: list[LLMCall], + reward: float | None = None, +) -> list[TrainingText]: + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + if reward is not None: + training_texts = apply_rollout_reward(training_texts, reward) + return training_texts diff --git a/pipelinerl/domains/math/__init__.py b/pipelinerl/domains/math/__init__.py index 9aee0b8f..7a9809b7 100644 --- a/pipelinerl/domains/math/__init__.py +++ b/pipelinerl/domains/math/__init__.py @@ -1,3 +1,3 @@ from .load_datasets import load_datasets -from .rollouts import generate_math_rollout, RewardTable +from .rollouts import generate_math_rollout, RewardTable, get_reward, length_penalty from .verifier_api import MathEnvironment, verify_answer, verify_answer_rpc \ No newline at end of file diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 3e6560c0..8b712ff0 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -55,6 +55,28 @@ def log_config(self, domain: str = "unknown") -> None: f"buffer_tokens={self.buffer_tokens}" ) +def get_reward(answer_status: str, finished: bool, reward_table: RewardTable) -> float: + match (answer_status, finished): + case ("wrong", False): + return reward_table.wrong_answer_not_finished + case ("wrong", True): + return reward_table.wrong_answer_finished + case ("no_answer", False): + return reward_table.no_answer_not_finished + case ("no_answer", True): + return reward_table.no_answer_finished + case ("unparsable", False): + return reward_table.unparsable_not_finished + case ("unparsable", True): + return reward_table.unparsable_finished + case ("correct", False): + return reward_table.correct_answer_not_finished + case ("correct", True): + return reward_table.correct_answer_finished + case _: + raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{finished}") + + def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: """ Compute the overlong penalty @@ -100,25 +122,7 @@ async def generate_math_rollout( trace = make_training_text(llm, llm_call) # Determine reward based on answer status and finished state - match (answer_status, trace.finished): - case ("wrong", False): - reward = rewards.wrong_answer_not_finished - case ("wrong", True): - reward = rewards.wrong_answer_finished - case ("no_answer", False): - reward = rewards.no_answer_not_finished - case ("no_answer", True): - reward = rewards.no_answer_finished - case ("unparsable", False): - reward = rewards.unparsable_not_finished - case ("unparsable", True): - reward = rewards.unparsable_finished - case ("correct", False): - reward = rewards.correct_answer_not_finished - case ("correct", True): - reward = rewards.correct_answer_finished - case _: - raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{trace.finished}") + reward = get_reward(answer_status, trace.finished, rewards) # Apply discount factor based on output length reward *= discount_factor**llm_call.output_length_tokens diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py index 8c489291..89e9cc5a 100644 --- a/pipelinerl/domains/miniwob/rollouts.py +++ b/pipelinerl/domains/miniwob/rollouts.py @@ -18,9 +18,9 @@ from tapeagents.remote_environment import AsyncRemoteEnvironment from tapeagents.tools.simple_browser import PageObservation -from pipelinerl.async_llm import make_training_text +from pipelinerl.async_llm import make_training_texts_from_llm_calls from pipelinerl.llm import LLMCall, TrainableLLM -from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.rollouts import BaseMetrics, RolloutResult, summarize_training_texts from pipelinerl.world import Job from .steps import WebTape @@ -271,13 +271,8 @@ async def _execute_rollout_with_timeout( ] # (4) # For each LLM interaction in the tape, make a training example. - all_finished = 1 - prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] - output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] - training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] - for text in training_texts: - text.reward = reward - all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 + training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward) + training_summary = summarize_training_texts(training_texts) latency = time.time() - start_time agent_time = tape.metadata.result.get("agent_execution_time", -1.0) @@ -289,7 +284,7 @@ async def _execute_rollout_with_timeout( success=reward > 0.5, no_error=no_error, no_answer=reward < 0, - overflow=not all_finished, + overflow=training_summary.overflow, n_llm_calls=n_llm_calls, n_step_errors=n_step_errors, n_page_observations=n_page_observations, @@ -307,8 +302,6 @@ async def _execute_rollout_with_timeout( latency=latency, dataset_name=problem["dataset"], domain="miniwob", - prompt_tokens=prompt_tokens, - output_tokens=output_tokens, ) @@ -340,6 +333,4 @@ def _create_failed_rollout_result(problem: dict, start_time: float, error_type: latency=latency, dataset_name=problem["dataset"], domain="miniwob", - prompt_tokens=[], - output_tokens=[], ) diff --git a/pipelinerl/domains/tir/__init__.py b/pipelinerl/domains/tir/__init__.py new file mode 100644 index 00000000..4a658bd0 --- /dev/null +++ b/pipelinerl/domains/tir/__init__.py @@ -0,0 +1 @@ +from .rollouts import generate_tir_rollout diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py new file mode 100644 index 00000000..c25a30b5 --- /dev/null +++ b/pipelinerl/domains/tir/rollouts.py @@ -0,0 +1,385 @@ +import asyncio +import json +import logging +import random +import re +import time +from dataclasses import dataclass +from typing import Awaitable, Callable + +import aiohttp +from omegaconf import DictConfig + +from sandbox_fusion import RunCodeRequest, set_sandbox_endpoint, run_code_async + +from pipelinerl.async_llm import llm_async_generate, make_training_texts_from_llm_calls +from pipelinerl.domains.math import RewardTable, get_reward, length_penalty, verify_answer_rpc +from pipelinerl.llm import Prompt, TrainableLLM +from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.utils import get_environment_jobs, resolve_environment_key + +logger = logging.getLogger(__name__) + +_SANDBOX_CONFIGURED = False + +_BLOCKED_PATTERNS = [ + re.compile(r"\bsys\.exit\b"), + re.compile(r"\bos\._exit\b"), + re.compile(r"\bos\.system\b"), + re.compile(r"\bsubprocess\b"), + re.compile(r"\bos\.popen\b"), + re.compile(r"\bos\.exec\w*\b"), + re.compile(r"\bos\.spawn\w*\b"), + re.compile(r"\bos\.kill\b"), + re.compile(r"\bshutil\.rmtree\b"), + re.compile(r"\bos\.remove\b"), + re.compile(r"\bos\.unlink\b"), +] + + +def build_tool_definitions() -> list[dict]: + return [ + { + "type": "function", + "function": { + "name": "run_python_code", + "description": "Execute Python code. Print only the final result.", + "parameters": { + "type": "object", + "properties": {"code": {"type": "string", "description": "Python code to execute"}}, + "required": ["code"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "MathAnswer", + "description": "Submit the final answer in LaTeX \\boxed{} format.", + "parameters": { + "type": "object", + "properties": {"answer": {"type": "string", "description": "The final answer"}}, + "required": ["answer"], + }, + }, + }, + ] + + +def _check_code_safety(code: str) -> str | None: + for pattern in _BLOCKED_PATTERNS: + if pattern.search(code): + return f"Blocked: code contains forbidden pattern '{pattern.pattern}'" + return None + + +async def execute_python_sandbox(code: str, endpoint: str, timeout: float) -> str: + """Execute Python code via SandboxFusion and return formatted output.""" + global _SANDBOX_CONFIGURED + if not _SANDBOX_CONFIGURED: + set_sandbox_endpoint(endpoint) + _SANDBOX_CONFIGURED = True + logger.info("Configured SandboxFusion endpoint: %s", endpoint) + + rejection = _check_code_safety(code) + if rejection is not None: + return rejection + + try: + request = RunCodeRequest(code=code, language="python", run_timeout=timeout) + response = await run_code_async(request) + + stdout = "" + stderr = "" + if response.run_result: + stdout = response.run_result.stdout or "" + stderr = response.run_result.stderr or "" + + status = response.status.value if hasattr(response.status, "value") else str(response.status) + is_timeout = "timeout" in status.lower() or "timeout" in (response.message or "").lower() + + parts = [] + if stdout: + parts.append(stdout.rstrip()) + if stderr: + parts.append(f"[stderr]\n{stderr.rstrip()}") + if is_timeout: + parts.append("[execution timed out]") + if not parts: + parts.append("[no output]") + return "\n".join(parts) + + except asyncio.TimeoutError: + return "[execution timed out]" + except Exception as exc: + logger.warning("SandboxFusion error: %s", exc) + return f"[execution error: {exc}]" + + +def _serialize_tool_calls(tool_calls) -> list[dict]: + """Serialize litellm tool call objects to dicts for conversation history.""" + return [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ] + + +def _parse_tool_arguments(arguments: str, *, fallback_key: str | None = None) -> dict: + """Parse tool-call arguments into an object payload. + + Valid JSON that is not an object should not crash the rollout loop. A bare + string can still be recovered for simple single-field tool schemas. + """ + try: + parsed = json.loads(arguments) + except (json.JSONDecodeError, TypeError): + return {} + if isinstance(parsed, dict): + return parsed + if fallback_key is not None and isinstance(parsed, str): + return {fallback_key: parsed} + return {} + + +@dataclass +class _ToolContext: + sandbox_endpoint: str + sandbox_timeout: float + messages: list[dict] + final_answer: str | None = None + submitted_final_answer: bool = False + num_python_calls: int = 0 + + +ToolHandler = Callable[[object, _ToolContext], Awaitable[None]] + + +async def _handle_math_answer(tc, ctx: _ToolContext) -> None: + args = _parse_tool_arguments(tc.function.arguments, fallback_key="answer") + ctx.final_answer = args.get("answer", "") + ctx.submitted_final_answer = True + ctx.messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": f"Answer submitted: {ctx.final_answer}", + }) + + +async def _handle_run_python_code(tc, ctx: _ToolContext) -> None: + args = _parse_tool_arguments(tc.function.arguments, fallback_key="code") + code = args.get("code") or args.get("python_code", "") + result = await execute_python_sandbox(code, ctx.sandbox_endpoint, ctx.sandbox_timeout) + ctx.num_python_calls += 1 + ctx.messages.append({"role": "tool", "tool_call_id": tc.id, "content": result}) + + +async def _handle_unknown_tool(tc, ctx: _ToolContext) -> None: + ctx.messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": f"Unknown tool: {tc.function.name}", + }) + + +_TOOL_HANDLERS: dict[str, ToolHandler] = { + "MathAnswer": _handle_math_answer, + "run_python_code": _handle_run_python_code, +} + + +class RewardShaper: + def __init__(self, cfg: DictConfig, llm: TrainableLLM): + self._python_cfg = getattr(cfg, "python_tool_shaping", None) + self._length_cfg = getattr(cfg, "length_shaping", None) + self._max_gen_tokens = int(llm.parameters.get("max_tokens", 2048)) + + def compute(self, answer_status: str, num_python_calls: int, llm_calls: list) -> float: + return ( + self._python_tool_bonus(answer_status, num_python_calls) + + self._length_adjustment(answer_status, llm_calls) + ) + + def _python_tool_bonus(self, answer_status: str, num_python_calls: int) -> float: + cfg = self._python_cfg + if cfg is None: + return 0.0 + bonus = float(getattr(cfg, "bonus_on_correct_with_python", 0.0)) + penalty = float(getattr(cfg, "penalty_on_incorrect_without_python", 0.0)) + max_abs = float(getattr(cfg, "max_abs", 0.2)) + total = 0.0 + if answer_status == "correct" and num_python_calls >= 1: + total += bonus + if answer_status in ("wrong", "unparsable") and num_python_calls == 0: + total -= penalty + return max(-max_abs, min(max_abs, total)) + + def _length_adjustment(self, answer_status: str, llm_calls: list) -> float: + cfg = self._length_cfg + if cfg is None or not llm_calls: + return 0.0 + if hasattr(cfg, "target_ratio"): + ratio = float(getattr(cfg, "target_ratio")) + target = int(max(1, ratio * self._max_gen_tokens)) + target = max(int(getattr(cfg, "min_target_tokens", 0)), target) + target = min(int(getattr(cfg, "max_target_tokens", 10**9)), target) + else: + target = int(getattr(cfg, "target_output_tokens", 512)) + slope = float(getattr(cfg, "slope", 0.0)) + max_penalty = float(getattr(cfg, "max_penalty", 0.0)) + bonus_short_correct = float(getattr(cfg, "bonus_on_short_correct", 0.0)) + + avg_out = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) / len(llm_calls) + total = 0.0 + if slope > 0.0 and max_penalty > 0.0 and avg_out > target: + total -= min(max_penalty, slope * (avg_out - target)) + if bonus_short_correct > 0.0 and answer_status == "correct" and avg_out <= target: + total += bonus_short_correct + return total + + +class Metrics(BaseMetrics): + num_python_calls: int = 0 + num_steps: int = 0 + overflow: bool = False + + +async def generate_tir_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + start = time.perf_counter() + + messages: list[dict] = [] + if cfg.actor.system_prompt: + messages.append({"role": "system", "content": cfg.actor.system_prompt}) + messages.append({"role": "user", "content": cfg.actor.task_template.format(task=problem["task"])}) + + tools = build_tool_definitions() + + llm_calls = [] + ctx = _ToolContext( + sandbox_endpoint=str(cfg.sandbox_endpoint), + sandbox_timeout=float(cfg.sandbox_timeout), + messages=messages, + ) + agent_max_loops = int(getattr(cfg.actor, "agent_max_loops", 3)) + configured_max_tokens = int(llm.parameters.get("max_tokens", 16000)) + max_model_len = int(cfg.vllm_config.vllm_kwargs.get("max_model_len", 32000)) + min_generation_tokens = 256 + + for _turn in range(agent_max_loops): + prompt = Prompt(messages=list(messages), tools=tools) + + llm.load_tokenizer() + prompt_token_ids = llm.tokenizer.apply_chat_template( + messages, + add_special_tokens=True, + add_generation_prompt=True, + tools=tools, + ) + prompt_len = len(prompt_token_ids) + remaining = max_model_len - prompt_len + if remaining < min_generation_tokens: + logger.warning( + "Prompt length %d leaves only %d tokens for generation (max_model_len=%d), stopping loop", + prompt_len, remaining, max_model_len, + ) + break + max_tokens_this_turn = min(configured_max_tokens, remaining) + if max_tokens_this_turn < configured_max_tokens: + logger.warning( + "Turn %d: capping max_tokens from %d to %d (prompt_len=%d, max_model_len=%d)", + _turn, configured_max_tokens, max_tokens_this_turn, prompt_len, max_model_len, + ) + + llm_call = await llm_async_generate(llm, prompt, session, max_tokens_override=max_tokens_this_turn) + llm_calls.append(llm_call) + + if not llm_call.output.tool_calls: + break + + assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} + assistant_msg["tool_calls"] = _serialize_tool_calls(llm_call.output.tool_calls) + messages.append(assistant_msg) + + for tc in llm_call.output.tool_calls: + handler = _TOOL_HANDLERS.get(tc.function.name, _handle_unknown_tool) + await handler(tc, ctx) + if ctx.submitted_final_answer: + break + + if ctx.submitted_final_answer: + break + + if ctx.final_answer is not None: + prediction = ctx.final_answer + elif llm_calls: + prediction = llm_calls[-1].output.content or "" + else: + prediction = "" + + env_key = resolve_environment_key(cfg, default="math") + env_jobs = get_environment_jobs(cfg, env_key) + if not env_jobs: + raise RuntimeError("No environment servers available for math domain") + env_job = random.choice(env_jobs) + assert env_job.port is not None + answer_status = await verify_answer_rpc( + session=session, + host=env_job.hostname, + port=env_job.port, + prediction=prediction, + gold=problem["answer"], + strict=True, + ) + + reward_table = RewardTable(**dict(cfg.rewards)) + base_reward = get_reward(answer_status, ctx.submitted_final_answer, reward_table) + + discount_factor = float(getattr(cfg.actor, "discount_factor", 1.0)) + if discount_factor != 1.0: + total_generated_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) + base_reward *= discount_factor ** total_generated_tokens + + buffer_tokens = getattr(reward_table, "buffer_tokens", 0) + if buffer_tokens: + max_tokens = int(llm.parameters.get("max_tokens", 0)) + total_output_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) + if max_tokens > 0: + base_reward += length_penalty(max_tokens, total_output_tokens, buffer_tokens) + + shaping = RewardShaper(cfg, llm).compute(answer_status, ctx.num_python_calls, llm_calls) + reward = base_reward + shaping + + training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward) + for text in training_texts: + text.finished = ctx.submitted_final_answer + + latency = time.perf_counter() - start + + metrics = Metrics( + reward=reward, + success=answer_status == "correct", + no_error=answer_status != "unparsable", + no_answer=answer_status == "no_answer", + num_python_calls=ctx.num_python_calls, + num_steps=len(llm_calls), + overflow=not ctx.submitted_final_answer, + ) + + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem.get("dataset"), + domain="tir", + ) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 2ba28c5e..dea58599 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -1,7 +1,7 @@ import logging import os from functools import partial -from typing import Any +from typing import Any, TYPE_CHECKING from pydantic import BaseModel, Field import numpy as np @@ -9,10 +9,14 @@ import torch import torch.nn.functional as F from datasets import Dataset -from transformers import PreTrainedModel from pipelinerl.finetune.types import PipelineBatchEncoding from pipelinerl.finetune.rl.utils import per_segment_sums +if TYPE_CHECKING: + from transformers import PreTrainedModel +else: + PreTrainedModel = Any + from .utils import ( sum_sum, mean_sum, @@ -452,58 +456,72 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R df_init = pd.DataFrame(dataset) assert isinstance(df_init, pd.DataFrame) - # Step 1: calculate group-level statistics - df_stats = df_init[["group_id", "rollout_index", "step_index"]].copy() + # Step 1: calculate rollout-level token statistics and step-level reward statistics + df_stats = df_init[["group_id", "rollout_index", "step_index", "rewards"]].copy() df_stats["num_tokens"] = df_init["input_ids"].apply(len) - # We assume that rewards for all tokens are the same - df_stats["rollout_reward"] = df_init["rewards"].apply(lambda x: x[0]) - # Check that the reward is the same for each step in the rollout - assert df_stats.groupby(["group_id", "rollout_index"])["rollout_reward"].nunique().max() == 1 - # Only keep step_index == 0 - df_stats = df_stats[df_stats["step_index"] == 0].drop(columns=["step_index"]) + df_stats["step_reward"] = df_stats["rewards"].apply(lambda rewards: rewards[0]) + df_rollouts = ( + df_stats.groupby(["group_id", "rollout_index"]) + .agg( + rollout_tokens=("num_tokens", "sum"), + ) + .reset_index() + ) + df_group_tokens = ( + df_rollouts.groupby("group_id") + .agg( + group_tokens=("rollout_tokens", "mean"), + ) + .reset_index() + ) df_grouped = ( - df_stats.groupby("group_id") + df_stats.groupby(["group_id", "step_index"]) .agg( - rollout_reward_sum=("rollout_reward", "sum"), - rollout_reward_count=("rollout_reward", "count"), - rollout_reward_std=("rollout_reward", "std"), - group_tokens=("num_tokens", "mean"), + step_reward_sum=("step_reward", "sum"), + step_reward_count=("step_reward", "count"), + step_reward_std=("step_reward", "std"), ) .reset_index() ) - assert df_grouped.columns.tolist() == [ + assert df_group_tokens.columns.tolist() == [ "group_id", - "rollout_reward_sum", - "rollout_reward_count", - "rollout_reward_std", "group_tokens", ] + assert df_grouped.columns.tolist() == [ + "group_id", + "step_index", + "step_reward_sum", + "step_reward_count", + "step_reward_std", + ] # Step 2: calculate advantages for each sample df_advantages = pd.merge( - df_init[["group_id", "rollout_index", "step_index", "rewards"]], + df_stats[["group_id", "rollout_index", "step_index", "rewards", "step_reward"]], df_grouped, - on="group_id", + on=["group_id", "step_index"], how="left" ) + df_advantages = pd.merge(df_advantages, df_group_tokens, on="group_id", how="left") assert len(df_advantages) == len(df_init) + def calculate_advantages(row): rewards = row["rewards"] - group_sum = row["rollout_reward_sum"] - group_count = row["rollout_reward_count"] - current_reward = rewards[0] + group_sum = row["step_reward_sum"] + group_count = row["step_reward_count"] + current_reward = row["step_reward"] if group_count > 1: loo_mean = (group_sum - current_reward) / (group_count - 1) else: loo_mean = current_reward - std = row["rollout_reward_std"] + std = row["step_reward_std"] if config.divide_advantage_by_std: return [(r - loo_mean) / (np.nan_to_num(std) + 1e-4) for r in rewards] return [(r - loo_mean) for r in rewards] df_advantages["advantages"] = df_advantages.apply(calculate_advantages, axis=1) df_advantages = df_advantages.drop( - columns=["rewards", "rollout_reward_sum", "rollout_reward_count", "rollout_reward_std"] + columns=["rewards", "step_reward", "step_reward_sum", "step_reward_count", "step_reward_std"] ) assert df_advantages.columns.tolist() == [ "group_id", diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py new file mode 100644 index 00000000..e48ec22c --- /dev/null +++ b/pipelinerl/rl_tool_parser_plugin.py @@ -0,0 +1,202 @@ +""" +Tool parser plugin for RL tool calling format. +""" + +import json +import logging +import re + +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ExtractedToolCallInformation, + ToolCall, + FunctionCall, +) + +_JSON_SCALAR_TYPES = (dict, list, str, int, float, bool) + + +def _build_tool_call(index: int, parsed: dict, *, force_id: str | None = None) -> ToolCall | None: + try: + args_obj = parsed.get("arguments", {}) + if not isinstance(args_obj, _JSON_SCALAR_TYPES): + args_obj = {} + call_id = force_id if force_id is not None else parsed.get("id", f"call_{index}") + return ToolCall( + id=call_id, + type="function", + function=FunctionCall( + name=str(parsed.get("name", "")), + arguments=json.dumps(args_obj, ensure_ascii=False), + ), + ) + except Exception: + logging.getLogger("pipelinerl.tool_parser").debug( + "Skipping malformed tool call", exc_info=True + ) + return None + + +@ToolParserManager.register_module("rl_tool") +class HermesRLToolParser(ToolParser): + """ + Tool parser for RL tool calling format using markers. + Supports both standard format and Apriel-style formats: + - [{...}, {...}] (preferred if present) + - [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE] wrapper + """ + + def __init__(self, tokenizer): + super().__init__(tokenizer) + + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL + ) + + self.apriel_final_response_regex = re.compile( + r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]", re.DOTALL + ) + # Lenient match: case-insensitive and tolerate a missing closing tag. + self.apriel_tool_calls_regex = re.compile( + r"\s*(.*?)\s*(?:|$)", re.DOTALL | re.IGNORECASE + ) + + # vLLM streaming hooks expect these attributes on the parser instance. + self.current_tool_name_sent = False + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.streamed_args_for_tool = [] + + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: + logger = logging.getLogger("pipelinerl.tool_parser") + final_response_match = None + + try: + tool_calls_matches = list(self.apriel_tool_calls_regex.finditer(model_output)) + if tool_calls_matches: + last_match = tool_calls_matches[-1] + tool_calls_json = last_match.group(1).strip() + parsed_calls = [] + try: + parsed_calls = json.loads(tool_calls_json) if tool_calls_json else [] + except Exception: + logger.debug("Failed to parse aggregated JSON; falling back", exc_info=True) + parsed_calls = [] + + tool_calls = [ + tc for tc in (_build_tool_call(i, pc) for i, pc in enumerate(parsed_calls)) + if tc is not None + ] + + final_response_match = self.apriel_final_response_regex.search(model_output) + content = final_response_match.group(1).strip() if final_response_match else "" + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + try: + tools_declared = bool(getattr(request, "tools", None)) + except Exception: + tools_declared = False + + if tools_declared: + candidate_strings: list[str] = [] + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + candidate_strings.append(final_response_match.group(1).strip()) + candidate_strings.append(model_output.strip()) + + for candidate in candidate_strings: + try: + parsed = json.loads(candidate) + except Exception: + continue + parsed_list = [] + if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: + parsed_list = [parsed] + elif isinstance(parsed, list) and all(isinstance(it, dict) for it in parsed): + parsed_list = [it for it in parsed if "name" in it and "arguments" in it] + if not parsed_list: + continue + tool_calls = [ + tc for tc in (_build_tool_call(i, pc) for i, pc in enumerate(parsed_list)) + if tc is not None + ] + content = final_response_match.group(1).strip() if final_response_match else "" + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + # Fallback: legacy blocks. + content_to_search = model_output + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + final_response_content = final_response_match.group(1).strip() + if self.tool_call_start_token in final_response_content: + content_to_search = final_response_content + elif self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_content + ) + + if self.tool_call_start_token not in content_to_search: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + + function_call_tuples = self.tool_call_regex.findall(content_to_search) + + tool_calls = [] + for i, match in enumerate(function_call_tuples): + json_str = match[0] if match[0] else match[1] + try: + parsed_call = json.loads(json_str.strip()) + except Exception: + logger.debug("Skipping malformed JSON", exc_info=True) + continue + tc = _build_tool_call(i, parsed_call, force_id=f"call_{i}") + if tc is not None: + tool_calls.append(tc) + + if tool_calls and final_response_match: + content = "" + elif final_response_match: + content = final_response_match.group(1).strip() + else: + content = model_output + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content + ) + + except Exception: + # Never propagate to the vLLM server. + logger.exception("Tool parser encountered an exception; returning safe fallback.") + if final_response_match: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_match.group(1).strip() + ) + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + \ No newline at end of file diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py index 1200ba23..4c71dda7 100644 --- a/pipelinerl/rollouts.py +++ b/pipelinerl/rollouts.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from pydantic import BaseModel, Field -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Sequence import numpy as np class BaseMetrics(BaseModel): @@ -65,3 +66,32 @@ class RolloutResult(BaseModel): dataset_name: str | None = None group_id: str | None = None domain: str | None = None + + +@dataclass(frozen=True) +class TrainingTextSummary: + prompt_tokens: list[int] + output_tokens: list[int] + overflow: bool + num_turns: int + + +def apply_rollout_reward(training_texts: Sequence[TrainingText], reward: float) -> list[TrainingText]: + texts = list(training_texts) + for training_text in texts: + training_text.reward = reward + return texts + + +def rollout_has_overflow(training_texts: Sequence[TrainingText]) -> bool: + return any(not training_text.finished for training_text in training_texts) + + +def summarize_training_texts(training_texts: Sequence[TrainingText]) -> TrainingTextSummary: + texts = list(training_texts) + return TrainingTextSummary( + prompt_tokens=[training_text.prompt_tokens for training_text in texts], + output_tokens=[training_text.output_tokens for training_text in texts], + overflow=rollout_has_overflow(texts), + num_turns=len(texts), + ) diff --git a/pyproject.toml b/pyproject.toml index 81ecbd42..13060791 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,9 +73,12 @@ ifeval = [ "langdetect", "absl-py", ] +tir = [ + "sandbox-fusion>=0.3.7", +] # Install all domain dependencies domains = [ - "pipelinerl[coding,fn_calling,logic,ifeval]", + "pipelinerl[coding,fn_calling,logic,ifeval,tir]", ] [tool.uv]