From fe8910ea3dedeceb69fa794ab526e86a00865f9b Mon Sep 17 00:00:00 2001 From: ehsk Date: Fri, 6 Feb 2026 19:05:00 +0000 Subject: [PATCH 1/3] self-evolving curriculum added --- conf/sec.yaml | 110 +++++++++ pipelinerl/actor.py | 78 ++++++- pipelinerl/curriculum/__init__.py | 18 ++ pipelinerl/curriculum/bandit.py | 367 ++++++++++++++++++++++++++++++ pipelinerl/curriculum/feedback.py | 250 ++++++++++++++++++++ pipelinerl/curriculum/iterator.py | 159 +++++++++++++ pipelinerl/curriculum/state.py | 218 ++++++++++++++++++ pipelinerl/preprocess.py | 27 +++ 8 files changed, 1224 insertions(+), 3 deletions(-) create mode 100644 conf/sec.yaml create mode 100644 pipelinerl/curriculum/__init__.py create mode 100644 pipelinerl/curriculum/bandit.py create mode 100644 pipelinerl/curriculum/feedback.py create mode 100644 pipelinerl/curriculum/iterator.py create mode 100644 pipelinerl/curriculum/state.py diff --git a/conf/sec.yaml b/conf/sec.yaml new file mode 100644 index 00000000..08fe20d4 --- /dev/null +++ b/conf/sec.yaml @@ -0,0 +1,110 @@ +defaults: + - base + - /domain_rollouts@domain_rollouts: base + - _self_ + +actor: + shared_memory_entry_size: 2000000000 + rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout + # No system prompt - model's chat template provides guidance + system_prompt: "" + # Minimal task template - each problem contains its own instructions + task_template: |- + {task} + task_prompt: "" + + domain_rollouts: + math: ${domain_rollouts.math} + coding: ${domain_rollouts.coding} + + domain_mix: + math: 0.5 + coding: 0.5 + +# SandboxFusion verification settings +sandbox_endpoint: ${oc.env:SANDBOX_ENDPOINT,http://127.0.0.1:8080} +sandbox_timeout: 10.0 +max_tests_per_problem: 5 + +preprocess: + shared_memory_entry_size: 2000000000 + +finetune: + seq_length: 32000 + +vllm_config: + vllm_kwargs: + max_model_len: 32000 + +llm: + parameters: + max_tokens: 16000 + +test_llm: + parameters: + max_tokens: 16000 + +# Bandit-based curriculum learning configuration +curriculum: + # Enable curriculum learning + enabled: true + # How difficulty is determined: "field" (from difficulty_field) or "estimated" (from success rates) + difficulty_source: "field" + # Field name for difficulty/level (used when difficulty_source="field") + difficulty_field: null + # Additional field(s) for categorization beyond difficulty (optional) + # Can be a single string: "dataset" or a list: ["dataset", "type"] + category_fields: ["domain"] + # Softmax temperature (higher = more exploration, lower = more exploitation) + temperature: 0.4 + # Q-value update learning rate + learning_rate: 0.5 + # Initial Q-value for new categories + initial_q_value: 0.0 + # Signal for Q-update: "advantage", "reward", or "success" + update_signal: "advantage" + # Number of difficulty buckets (only used when difficulty_source="estimated") + # Problems are grouped by success rate into buckets: + # e.g., 5 buckets: [0-0.2), [0.2-0.4), [0.4-0.6), [0.6-0.8), [0.8-1.0] + num_difficulty_buckets: 5 + # How often to reassign problems to buckets (only used when difficulty_source="estimated") + # Counted in preprocessor batches - each batch updates success rates and may trigger reindex + # Lower = more responsive to changing success rates, higher = more stable bucket assignments + reindex_interval: 5 + +dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets +dataset_loader_params: + per_domain_params: + coding: + # TACO + APPS + taco_split: train + apps_split: train + subset: train + train_ratio: 0.9 + max_tests_per_problem: 5 + taco_excluded_difficulties: [VERY_HARD] + skip_apps: false + max_examples: null + seed: 42 + huggingface_token: ${oc.env:HF_TOKEN, null} + +environments: + - key: math + mode: remote + replicas_per_actor: ${world.env_replicas_per_actor} + _target_: pipelinerl.domains.math.MathEnvironment + +environment_key: null + +world: + env_replicas_per_actor: 1 + +train_dataset_names: + - coding::taco + - coding::apps + - math::open_reasoner_zero_57k + - math::open_reasoner_zero_extended_72k + +test_dataset_names: + - math::aime_2025 + - coding::livecodebench_v5 \ No newline at end of file diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 2d7d0167..d62bd0a9 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -10,7 +10,7 @@ from multiprocessing.managers import SharedMemoryManager from pathlib import Path from queue import Empty -from typing import Dict, List +from typing import Dict, List, Optional import aiohttp import hydra @@ -41,6 +41,7 @@ wait_for_environments, wait_for_inference_servers, ) +from .curriculum import BanditConfig, CurriculumState logger = logging.getLogger(__name__) @@ -184,6 +185,17 @@ async def rollout_and_maybe_produce_result( sample.metadata["model_version"] = model_version sample.metadata["rollout_index"] = rollout_index sample.metadata["step_index"] = step_index + # Propagate curriculum metadata if present + if "_selected_category" in problem: + sample.metadata["_selected_category"] = problem["_selected_category"] + if "id" in problem: + sample.metadata["id"] = problem["id"] + # Propagate all configured category fields for stats tracking + if cfg.get("curriculum") and cfg.curriculum.get("enabled"): + curriculum_config = BanditConfig(**cfg.curriculum) + for field in curriculum_config.get_all_category_fields(): + if field in problem: + sample.metadata[f"_curriculum_{field}"] = problem[field] sample.group_id = full_group_id group_rollouts[group_id].append(rollout_result) if len(group_rollouts[group_id]) == attempts: @@ -308,6 +320,25 @@ def __init__( self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) + # Initialize curriculum learning components if enabled (only for training) + self.curriculum_state: Optional[CurriculumState] = None + self._curriculum_config: Optional[BanditConfig] = None # Keep for validation + if is_training and cfg.get("curriculum") and cfg.curriculum.get("enabled", False): + exp_path = Path(cfg.output_dir) + self._curriculum_config = BanditConfig(**cfg.curriculum) + + # Validate: "estimated" difficulty_source requires GRPO-like policy (attempts > 1) + # In estimated mode, per-group success rate is used as difficulty proxy + if self._curriculum_config.difficulty_source == "estimated" and cfg.attempts <= 1: + raise ValueError( + f"Curriculum difficulty_source='estimated' requires attempts > 1 (GRPO-like policy) " + f"to estimate difficulty from per-group success rate. Got attempts={cfg.attempts}" + ) + + # Feedback stream - listener will be started AFTER forking scheduler processes (fork safety) + self._curriculum_feedback_stream = SingleStreamSpec(exp_path=exp_path, topic="curriculum_feedback") + logger.info(f"Initialized curriculum learning with config: {self._curriculum_config}") + # Determine the number of processes to use num_processes = min(self.cfg.actor.rollout_workers, len(self.llms)) attempts = self.cfg.attempts if is_training else 1 @@ -345,12 +376,29 @@ def __init__( process.start() self.rollout_processes.append(process) + # Start curriculum feedback listener AFTER forking scheduler processes (fork safety) + # Starting it before fork can cause multiprocessing.Queue corruption + if self._curriculum_config is not None and hasattr(self, '_curriculum_feedback_stream'): + self.curriculum_state = CurriculumState( + self._curriculum_config, + self._curriculum_feedback_stream, + ) + self.curriculum_state.start_listening() + logger.info("Started curriculum feedback listener (after fork)") + def init_stats(self): self.stats = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) self.latency_list = [] self.model_versions_list = [] self.sliding_stats = defaultdict(list) self.domain_counts = defaultdict(int) + + # Curriculum batch-level tracking + self.curriculum_categories = [] + # Track numeric values for each feature field (for computing averages) + self.curriculum_feature_values: Dict[str, List[float]] = defaultdict(list) + # Track categorical values for each feature field (for distribution plots) + self.curriculum_feature_categories: Dict[str, List[str]] = defaultdict(list) def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} @@ -398,7 +446,14 @@ def update_stats(self, rollout_results: List[RolloutResult]): for k, v in sliding_window_stats.items(): self.sliding_stats[k].append(v) - + # Track curriculum categories and feature values from rollout metadata + if self.curriculum_state is not None: + cats, feat_vals, feat_cats = self.curriculum_state.track_rollout_results(rollout_results) + self.curriculum_categories.extend(cats) + for field, values in feat_vals.items(): + self.curriculum_feature_values[field].extend(values) + for field, values in feat_cats.items(): + self.curriculum_feature_categories[field].extend(values) def run(self, dataset: list[tuple[str, dict]]): loop_start_time = time.time() @@ -418,7 +473,12 @@ def run(self, dataset: list[tuple[str, dict]]): # for test samples, loop through the dataset once domain_sampler = None if self.is_training: - problem_iter = random_iter(dataset) + if self.curriculum_state is not None: + problem_iter = self.curriculum_state.create_iterator(dataset) + logger.info("Using curriculum learning for sampling") + else: + problem_iter = random_iter(dataset) + domain_mix_cfg = getattr(self.cfg.actor, "domain_mix", None) if domain_mix_cfg: mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True) @@ -491,6 +551,8 @@ def run(self, dataset: list[tuple[str, dict]]): else: problem = next(problem_iter) self.problem_queue.put(problem, block=False) + if "_selected_category" in problem: + logger.debug(f"Actor submitting problem with category: {problem['_selected_category']}") submitted_groups += 1 except queue.Full: assert False, "Problem queue was not full just a moment ago, but now it is full" @@ -622,6 +684,16 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 + + # Add curriculum stats if available + if self.curriculum_state is not None: + curriculum_stats = self.curriculum_state.compute_batch_stats( + self.curriculum_categories, + self.curriculum_feature_values, + self.curriculum_feature_categories, + ) + stats |= curriculum_stats + if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) diff --git a/pipelinerl/curriculum/__init__.py b/pipelinerl/curriculum/__init__.py new file mode 100644 index 00000000..c70e33e0 --- /dev/null +++ b/pipelinerl/curriculum/__init__.py @@ -0,0 +1,18 @@ +"""Bandit-based curriculum learning for PipelineRL.""" + +from .bandit import BanditConfig, BanditState, BanditCurriculum, SuccessRateTracker +from .iterator import BanditIterator +from .feedback import CategoryFeedback, compute_category_feedback, CurriculumFeedbackTracker +from .state import CurriculumState + +__all__ = [ + "BanditConfig", + "BanditState", + "BanditCurriculum", + "SuccessRateTracker", + "BanditIterator", + "CategoryFeedback", + "compute_category_feedback", + "CurriculumFeedbackTracker", + "CurriculumState", +] diff --git a/pipelinerl/curriculum/bandit.py b/pipelinerl/curriculum/bandit.py new file mode 100644 index 00000000..b01a2f03 --- /dev/null +++ b/pipelinerl/curriculum/bandit.py @@ -0,0 +1,367 @@ +"""Multi-armed bandit for curriculum learning.""" + +import logging +from typing import Dict, List, Optional, Union + +import numpy as np +from pydantic import BaseModel, Field, model_validator + +logger = logging.getLogger(__name__) + + +class BanditConfig(BaseModel): + """Configuration for bandit-based curriculum learning.""" + + # Whether curriculum learning is enabled + enabled: bool = Field(default=False, description="Enable curriculum learning") + + # How difficulty is determined for curriculum learning: + # - "field": use the value from difficulty_field directly as category + # - "estimated": estimate difficulty from per-group success rates + # (requires GRPO-like policy with attempts > 1) + difficulty_source: str = Field( + default="field", + description="How difficulty is determined: 'field' (from difficulty_field) or 'estimated' (from success rates)", + ) + + # Field name for difficulty/level (e.g., "level", "difficulty") + # Used when difficulty_source="field". If None, uses category_fields only. + difficulty_field: Optional[str] = Field( + default=None, + description="Field in problem dict containing difficulty/level. " + "Takes precedence over category_fields when set.", + ) + + # Additional field(s) for categorization beyond difficulty + # Can be a single string or list of strings for joint categories + # e.g., "dataset" or ["dataset", "type"] for joint category "math|algebra" + category_fields: Union[str, List[str]] = Field( + default_factory=list, + description="Additional field(s) in problem dict for categorization. " + "Can be a single string or list of strings.", + ) + + # Bandit algorithm parameters + temperature: float = Field( + default=1.0, description="Softmax temperature (higher = more exploration)" + ) + learning_rate: float = Field( + default=0.1, description="Q-value update learning rate" + ) + initial_q_value: float = Field( + default=0.0, description="Initial Q-value for new categories" + ) + + # Update signal: which metric to use + update_signal: str = Field( + default="advantage", + description="Signal for Q-update: 'advantage', 'reward', or 'success'", + ) + + # Number of difficulty buckets for estimated mode + # Problems are assigned to buckets based on their estimated success rate + # e.g., 5 buckets: [0-0.2), [0.2-0.4), [0.4-0.6), [0.6-0.8), [0.8-1.0] + num_difficulty_buckets: int = Field( + default=5, + description="Number of difficulty buckets for estimated mode", + ) + + # How often to re-assign problems to buckets based on updated success rates + # Only applies when difficulty_source="estimated" + # Each preprocessor batch sends one feedback message, so this is roughly + # "reindex every N optimization steps" + reindex_interval: int = Field( + default=5, + description="Re-index problems every N preprocessor feedback messages", + ) + + @model_validator(mode="after") + def validate_category_fields(self) -> "BanditConfig": + """Validate that category fields are properly configured.""" + if not self.enabled: + return self + + has_difficulty_field = bool(self.difficulty_field) + has_category_fields = bool(self.category_fields) + + if self.difficulty_source == "field": + # When using field-based difficulty, must have at least one field + if not has_difficulty_field and not has_category_fields: + raise ValueError( + "When difficulty_source='field', either difficulty_field or " + "category_fields must be provided" + ) + # When difficulty_source="estimated", no fields required (difficulty comes from success rates) + # but category_fields can still be used for grouping + + return self + + def get_all_category_fields(self) -> List[str]: + """Get all category fields as a list (difficulty_field + category_fields). + + Converts category_fields to list if it's a string, and prepends + difficulty_field if set. + """ + fields = [] + + # Add difficulty_field first if set + if self.difficulty_field: + fields.append(self.difficulty_field) + + # Add category_fields (normalize string to list) + if isinstance(self.category_fields, str): + if self.category_fields: # non-empty string + fields.append(self.category_fields) + else: + fields.extend(self.category_fields) + + return fields + + def success_rate_to_bucket(self, success_rate: float) -> str: + """Convert a success rate to a difficulty bucket name. + + Args: + success_rate: Success rate in [0, 1] + + Returns: + Bucket name like "bucket_0" (hardest) to "bucket_4" (easiest) + """ + # Clamp to [0, 1] + success_rate = max(0.0, min(1.0, success_rate)) + # Bucket index: 0 = hardest (low success), num_buckets-1 = easiest (high success) + bucket_idx = min( + int(success_rate * self.num_difficulty_buckets), + self.num_difficulty_buckets - 1 + ) + return f"bucket_{bucket_idx}" + + def get_bucket_names(self) -> List[str]: + """Get all bucket names for estimated mode.""" + return [f"bucket_{i}" for i in range(self.num_difficulty_buckets)] + + +class BanditState(BaseModel): + """Serializable state for the bandit.""" + + q_values: Dict[str, float] = Field(default_factory=dict) + sample_counts: Dict[str, int] = Field(default_factory=dict) + + +class SuccessRateTracker: + """Tracks per-problem success rates for estimated difficulty mode. + + Uses exponential moving average to track success rates over time. + """ + + def __init__(self, learning_rate: float = 0.3, default_rate: float = 0.5): + """ + Args: + learning_rate: EMA learning rate for success rate updates + default_rate: Default success rate for unseen problems + """ + self.learning_rate = learning_rate + self.default_rate = default_rate + # problem_id -> estimated success rate + self._success_rates: Dict[str, float] = {} + # problem_id -> number of observations + self._observation_counts: Dict[str, int] = {} + + def get_success_rate(self, problem_id: str) -> float: + """Get estimated success rate for a problem.""" + return self._success_rates.get(problem_id, self.default_rate) + + def update(self, problem_id: str, success_rate: float) -> None: + """Update success rate estimate for a problem. + + Args: + problem_id: Unique problem identifier + success_rate: Observed success rate (e.g., from batch) + """ + if problem_id not in self._success_rates: + # First observation - use it directly + self._success_rates[problem_id] = success_rate + self._observation_counts[problem_id] = 1 + else: + # EMA update + old_rate = self._success_rates[problem_id] + new_rate = (1 - self.learning_rate) * old_rate + self.learning_rate * success_rate + self._success_rates[problem_id] = new_rate + self._observation_counts[problem_id] += 1 + + def get_state(self) -> Dict[str, Dict[str, float]]: + """Get serializable state.""" + return { + "success_rates": dict(self._success_rates), + "observation_counts": {k: float(v) for k, v in self._observation_counts.items()}, + } + + def get_stats(self, config: "BanditConfig") -> Dict[str, float]: + """Get statistics for wandb logging. + + Args: + config: BanditConfig to compute bucket assignments + + Returns: + Dict of stat_name -> value + """ + stats = {} + + if not self._success_rates: + return stats + + rates = list(self._success_rates.values()) + stats["estimated/num_problems_tracked"] = len(rates) + stats["estimated/mean_success_rate"] = np.mean(rates) + stats["estimated/std_success_rate"] = np.std(rates) if len(rates) > 1 else 0.0 + stats["estimated/min_success_rate"] = np.min(rates) + stats["estimated/max_success_rate"] = np.max(rates) + + # Distribution across buckets + bucket_counts: Dict[str, int] = {name: 0 for name in config.get_bucket_names()} + for rate in rates: + bucket = config.success_rate_to_bucket(rate) + bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1 + + total = len(rates) + for bucket, count in bucket_counts.items(): + stats[f"estimated/bucket_fraction/{bucket}"] = count / total if total > 0 else 0.0 + + # Average observations per problem + if self._observation_counts: + obs_counts = list(self._observation_counts.values()) + stats["estimated/mean_observations"] = np.mean(obs_counts) + + return stats + + def set_state(self, state: Dict[str, Dict[str, float]]) -> None: + """Set state from serialized form.""" + self._success_rates = dict(state.get("success_rates", {})) + self._observation_counts = {k: int(v) for k, v in state.get("observation_counts", {}).items()} + + def __len__(self) -> int: + return len(self._success_rates) + + +class BanditCurriculum: + """ + Multi-armed bandit for curriculum learning. + + Uses softmax/Boltzmann selection over Q-values representing category value. + Q-values are updated based on mean advantage feedback from the preprocessor. + """ + + def __init__(self, config: BanditConfig, categories: Optional[List[str]] = None): + self.config = config + self.state = BanditState() + + # Initialize categories if provided + if categories: + for cat in categories: + self._ensure_category(cat) + + def _ensure_category(self, category: str) -> None: + """Ensure category exists in state.""" + if category not in self.state.q_values: + self.state.q_values[category] = self.config.initial_q_value + self.state.sample_counts[category] = 0 + + def get_selection_probabilities(self) -> Dict[str, float]: + """ + Compute softmax selection probabilities over categories. + + Returns dict mapping category -> probability + """ + if not self.state.q_values: + return {} + + categories = list(self.state.q_values.keys()) + q_vals = np.array([self.state.q_values[c] for c in categories]) + + # Softmax with temperature + # Subtract max for numerical stability + q_vals = q_vals - np.max(q_vals) + exp_vals = np.exp(q_vals / max(self.config.temperature, 1e-8)) + probs = exp_vals / exp_vals.sum() + + return {cat: float(prob) for cat, prob in zip(categories, probs)} + + def select_category(self, category_to_problems: Optional[Dict[str, list]] = None) -> Optional[str]: + """ + Select a category using softmax/Boltzmann selection. + + Args: + category_to_problems: If provided, only select from non-empty categories. + + Returns selected category name or None if no categories available. + """ + probs = self.get_selection_probabilities() + if not probs: + return None + + # Filter to non-empty categories + if category_to_problems is not None: + probs = {c: p for c, p in probs.items() if category_to_problems.get(c)} + if not probs: + return None + total = sum(probs.values()) + probs = {c: p / total for c, p in probs.items()} + + categories = list(probs.keys()) + probabilities = [probs[c] for c in categories] + + return np.random.choice(categories, p=probabilities) + + def update_q_value(self, category: str, signal: float) -> None: + """ + Update Q-value for a category based on feedback signal. + + Uses exponential moving average: Q = (1-lr)*Q + lr*signal + + Args: + category: The category to update + signal: The feedback signal (e.g., mean advantage) + """ + if not np.isfinite(signal): + logger.warning(f"Ignoring non-finite signal {signal} for category {category}") + return + + self._ensure_category(category) + + old_q = self.state.q_values[category] + lr = self.config.learning_rate + new_q = (1 - lr) * old_q + lr * signal + + self.state.q_values[category] = new_q + self.state.sample_counts[category] += 1 + + logger.debug( + f"Updated Q[{category}]: {old_q:.4f} -> {new_q:.4f} (signal={signal:.4f})" + ) + + def get_state(self) -> BanditState: + """Get serializable state for persistence/communication.""" + return self.state.model_copy(deep=True) + + def set_state(self, state: BanditState) -> None: + """Set state from serialized form.""" + self.state = state.model_copy(deep=True) + + def get_stats(self) -> Dict[str, float]: + """Get statistics for logging.""" + stats = {} + probs = self.get_selection_probabilities() + + for cat in self.state.q_values: + stats[f"curriculum/q_value/{cat}"] = self.state.q_values[cat] + stats[f"curriculum/samples/{cat}"] = self.state.sample_counts[cat] + if cat in probs: + stats[f"curriculum/prob/{cat}"] = probs[cat] + + # Also log aggregate stats + if probs: + # Entropy of selection distribution (higher = more uniform) + probs_arr = np.array(list(probs.values())) + entropy = -np.sum(probs_arr * np.log(probs_arr + 1e-10)) + stats["curriculum/selection_entropy"] = entropy + + return stats diff --git a/pipelinerl/curriculum/feedback.py b/pipelinerl/curriculum/feedback.py new file mode 100644 index 00000000..213de207 --- /dev/null +++ b/pipelinerl/curriculum/feedback.py @@ -0,0 +1,250 @@ +"""Curriculum feedback computation and message types.""" + +import logging +import time +from collections import defaultdict +from typing import Dict, List, Literal, Union + +import numpy as np +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class CategoryFeedback(BaseModel): + """Feedback message from preprocessor to actor about category performance.""" + + kind: Literal["category_feedback"] = "category_feedback" + + # Map from category -> mean advantage for that category in this batch + category_advantages: Dict[str, float] = Field(default_factory=dict) + + # Map from category -> number of samples with non-zero advantages in this batch + category_counts: Dict[str, int] = Field(default_factory=dict) + + # Map from category -> total number of rollouts in this batch + category_total_rollouts: Dict[str, int] = Field(default_factory=dict) + + # Map from category -> set of problem IDs in this batch + category_problem_ids: Dict[str, List[str]] = Field(default_factory=dict) + + # Map from category -> success rate in this batch + category_success_rates: Dict[str, float] = Field(default_factory=dict) + + # Map from problem_id -> success rate (for estimated mode) + # Used to update per-problem success rate estimates in the actor + problem_success_rates: Dict[str, float] = Field(default_factory=dict) + + # Timestamp for ordering/staleness detection + timestamp: float = Field(default_factory=time.time) + + # Model version these stats correspond to + model_version: int = 0 + + +class BanditStateUpdate(BaseModel): + """Full bandit state update (for synchronization).""" + + kind: Literal["bandit_state_update"] = "bandit_state_update" + + q_values: Dict[str, float] = Field(default_factory=dict) + sample_counts: Dict[str, int] = Field(default_factory=dict) + timestamp: float = Field(default_factory=time.time) + + +CurriculumMessage = Union[CategoryFeedback, BanditStateUpdate] + + +class CurriculumFeedbackTracker: + """Tracks curriculum feedback stats over multiple batches for periodic logging.""" + + def __init__(self, log_every_n_samples: int = 128): + self.log_every_n_samples = log_every_n_samples + self.cumulative_rollouts_per_category: Dict[str, int] = defaultdict(int) + self.cumulative_problems_per_category: Dict[str, set] = defaultdict(set) + self.reset() + + def reset(self): + """Reset batch stats (not cumulative).""" + self.batch_nonzero_adv_per_category: Dict[str, int] = defaultdict(int) + self.batch_rollouts_per_category: Dict[str, int] = defaultdict(int) + self.batch_problems_per_category: Dict[str, set] = defaultdict(set) + self.batch_count = 0 + + def update(self, feedback: CategoryFeedback): + """Accumulate stats from a feedback message.""" + for category, count in feedback.category_counts.items(): + self.batch_nonzero_adv_per_category[category] += count + for category, count in feedback.category_total_rollouts.items(): + self.batch_rollouts_per_category[category] += count + self.cumulative_rollouts_per_category[category] += count + for category, problem_ids in feedback.category_problem_ids.items(): + self.batch_problems_per_category[category].update(problem_ids) + self.cumulative_problems_per_category[category].update(problem_ids) + self.batch_count += 1 + + def has_data(self) -> bool: + """Check if there are accumulated stats to report.""" + return self.batch_count > 0 + + def time_to_log(self) -> bool: + """Check if enough samples have accumulated for periodic logging.""" + total_rollouts = sum(self.batch_rollouts_per_category.values()) + return total_rollouts >= self.log_every_n_samples + + def get_summary_and_reset(self) -> str: + """Get a formatted summary string and reset batch stats.""" + parts = [f"{self.batch_count} chunks"] + for cat in sorted(self.batch_rollouts_per_category.keys()): + total = self.batch_rollouts_per_category[cat] + nonzero = self.batch_nonzero_adv_per_category.get(cat, 0) + zero = total - nonzero + problems = len(self.batch_problems_per_category.get(cat, set())) + cum_total = self.cumulative_rollouts_per_category[cat] + cum_problems = len(self.cumulative_problems_per_category[cat]) + parts.append( + f"{cat}: {total} rollouts ({zero} zero-adv), {problems} problems " + f"(cumulative: {cum_total} rollouts, {cum_problems} problems)" + ) + self.reset() + return " | ".join(parts) + + +def compute_category_feedback( + dataset: List[Dict], + category_fields: Union[str, List[str]] = None, +) -> CategoryFeedback: + """ + Compute feedback statistics from a preprocessed batch. + + Called in preprocess.py after populate_rl_data() computes advantages. + + Args: + dataset: List of processed entries with 'advantages', 'metadata', etc. + category_fields: Field(s) to use for category grouping. Can be a single + field name or a list of fields for joint categories. If None or empty, + uses "_selected_category" from metadata only. + + Returns: + CategoryFeedback message with aggregated statistics + """ + # Normalize category_fields to list + if category_fields is None: + fields = [] + elif isinstance(category_fields, str): + fields = [category_fields] if category_fields else [] + else: + fields = list(category_fields) + + def get_category_from_entry(entry: Dict) -> str: + """Extract category from entry using configured fields.""" + values = [str(entry[field]) for field in fields if entry.get(field)] + return "|".join(values) + + # Aggregate by category + category_advantages: Dict[str, List[float]] = defaultdict(list) + category_nonzero_counts: Dict[str, int] = defaultdict(int) # For logging + category_rewards: Dict[str, List[float]] = defaultdict(list) + category_problem_ids: Dict[str, set] = defaultdict(set) + + # Per-problem rewards for estimated mode + problem_rewards: Dict[str, List[float]] = defaultdict(list) + + for entry in dataset: + # Get category from metadata (set by bandit iterator) or entry itself + metadata = entry.get("metadata", {}) + category = metadata.get("_selected_category") + problem_id = metadata.get("id") # For estimated mode + + if category is None and fields: + # Fallback to category_fields in entry + category = get_category_from_entry(entry) + + # Track rewards for success rate estimation (do this even without category) + rewards = entry.get("rewards", []) + reward = None + if rewards: + # Use the first reward value (they're typically all the same for a rollout) + reward = rewards[0] if isinstance(rewards, list) else rewards + if np.isfinite(reward): + # Track per-problem for estimated mode (independent of category) + if problem_id is not None: + problem_rewards[str(problem_id)].append(reward) + # Track per-category if we have one + if category is not None: + category_rewards[category].append(reward) + category_problem_ids[category].add(str(problem_id)) + + # Skip category-based stats if no category available + if category is None: + # This shouldn't happen in estimated mode if iterator is working correctly + logger.debug(f"Entry missing _selected_category, problem_id={problem_id}") + continue + + # Get advantages (per-token, use mean of absolute values like SEC) + advantages = entry.get("advantages", []) + labels = entry.get("labels", []) + if advantages and labels: + adv_arr = np.abs(np.array(advantages)) + mask = np.array(labels) != -100 # Response mask (like SEC) + if mask.any(): + # Masked mean of absolute advantages (like SEC's masked_mean) + mean_abs_adv = float(adv_arr[mask].mean()) + if np.isfinite(mean_abs_adv): + category_advantages[category].append(mean_abs_adv) + if mean_abs_adv > 0: + category_nonzero_counts[category] += 1 + + # Compute category aggregates + result_advantages = {} + result_counts = {} + result_total_rollouts = {} + result_problem_ids = {} + result_success_rates = {} + + all_categories = set(category_advantages.keys()) | set(category_rewards.keys()) + + for category in all_categories: + advs = category_advantages.get(category, []) + rewards = category_rewards.get(category, []) + problem_ids = category_problem_ids.get(category, set()) + + if advs: + # Mean includes zero-advantage samples (like SEC) + result_advantages[category] = float(np.mean(advs)) + # Count of non-zero advantage samples (for logging/tracking) + result_counts[category] = category_nonzero_counts.get(category, 0) + result_total_rollouts[category] = len(rewards) + result_problem_ids[category] = list(problem_ids) + if rewards: + # Success = positive reward + success_rate = sum(1 for r in rewards if r > 0) / len(rewards) + result_success_rates[category] = float(success_rate) + + # Compute per-problem success rates (for estimated mode) + result_problem_success_rates = {} + for problem_id, rewards in problem_rewards.items(): + if rewards: + success_rate = sum(1 for r in rewards if r > 0) / len(rewards) + result_problem_success_rates[problem_id] = float(success_rate) + + model_version = 0 + if dataset: + # Get model version from first entry's metadata + model_version = dataset[0].get("metadata", {}).get("model_version", 0) + + logger.debug( + f"Computed feedback for {len(all_categories)} categories, " + f"{len(result_problem_success_rates)} problems: " + f"advantages={result_advantages}, counts={result_counts}" + ) + + return CategoryFeedback( + category_advantages=result_advantages, + category_counts=result_counts, + category_total_rollouts=result_total_rollouts, + category_problem_ids=result_problem_ids, + category_success_rates=result_success_rates, + problem_success_rates=result_problem_success_rates, + model_version=model_version, + ) diff --git a/pipelinerl/curriculum/iterator.py b/pipelinerl/curriculum/iterator.py new file mode 100644 index 00000000..6480288a --- /dev/null +++ b/pipelinerl/curriculum/iterator.py @@ -0,0 +1,159 @@ +"""Bandit-based problem iterator for curriculum learning.""" + +import logging +import random +from collections import defaultdict +from typing import Any, Dict, Iterator, List, Optional + +from .bandit import BanditConfig, BanditCurriculum, SuccessRateTracker + +logger = logging.getLogger(__name__) + + +class BanditIterator: + """ + Iterator that selects problems based on bandit-learned category distribution. + + Supports two modes: + - difficulty_source="field": Categories from data fields (e.g., "level") + - difficulty_source="estimated": Categories from estimated success rates + """ + + def __init__( + self, + dataset: List[Dict[str, Any]], + bandit: BanditCurriculum, + config: BanditConfig, + success_tracker: Optional[SuccessRateTracker] = None, + ): + self.dataset = dataset + self.bandit = bandit + self.config = config + self.success_tracker = success_tracker + + # Index problems by category for efficient sampling + self.category_to_problems: Dict[str, List[int]] = defaultdict(list) # category -> list of indices + self._index_by_category() + + # Initialize bandit with discovered categories + for category in self.category_to_problems: + self.bandit._ensure_category(category) + + logger.info( + f"BanditIterator initialized with {len(self.category_to_problems)} categories" + ) + + def _get_problem_id(self, problem: Dict[str, Any], index: int) -> str: + """Get unique identifier for a problem.""" + return str(problem.get("id", f"idx_{index}")) + + def _get_category(self, problem: Dict[str, Any], index: int = -1) -> Optional[str]: + """Extract category from a problem based on config.""" + if self.config.difficulty_source == "estimated": + # Use estimated success rate to determine bucket + problem_id = self._get_problem_id(problem, index) + + if self.success_tracker: + success_rate = self.success_tracker.get_success_rate(problem_id) + return self.config.success_rate_to_bucket(success_rate) + else: + # No tracker yet, use default bucket (middle) + return self.config.success_rate_to_bucket(0.5) + + # Field-based mode + fields = self.config.get_all_category_fields() + if not fields: + return None + + values = [] + for field in fields: + value = problem.get(field) + if value is not None: + values.append(str(value)) + else: + values.append("unknown") + + if values and any(v != "unknown" for v in values): + return "|".join(values) + + return "unknown" + + def _index_by_category(self) -> None: + """Build index mapping categories to problem indices.""" + if self.config.difficulty_source == "estimated": + # For estimated mode, initialize all buckets + for bucket_name in self.config.get_bucket_names(): + self.category_to_problems[bucket_name] = [] + + # Index problems by their current estimated bucket + for i, problem in enumerate(self.dataset): + category = self._get_category(problem, i) + if category: + self.category_to_problems[category].append(i) + + logger.info( + f"Indexed {len(self.dataset)} problems into {len(self.category_to_problems)} " + f"difficulty buckets (estimated mode)" + ) + for cat in sorted(self.category_to_problems.keys()): + problems = self.category_to_problems[cat] + logger.info(f" {cat}: {len(problems)} problems") + else: + # Field-based mode + fields = self.config.get_all_category_fields() + if not fields: + logger.info("No category fields configured") + return + + for i, problem in enumerate(self.dataset): + category = self._get_category(problem, i) + if category is not None: + self.category_to_problems[category].append(i) + + logger.info( + f"Indexed {len(self.dataset)} problems into {len(self.category_to_problems)} categories" + ) + for cat, indices in self.category_to_problems.items(): + logger.info(f" Category '{cat}': {len(indices)} problems") + + def reindex_by_estimated_difficulty(self) -> None: + """Re-index problems by their current estimated success rates. + + Should be called periodically when success rate estimates have changed. + """ + if self.config.difficulty_source != "estimated": + return + + # Clear and rebuild + self.category_to_problems.clear() + for bucket_name in self.config.get_bucket_names(): + self.category_to_problems[bucket_name] = [] + + for i, problem in enumerate(self.dataset): + category = self._get_category(problem, i) + if category: + self.category_to_problems[category].append(i) + + # Ensure bandit knows about all buckets + for category in self.category_to_problems: + self.bandit._ensure_category(category) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """Infinite iterator that samples problems based on bandit selection.""" + while True: + # Pass category_to_problems so bandit excludes empty categories + category = self.bandit.select_category(self.category_to_problems) + logger.debug(f"Bandit selected category: {category}") + + if category is not None and category in self.category_to_problems: + indices = self.category_to_problems[category] + problem_idx = random.choice(indices) + else: + # Fallback to uniform random if no valid categories + problem_idx = random.randint(0, len(self.dataset) - 1) + category = self._get_category(self.dataset[problem_idx], problem_idx) + + problem = dict(self.dataset[problem_idx]) + problem["_selected_category"] = category + + yield problem diff --git a/pipelinerl/curriculum/state.py b/pipelinerl/curriculum/state.py new file mode 100644 index 00000000..19268ee6 --- /dev/null +++ b/pipelinerl/curriculum/state.py @@ -0,0 +1,218 @@ +"""Curriculum state manager with feedback listener.""" + +import logging +import threading +from typing import Any, Dict, Iterator, List, Optional + +from pydantic import TypeAdapter + +from pipelinerl.streams import SingleStreamSpec, read_stream + +from .bandit import BanditConfig, BanditCurriculum, BanditState, SuccessRateTracker +from .feedback import BanditStateUpdate, CategoryFeedback, CurriculumMessage +from .iterator import BanditIterator + +logger = logging.getLogger(__name__) + + +class CurriculumState: + """ + Implements Self-Evolving Curriculum (SEC): https://arxiv.org/abs/2505.14970 + + Manages all curriculum learning state: bandit, success tracker, and feedback listener. + """ + + def __init__( + self, + config: BanditConfig, + feedback_stream: SingleStreamSpec, + ): + self.config = config + self.bandit = BanditCurriculum(config) + + # Create success tracker for estimated mode + self.success_tracker: Optional[SuccessRateTracker] = None + if config.difficulty_source == "estimated": + self.success_tracker = SuccessRateTracker(learning_rate=config.learning_rate) + logger.info("Created success rate tracker for estimated difficulty mode") + + self.feedback_stream = feedback_stream + self._iterator: Optional[BanditIterator] = None + self._thread: Optional[threading.Thread] = None + self._stop_event = threading.Event() + self._message_adapter = TypeAdapter(CurriculumMessage) + self._updates_since_reindex = 0 + + def create_iterator(self, dataset: List[Dict[str, Any]]) -> Iterator[Dict[str, Any]]: + """Create a bandit-based iterator for the dataset. + + Args: + dataset: List of problem dicts + + Returns: + Iterator that yields problems based on bandit selection + """ + self._iterator = BanditIterator( + dataset, self.bandit, self.config, self.success_tracker + ) + return iter(self._iterator) + + def start_listening(self): + """Start background thread to listen for feedback.""" + + def listen(): + logger.info( + f"Starting curriculum feedback listener on stream {self.feedback_stream.topic}" + ) + try: + with read_stream(self.feedback_stream) as reader: + for line in reader.read(): + if self._stop_event.is_set(): + break + try: + message = self._message_adapter.validate_python(line) + self._process_curriculum_feedback(message) + except Exception as e: + logger.warning(f"Failed to parse curriculum message: {e}") + except Exception as e: + logger.error(f"Curriculum feedback listener error: {e}") + + self._thread = threading.Thread(target=listen, daemon=True) + self._thread.start() + logger.info("Started curriculum feedback listener") + + def stop_listening(self): + """Stop the feedback listener thread.""" + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=5.0) + logger.info("Stopped curriculum feedback listener") + + def _process_curriculum_feedback(self, message: CurriculumMessage): + """Process incoming curriculum feedback.""" + if isinstance(message, CategoryFeedback): + # Update Q-values based on advantages + for category, advantage in message.category_advantages.items(): + self.bandit.update_q_value(category, advantage) + + # Update per-problem success rates (for estimated mode) + if self.success_tracker and message.problem_success_rates: + for problem_id, success_rate in message.problem_success_rates.items(): + self.success_tracker.update(problem_id, success_rate) + + self._updates_since_reindex += 1 + # Periodically trigger re-indexing when success rates change + if self._updates_since_reindex >= self.config.reindex_interval: + if self._iterator is not None: + self._iterator.reindex_by_estimated_difficulty() + logger.debug("Re-indexed problems by estimated difficulty") + self._updates_since_reindex = 0 + + logger.debug( + f"Processed feedback for {len(message.category_advantages)} categories, " + f"{len(message.problem_success_rates)} problems " + f"(model_version={message.model_version})" + ) + + elif isinstance(message, BanditStateUpdate): + # Full state synchronization + state = BanditState( + q_values=message.q_values, + sample_counts=message.sample_counts, + ) + self.bandit.set_state(state) + logger.info("Synchronized bandit state from external update") + + def track_rollout_results(self, rollout_results): + """Extract curriculum metadata from rollout results. + + Args: + rollout_results: List of RolloutResult objects + + Returns: + Tuple of (categories, feature_values, feature_categories) + """ + categories = [] + feature_values: Dict[str, list] = {} + feature_categories: Dict[str, list] = {} + + category_fields = self.config.get_all_category_fields() + + for result in rollout_results: + for text in result.training_texts: + metadata = text.metadata or {} + # Track selected category + if "_selected_category" in metadata: + categories.append(metadata["_selected_category"]) + # Track each configured feature field + for field in category_fields: + key = f"_curriculum_{field}" + if key in metadata: + value = metadata[key] + # Try numeric first (for averages) + try: + if field not in feature_values: + feature_values[field] = [] + feature_values[field].append(float(value)) + except (ValueError, TypeError): + pass + # Always track as categorical (for distribution) + if field not in feature_categories: + feature_categories[field] = [] + feature_categories[field].append(str(value)) + + return categories, feature_values, feature_categories + + def compute_batch_stats( + self, + categories: list, + feature_values: Dict[str, list], + feature_categories: Dict[str, list], + ) -> Dict[str, float]: + """Compute curriculum statistics for logging. + + Args: + categories: List of selected categories in this batch + feature_values: Dict of field -> list of numeric values + feature_categories: Dict of field -> list of categorical values + + Returns: + Dict of stat_name -> value + """ + from collections import Counter + + stats = {} + + # Add bandit Q-value stats + stats.update(self.bandit.get_stats()) + + # Add success tracker stats (for estimated mode) + if self.success_tracker: + stats.update(self.success_tracker.get_stats(self.config)) + + # Average feature values (like SEC's loader/average_difficulty) + for field, values in feature_values.items(): + if values: + stats[f"curriculum/average_{field}"] = sum(values) / len(values) + + # Categorical distribution for each feature field (fraction and count) + for field, values in feature_categories.items(): + if values: + value_counts = Counter(values) + total = len(values) + for value, count in value_counts.items(): + stats[f"curriculum/{field}/{value}"] = count / total + stats[f"curriculum/{field}_count/{value}"] = count + + # For estimated mode, track bucket distribution from selected categories + # (categories contains bucket names like "bucket_0", "bucket_1", etc.) + if self.config.difficulty_source == "estimated" and categories: + bucket_counts = Counter(categories) + total = len(categories) + for bucket_name, count in bucket_counts.items(): + # Extract bucket number from "bucket_0" -> "0" + bucket_num = bucket_name.replace("bucket_", "") + stats[f"curriculum/bucket/{bucket_num}"] = count / total + stats[f"curriculum/bucket_count/{bucket_num}"] = count + + return stats diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 0a6015e4..faf6f164 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -46,6 +46,7 @@ set_streams_backend, write_to_streams, ) +from pipelinerl.curriculum import BanditConfig, compute_category_feedback, CurriculumFeedbackTracker logger = logging.getLogger(__name__) @@ -381,6 +382,16 @@ def run_preprocessing_loop( stats_streams = SingleStreamSpec(exp_path=exp_root_dir, topic="preprocessor_stats") logger.info("Streams initialized") + # Curriculum feedback stream (only used if curriculum learning is enabled) + curriculum_enabled = cfg.get("curriculum") and cfg.curriculum.get("enabled", False) + curriculum_stream = None + curriculum_category_fields: list[str] = [] + if curriculum_enabled: + curriculum_stream = SingleStreamSpec(exp_path=exp_root_dir, topic="curriculum_feedback") + curriculum_config = BanditConfig(**cfg.curriculum) + curriculum_category_fields = curriculum_config.get_all_category_fields() + logger.info(f"Curriculum feedback enabled with category fields: {curriculum_category_fields}") + raw_chunk_queue = Queue(cfg.preprocess.raw_queue_size) rl_config = RLConfig(**cfg.finetune.rl) pop_old_data = cfg.max_lag is None and cfg.pop_old_data and not cfg.debug.mode @@ -495,6 +506,7 @@ def run_preprocessing_loop( fetching_took = 0 writing_took = 0 num_filtered_out = 0 + curriculum_tracker = CurriculumFeedbackTracker(cfg.preprocess.log_every_n_samples) while True: if ( trainer_state.samples_processed is not None @@ -528,6 +540,21 @@ def run_preprocessing_loop( total_filtered_out += num_filtered_out if num_filtered_out > 0: logger.info(f"Filtered out {num_filtered_out} samples from groups with zero advantage.") + + # Compute and publish curriculum feedback if enabled + if curriculum_stream is not None and dataset: + feedback = compute_category_feedback( + dataset=dataset, + category_fields=curriculum_category_fields, + ) + with write_to_streams(curriculum_stream) as curriculum_writer: + curriculum_writer.write(feedback.model_dump()) + curriculum_tracker.update(feedback) + if curriculum_tracker.time_to_log(): + logger.info( + f"Published curriculum feedback to actor: {curriculum_tracker.get_summary_and_reset()}" + ) + fetching_took += time.time() - start_fetching except Empty: pass From 201b30585da019bcb790bcc538d239ab3f22de35 Mon Sep 17 00:00:00 2001 From: ehsk Date: Fri, 6 Feb 2026 21:34:45 +0000 Subject: [PATCH 2/3] manual domain mix commented out --- conf/sec.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conf/sec.yaml b/conf/sec.yaml index 08fe20d4..258daf2c 100644 --- a/conf/sec.yaml +++ b/conf/sec.yaml @@ -17,9 +17,9 @@ actor: math: ${domain_rollouts.math} coding: ${domain_rollouts.coding} - domain_mix: - math: 0.5 - coding: 0.5 + # domain_mix: + # math: 0.5 + # coding: 0.5 # SandboxFusion verification settings sandbox_endpoint: ${oc.env:SANDBOX_ENDPOINT,http://127.0.0.1:8080} From 48c14b21911d7b1b1fa380012ac61d865cff9bb6 Mon Sep 17 00:00:00 2001 From: ehsk Date: Thu, 9 Apr 2026 15:36:29 +0000 Subject: [PATCH 3/3] redundant code removed --- pipelinerl/actor.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 1d1d642f..18cab32f 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -423,16 +423,6 @@ def _stop_rollout_processes(self): logger.info(f"Stopped {len(self.rollout_processes)} rollout processes") self.rollout_processes = [] - # Start curriculum feedback listener AFTER forking scheduler processes (fork safety) - # Starting it before fork can cause multiprocessing.Queue corruption - if self._curriculum_config is not None and hasattr(self, '_curriculum_feedback_stream'): - self.curriculum_state = CurriculumState( - self._curriculum_config, - self._curriculum_feedback_stream, - ) - self.curriculum_state.start_listening() - logger.info("Started curriculum feedback listener (after fork)") - def init_stats(self): self.stats = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) self.latency_list = []