Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ package-lock.json
package.json
results
results/
results-batch-jobs
results-batch-jobs/
data/
cache/
dump.rdb
Expand Down
6 changes: 6 additions & 0 deletions conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ finetune:
seed: ${..seed}

actor:
launcher: asyncio
log_each_n_secs: 0
llm_max_rollouts: 64
rollout_workers: 1
Expand All @@ -22,6 +23,10 @@ actor:
result_queue_size: 64
throughput_window_size: 50
shared_memory_entry_size: 10000000
difficulty_aware_penalty:
enabled: false
gamma: 0.5
failure_scale: 0.5
environment: null
preprocess:
input: actor
Expand Down Expand Up @@ -113,6 +118,7 @@ pop_old_data: true
max_lag: null
attempts: ${finetune.attempts}
train_subset: null
test_subset: null
debug:
mode: ""
streams_from: null
Expand Down
106 changes: 106 additions & 0 deletions conf/cube_math_tool.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
defaults:
- base
- override rewards: success_and_format
- _self_

output_dir: results/cube_math_tool/${now:%Y-%m-%d}/${now:%H-%M-%S}
model_path: /mnt/llmd/base_models/Qwen2.5-7B-Instruct
litellm_logging_level: info
ray_debug: 0
ray_local_mode: false

actor:
launcher: ray
ray_num_cpus: null
cube_actor_num_cpus: 1.0
llm_max_rollouts: 128
ray_worker_log_enabled: true
ray_worker_log_path: null
ray_worker_log_level: WARNING
ray_worker_litellm_log_level: CRITICAL

llm:
parameters:
max_tokens: 16000
max_completion_tokens: 16000
temperature: 1.0

test_llm:
parameters:
max_tokens: 16000
max_completion_tokens: 16000
temperature: 1.0
top_p: 0.95

vllm_config:
vllm_kwargs:
max_model_len: 32000
served_model_name: Qwen2.5-7B-Instruct
enable-auto-tool-choice: ""
tool-call-parser: rl_tool
tool-parser-plugin: ${hydra:runtime.cwd}/pipelinerl/rl_tool_parser_plugin.py

finetune:
seq_length: 32000
seq_parallel: 8
gradient_accumulation_passes: 1024
rl:
policy_loss: gspo
overlong_filtering: true

preprocess:
input: actor
output: training_data
n_workers: 8
shared_memory_entry_size: 1000000000

cube_params:
name: tir
resource_guard:
actor_memory_gb: 1.25
memory_overhead_gb: 8.0
memory_usage_threshold: 0.90
benchmark:
_target_: math_tool_use.benchmark.MathToolUseBenchmark
default_tool_config:
_target_: math_tool_use.tool.MathToolUseToolConfig
sandbox_endpoint: http://dns-24e3447c-506e-4b21-92df-156e18db5087-sandboxfusion
agent:
_target_: cube_harness.agents.tir.TirAgentConfig
llm_config:
_target_: cube_harness.llm.RoutedLLMConfig
model_name: ${vllm_config.vllm_kwargs.served_model_name}
tokenizer_name: ${model_path}
timeout: 3600.0
num_retries: 1
extra_body:
return_token_ids: true
system_prompt: |
You are a math-focused AI Agent. Solve problems by combining clear symbolic reasoning
with short, deterministic Python code.
Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration.
Always present the final answer in LaTeX \boxed{}.
Do not express emotions or opinions about user questions.

Workflow:
1. Draft a brief plan in plain text.
2. Execute one run_python_code call to compute or verify the result.
3. Finalize by calling MathAnswer with the LaTeX-formatted answer.

Python execution policy (run_python_code):
- Use Python strictly for pure computation to verify and validate the final answer.
- No network, file system, OS or environment access.
- Keep snippets minimal and self-contained; print only the final result.

Validation:
- Cross-check results (alternative derivation, invariants, higher precision) before finalizing.
- If execution fails, propose the minimal fix and retry.
Always verify with run_python_code before invoking MathAnswer.
max_actions: 3
seed: ${seed}

train_dataset_names:
- open_reasoner_zero_57k
- open_reasoner_zero_extended_72k
test_dataset_names:
- aime_2025
56 changes: 56 additions & 0 deletions example_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env bash
set -euo pipefail

seq_len=32000
max_tokens=16000
temperature=0.7
seq_parallel=4
output_dir_base=results
job_name=cube-ft
filter_zero_advantage_groups=true
max_train_steps=50
train_subset_end=-1
test_subset_end=-1
gradient_accumulation_passes=1024
attempts=8
wandb=true

uv run python -m pipelinerl.launch --config-name cube_math_tool.yaml \
output_dir=$output_dir_base/$job_name \
actor.llm_max_rollouts=32 \
actor.cube_actor_num_cpus=0.5 \
force_restart=true \
fp32_lm_head=true \
finetune.learning_rate=1e-6 \
finetune.attempts=$attempts \
finetune.rl.policy_loss=gspo \
finetune.rl.epsilon_low=3e-3 \
finetune.rl.epsilon_high=4e-3 \
+finetune.rl.filter_zero_advantage_groups=$filter_zero_advantage_groups \
finetune.max_train_steps=$max_train_steps \
finetune.seq_length=$seq_len \
finetune.seq_parallel=$seq_parallel \
finetune.gradient_accumulation_passes=$gradient_accumulation_passes \
vllm_config.vllm_kwargs.max_model_len=$seq_len \
llm.parameters.max_tokens=$max_tokens \
llm.parameters.temperature=$temperature \
llm.parameters.max_completion_tokens=$max_tokens \
+llm.parameters.max_model_len=$seq_len \
test_llm.parameters.max_tokens=$max_tokens \
test_llm.parameters.temperature=$temperature \
test_llm.parameters.max_completion_tokens=$max_tokens \
+test_llm.parameters.max_model_len=$seq_len \
world.actor_fraction=4 \
world.preprocessor_fraction=0 \
world.finetune_fraction=4 \
streams=files \
eval_every_n_versions=1 \
model_path=Qwen/Qwen3-4B-Instruct-2507 \
vllm_config.vllm_kwargs.served_model_name=Qwen3-4B-Instruct-2507 \
wandb.wandb_workspace_root=$output_dir_base \
wandb.wandb_project_name=watermelon \
wandb.use_wandb=$wandb \
+train_subset.begin=0 \
+train_subset.end=$train_subset_end \
+test_subset.begin=0 \
+test_subset.end=$test_subset_end \
78 changes: 16 additions & 62 deletions pipelinerl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
import hydra
import uvloop
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, Field

import wandb
from pipelinerl.async_llm import RetryableLLMResponseError
from pipelinerl.domain_sampling import DomainWeightedSampler
from pipelinerl.domains.math.rollouts import length_penalty
from pipelinerl.finetune_loop import calculate_train_steps
from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb
from pipelinerl.llm import TrainableLLM
from pipelinerl.metrics import SlidingWindowAggregator
from pipelinerl.rollouts import BaseMetrics, RolloutResult, rollout_has_overflow
from pipelinerl.shared_memory_array import SharedMemoryQueue
from pipelinerl.state import TrainerState
Expand All @@ -46,65 +47,6 @@
logger = logging.getLogger(__name__)


class SlidingWindowData(BaseModel):
prompt_tokens_window: list[list[int]] = Field(
default_factory=list,
description="Prompt token counts for each chunk in the window",
)
output_tokens_window: list[list[int]] = Field(
default_factory=list,
description="Output token counts for each chunk in the window",
)
timestamps: list[float] = Field(default_factory=list)


class SlidingWindowAggregator:
def __init__(self, window_size: int):
self.window_size = window_size
self.data = SlidingWindowData()

def update(self, prompt_tokens: list[int], output_tokens: list[int]):
self.data.prompt_tokens_window.append(prompt_tokens)
self.data.output_tokens_window.append(output_tokens)
self.data.timestamps.append(time.time())
if len(self.data.prompt_tokens_window) > self.window_size:
self.data.prompt_tokens_window.pop(0)
self.data.output_tokens_window.pop(0)
self.data.timestamps.pop(0)

def get_stats(self):
if len(self.data.prompt_tokens_window) < self.window_size:
return None

# 1. How many samples do we produce per second?
# 2. How many output tokens do we produce per second?
# 3. How many prompt tokens do we produce per second?
# 4. How many total tokens do we produce per second?
null_stats = {
"samples_per_second": 0,
"output_tokens_per_second": 0,
"prompt_tokens_per_second": 0,
"total_tokens_per_second": 0,
}
if not self.data.timestamps:
return null_stats

time_span = self.data.timestamps[-1] - self.data.timestamps[0]
if time_span < 1e-6:
return null_stats

num_samples = sum(len(tokens) for tokens in self.data.prompt_tokens_window)
total_output_tokens = sum(sum(tokens) for tokens in self.data.output_tokens_window)
total_prompt_tokens = sum(sum(tokens) for tokens in self.data.prompt_tokens_window)

return {
"samples_per_second": num_samples / time_span,
"output_tokens_per_second": total_output_tokens / time_span,
"prompt_tokens_per_second": total_prompt_tokens / time_span,
"total_tokens_per_second": (total_output_tokens + total_prompt_tokens) / time_span,
}



def make_stats_dict() -> dict:
return defaultdict(lambda: defaultdict(list))
Expand Down Expand Up @@ -142,7 +84,13 @@ async def schedule_rollouts(

final_steps = calculate_train_steps(cfg.finetune, cfg.finetune.interrupt_train_steps)
samples_target = final_steps * cfg.finetune.train_batch_size * cfg.finetune.gradient_accumulation_passes
retryable_rollout_exceptions = (aiohttp.ServerTimeoutError, asyncio.TimeoutError, TimeoutError)
retryable_rollout_exceptions = (
aiohttp.ClientConnectionError,
aiohttp.ServerTimeoutError,
RetryableLLMResponseError,
asyncio.TimeoutError,
TimeoutError,
)
max_rollout_retries = int(getattr(cfg.actor, "max_rollout_retries", -1)) # -1 means infinite retries
retry_initial_delay_s = float(getattr(cfg.actor, "rollout_retry_initial_delay_s", 1.0))
retry_max_delay_s = float(getattr(cfg.actor, "rollout_retry_max_delay_s", 30.0))
Expand Down Expand Up @@ -802,6 +750,8 @@ def run_actor_loop(cfg: DictConfig):
test_dataset = dataset_loader(cfg.test_dataset_names, **dataset_loader_params)
if cfg.train_subset:
train_dataset = train_dataset[cfg.train_subset.begin : cfg.train_subset.end]
if cfg.test_subset:
test_dataset = test_dataset[cfg.test_subset.begin : cfg.test_subset.end]
logger.info(f"Loaded {len(train_dataset)} training problems")
logger.info(f"Loaded {len(test_dataset)} test problems")

Expand All @@ -810,14 +760,17 @@ def run_actor_loop(cfg: DictConfig):
actor_model_path = finetune_model_path
else:
actor_model_path = cfg.model_path


served_model_name = cfg.vllm_config.vllm_kwargs.get("served_model_name") if cfg.vllm_config.vllm_kwargs else None

train_llms = [
TrainableLLM(
base_url=url,
model_name=str(actor_model_path),
tokenizer_name=str(actor_model_path),
parameters=cfg.llm.parameters,
collect_logprobs=True,
served_model_name=served_model_name,
)
for url in llm_urls
]
Expand All @@ -828,6 +781,7 @@ def run_actor_loop(cfg: DictConfig):
tokenizer_name=str(actor_model_path),
parameters=cfg.test_llm.parameters,
collect_logprobs=True,
served_model_name=served_model_name,
)
for url in llm_urls
]
Expand Down
Loading