diff --git a/docs/metrics.md b/docs/metrics.md index d9e74a04c..5e522e140 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -103,6 +103,27 @@ With the above, agentic_grpo_learner will by default start an async trajectory logger which logs the trajectories including prompts, responses, etc. to the specified `log_dir`. +### Trackio Trajectory Backend + +RL algorithms can also log rollout prompt/completion pairs as Trackio traces +through the trajectory logger. Set `trackio_project` and a positive +`trackio_max_traces_per_step` on the algorithm config. Tunix logs standard +PPO/GRPO rollouts under `rollout/traces` and agentic rollouts under +`agentic/trajectories`. + +```python +from tunix.rl.grpo import grpo_learner + +algo_config = grpo_learner.GRPOConfig( + trackio_project="my-rl-project", + trackio_run_name="experiment-1", + trackio_max_traces_per_step=16, +) +``` + +Trackio is optional. Install it in the training environment before enabling +trace logging. + ### Enabling Metrics in Jobs Once you have your `MetricsLoggerOptions` configured, you can pass it to your diff --git a/tests/test_trajectory_logger_trackio.py b/tests/test_trajectory_logger_trackio.py new file mode 100644 index 000000000..d58906c41 --- /dev/null +++ b/tests/test_trajectory_logger_trackio.py @@ -0,0 +1,197 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import pathlib +import sys +import types +import unittest +from unittest import mock + +import numpy as np + +_MODULE_PATH = ( + pathlib.Path(__file__).resolve().parents[1] + / "tunix" + / "utils" + / "trajectory_logger.py" +) +_SPEC = importlib.util.spec_from_file_location( + "trajectory_logger", _MODULE_PATH +) +trajectory_logger = importlib.util.module_from_spec(_SPEC) +assert _SPEC.loader is not None +sys.modules[_SPEC.name] = trajectory_logger +_SPEC.loader.exec_module(trajectory_logger) + + +class TrackioTrajectoryLoggerTest(unittest.TestCase): + + def test_log_rollouts_creates_trackio_traces(self): + run = mock.MagicMock() + trace = mock.MagicMock( + side_effect=lambda messages, metadata: { + "messages": messages, + "metadata": metadata, + } + ) + fake_trackio = types.SimpleNamespace( + init=mock.MagicMock(return_value=run), + Trace=trace, + ) + + with mock.patch.dict(sys.modules, {"trackio": fake_trackio}): + logger = trajectory_logger.AsyncTrajectoryLogger( + backends=[ + trajectory_logger.TrackioTrajectoryLogBackend( + project="tunix-smoke", + run_name="run-1", + max_traces_per_step=2, + ) + ], + ) + logger.log_rollouts( + prompts=["What is 2 + 2?"], + completions=["4", "Four."], + rewards=np.array([1.0, 0.8]), + advantages=np.array([[0.1, 0.2], [0.3, 0.4]]), + mode="train", + step=7, + ) + logger.stop() + + fake_trackio.init.assert_called_once_with( + project="tunix-smoke", name="run-1" + ) + self.assertEqual(trace.call_count, 2) + self.assertEqual( + trace.call_args_list[0].kwargs["messages"], + [ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "4"}, + ], + ) + self.assertEqual(trace.call_args_list[0].kwargs["metadata"]["step"], 7) + self.assertEqual( + trace.call_args_list[1].kwargs["metadata"]["advantages"], [0.3, 0.4] + ) + run.log.assert_called_once_with( + {"rollout/traces": [mock.ANY, mock.ANY]}, step=7 + ) + + def test_from_config_adds_trackio_backend(self): + config = types.SimpleNamespace( + trackio_project="tunix-smoke", + trackio_run_name="run-1", + trackio_trace_key="rollout/traces", + trackio_max_traces_per_step=2, + trackio_init_kwargs={}, + ) + + logger = trajectory_logger.AsyncTrajectoryLogger.from_config(config=config) + + self.assertTrue(logger.has_backends) + logger.stop() + + def test_log_messages_uses_existing_chat_messages(self): + run = mock.MagicMock() + trace = mock.MagicMock(return_value="trace") + fake_trackio = types.SimpleNamespace( + init=mock.MagicMock(return_value=run), + Trace=trace, + ) + messages = [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "Name a color."}, + {"role": "assistant", "content": "Blue."}, + ] + + with mock.patch.dict(sys.modules, {"trackio": fake_trackio}): + logger = trajectory_logger.AsyncTrajectoryLogger( + backends=[ + trajectory_logger.TrackioTrajectoryLogBackend( + project="tunix-smoke", + max_traces_per_step=1, + ) + ], + ) + logger.log_messages( + messages_list=[messages], + metadata_list=[{"trajectory_reward": 0.9}], + trace_key="agentic/trajectories", + step=3, + ) + logger.stop() + + trace.assert_called_once_with( + messages=messages, + metadata={ + "mode": "train", + "sample_index": 0, + "step": 3, + "trajectory_reward": 0.9, + }, + ) + run.log.assert_called_once_with({"agentic/trajectories": ["trace"]}, step=3) + + def test_direct_trackio_backend_still_logs_rollouts(self): + run = mock.MagicMock() + trace = mock.MagicMock( + side_effect=lambda messages, metadata: { + "messages": messages, + "metadata": metadata, + } + ) + fake_trackio = types.SimpleNamespace( + init=mock.MagicMock(return_value=run), + Trace=trace, + ) + + with mock.patch.dict(sys.modules, {"trackio": fake_trackio}): + logger = trajectory_logger.TrackioTrajectoryLogBackend( + project="tunix-smoke", + run_name="run-1", + max_traces_per_step=2, + ) + logger.log_rollouts( + prompts=["What is 2 + 2?"], + completions=["4", "Four."], + rewards=np.array([1.0, 0.8]), + advantages=np.array([[0.1, 0.2], [0.3, 0.4]]), + mode="train", + step=7, + ) + + fake_trackio.init.assert_called_once_with( + project="tunix-smoke", name="run-1" + ) + self.assertEqual(trace.call_count, 2) + self.assertEqual( + trace.call_args_list[0].kwargs["messages"], + [ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "4"}, + ], + ) + self.assertEqual(trace.call_args_list[0].kwargs["metadata"]["step"], 7) + self.assertEqual( + trace.call_args_list[1].kwargs["metadata"]["advantages"], [0.3, 0.4] + ) + run.log.assert_called_once_with( + {"rollout/traces": [mock.ANY, mock.ANY]}, step=7 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index e11d31760..3ebe97309 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -210,11 +210,13 @@ def __init__( metrics_logger_options.log_dir if metrics_logger_options else None ) - if metrics_log_dir: - self._trajectory_logger = trajectory_logger.AsyncTrajectoryLogger( - metrics_log_dir - ) - else: + self._trajectory_logger = ( + trajectory_logger.AsyncTrajectoryLogger.from_config( + log_dir=metrics_log_dir, + config=algo_config, + ) + ) + if not self._trajectory_logger.has_backends: logging.warning("Metrics log dir is None, skipping trajectory logging.") self.algo_config.temperature = self.rl_cluster.get_rollout_config( @@ -334,6 +336,30 @@ def _process_results( if self._trajectory_logger and trajectories_to_log: for traj in trajectories_to_log: self._trajectory_logger.log_item_async(traj) + messages_list = [] + metadata_list = [] + for trajectory_index, traj in enumerate(trajectories_to_log): + messages = traj.get("conversation_text") or [] + if not messages: + continue + messages_list.append(messages) + metadata_list.append({ + "trajectory_index": trajectory_index, + "trajectory_reward": traj.get("trajectory_reward"), + "status": traj.get("status"), + "policy_version": traj.get("policy_version"), + }) + self._trajectory_logger.log_messages( + messages_list=messages_list, + mode=mode, + step=self.rl_cluster.global_steps, + metadata_list=metadata_list, + metadata={ + "algorithm": self.algo_config.algo_variant, + "num_generations": self.algo_config.num_generations, + }, + trace_key="agentic/trajectories", + ) # Pad all prompts and completions to consistent lengths. rollout_config = self.rl_cluster.cluster_config.rollout_config diff --git a/tunix/rl/algorithm_config.py b/tunix/rl/algorithm_config.py index 0a4b178cf..cc9086f96 100644 --- a/tunix/rl/algorithm_config.py +++ b/tunix/rl/algorithm_config.py @@ -13,8 +13,11 @@ # limitations under the License. import dataclasses +from typing import Any + from absl import logging + @dataclasses.dataclass(slots=True, kw_only=True) class AlgorithmConfig: """Configuration for RL algorithms. @@ -23,12 +26,24 @@ class AlgorithmConfig: algo_variant: The core algorithm variant to use. advantage_estimator: The advantage estimator to use. policy_loss_fn: The policy loss function to use. + trackio_project: Trackio project name for rollout trace logging. If unset, + Trackio trace logging is disabled. + trackio_run_name: Optional Trackio run name for rollout trace logging. + trackio_trace_key: Metric key used for Trackio trace records. + trackio_max_traces_per_step: Maximum rollout traces to log per step. Set to + 0 to disable trace logging. + trackio_init_kwargs: Extra keyword arguments forwarded to `trackio.init`. """ algo_variant: str = "grpo" advantage_estimator: str = "grpo" policy_loss_fn: str = "grpo" reward_manager: str = "sequence-level" + trackio_project: str | None = None + trackio_run_name: str | None = None + trackio_trace_key: str = "rollout/traces" + trackio_max_traces_per_step: int = 0 + trackio_init_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): valid_algo_variants = [ diff --git a/tunix/rl/grpo/grpo_learner.py b/tunix/rl/grpo/grpo_learner.py index 1a20fa877..b4903f594 100644 --- a/tunix/rl/grpo/grpo_learner.py +++ b/tunix/rl/grpo/grpo_learner.py @@ -361,6 +361,14 @@ def _generate_and_compute_advantage( ) self.rl_cluster.buffer_metrics(user_defined_metric, mode=mode) + self._log_trackio_rollout_traces( + prompts=training_input["prompts"], + completions=rollout_output.text, + rewards=rewards, + advantages=advantages, + mode=mode, + ) + return TrainExample( prompt_ids=prompt_ids, prompt_mask=prompt_mask, diff --git a/tunix/rl/ppo/ppo_learner.py b/tunix/rl/ppo/ppo_learner.py index c16aed90d..4520c31a8 100644 --- a/tunix/rl/ppo/ppo_learner.py +++ b/tunix/rl/ppo/ppo_learner.py @@ -483,6 +483,14 @@ def _generate_and_compute_advantage( ) self.rl_cluster.buffer_metrics(user_defined_metric, mode=mode) + self._log_trackio_rollout_traces( + prompts=training_input["prompts"], + completions=rollout_output.text, + rewards=last_token_scores, + advantages=advantages, + mode=mode, + ) + return TrainExample( prompt_ids=prompt_ids, prompt_mask=prompt_mask, diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index 0c0bfc7f8..353c8f4e6 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -35,7 +35,7 @@ from tunix.rl import utils as rl_utils from tunix.rl.queue import data_queue as queue_lib from tunix.sft import utils as sft_utils - +from tunix.utils import trajectory_logger ABC = abc.ABC abstractmethod = abc.abstractmethod @@ -130,6 +130,9 @@ def __init__( ) self.executor = futures.ThreadPoolExecutor(max_workers=1) self._last_iter_step = self.rl_cluster.actor_trainer.iter_steps + self._trajectory_logger = ( + trajectory_logger.AsyncTrajectoryLogger.from_config(config=algo_config) + ) self._rollout_micro_batch_size = ( self._training_config.rollout_micro_batch_size @@ -139,6 +142,28 @@ def __init__( ) sft_utils.show_hbm_usage(title="RLLearner init") + def _log_trackio_rollout_traces( + self, + *, + prompts: Sequence[str | list[dict[str, str]]], + completions: Sequence[str], + rewards: Any = None, + advantages: Any = None, + mode: rl_cluster_lib.Mode = rl_cluster_lib.Mode.TRAIN, + ) -> None: + self._trajectory_logger.log_rollouts( + prompts=prompts, + completions=completions, + rewards=rewards, + advantages=advantages, + mode=mode, + step=self.rl_cluster.global_steps, + metadata={ + "algorithm": self.algo_config.algo_variant, + "num_generations": self._num_generations(), + }, + ) + @abstractmethod def _generate_and_compute_advantage( self, diff --git a/tunix/utils/trajectory_logger.py b/tunix/utils/trajectory_logger.py index 0dfe4e04d..9a360b816 100644 --- a/tunix/utils/trajectory_logger.py +++ b/tunix/utils/trajectory_logger.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Logging utilities for trajectory data, saving as CSV.""" +"""Logging utilities for trajectory data.""" import atexit import dataclasses @@ -23,7 +23,7 @@ import threading import time import types -from typing import Any +from typing import Any, Sequence from absl import logging from etils import epath @@ -32,6 +32,7 @@ import numpy as np import pandas as pd + def _make_serializable(item: Any) -> Any: """Makes an object serializable.""" if isinstance(item, dict): @@ -151,14 +152,241 @@ def log_item( df.to_csv(f, header=write_header, index=False) +def _is_main_process() -> bool: + jax = sys.modules.get('jax') + if jax is None: + return True + return jax.process_index() == 0 + + +def _to_metadata_value(value: Any) -> Any: + """Converts array-like values to JSON-compatible metadata.""" + if value is None: + return None + if isinstance(value, np.ndarray): + if value.size == 1: + return value.item() + return [_to_metadata_value(item) for item in value.tolist()] + if isinstance(value, np.generic): + return value.item() + if isinstance(value, dict): + return {key: _to_metadata_value(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_to_metadata_value(item) for item in value] + if hasattr(value, '__array__'): + return _to_metadata_value(np.asarray(value)) + return value + + +def _sequence_item(values: Any, index: int) -> Any: + if values is None: + return None + try: + return values[index] + except (IndexError, TypeError): + return None + + +def _prompt_for_completion( + prompts: Sequence[str | list[dict[str, str]]], + completion_index: int, + num_completions: int, +) -> str | list[dict[str, str]]: + if len(prompts) == num_completions: + return prompts[completion_index] + if len(prompts) > 0 and num_completions % len(prompts) == 0: + completions_per_prompt = num_completions // len(prompts) + return prompts[completion_index // completions_per_prompt] + return prompts[completion_index % len(prompts)] + + +def _messages_for_trace( + prompt: str | list[dict[str, str]], completion: str +) -> list[dict[str, str]]: + if isinstance(prompt, list): + messages = [dict(message) for message in prompt] + else: + messages = [{'role': 'user', 'content': prompt}] + messages.append({'role': 'assistant', 'content': completion}) + return messages + + +@dataclasses.dataclass(slots=True) +class CsvTrajectoryLogBackend: + """Logs trajectory records as CSV files.""" + + log_dir: str + file_suffix: str + + def log_items(self, items: list[dict[str, Any] | Any]) -> None: + log_item(self.log_dir, items, self.file_suffix) + + def close(self) -> None: + pass + + +@dataclasses.dataclass(slots=True) +class TrackioTrajectoryLogBackend: + """Logs rollout and trajectory records as Trackio traces.""" + + project: str | None = None + run_name: str | None = None + trace_key: str = 'rollout/traces' + max_traces_per_step: int = 0 + init_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + + _trackio: Any = dataclasses.field(init=False, default=None) + _run: Any = dataclasses.field(init=False, default=None) + + @classmethod + def from_config(cls, config: Any) -> 'TrackioTrajectoryLogBackend': + return cls( + project=getattr(config, 'trackio_project', None), + run_name=getattr(config, 'trackio_run_name', None), + trace_key=getattr(config, 'trackio_trace_key', 'rollout/traces'), + max_traces_per_step=getattr(config, 'trackio_max_traces_per_step', 0), + init_kwargs=dict(getattr(config, 'trackio_init_kwargs', {}) or {}), + ) + + @property + def enabled(self) -> bool: + return bool(self.project) and self.max_traces_per_step > 0 + + def _ensure_run(self) -> Any | None: + if not self.enabled or not _is_main_process(): + return None + if self._run is None: + try: + import trackio # pylint: disable=g-import-not-at-top + except ImportError: + logging.warning( + 'Trackio trace logging requested, but `trackio` is not installed.' + ) + self.max_traces_per_step = 0 + return None + self._trackio = trackio + self._run = trackio.init( + project=self.project, + name=self.run_name, + **self.init_kwargs, + ) + return self._run + + def log_rollouts( + self, + *, + prompts: Sequence[str | list[dict[str, str]]], + completions: Sequence[str], + rewards: Any = None, + advantages: Any = None, + mode: Any = 'train', + step: int | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Logs rollout completions as Trackio traces.""" + if len(prompts) == 0 or len(completions) == 0: + return + run = self._ensure_run() + if run is None: + return + + traces = [] + max_traces = min(self.max_traces_per_step, len(completions)) + for sample_index in range(max_traces): + prompt = _prompt_for_completion(prompts, sample_index, len(completions)) + trace_metadata = { + 'mode': str(mode), + 'sample_index': sample_index, + } + if step is not None: + trace_metadata['step'] = int(step) + reward = _sequence_item(rewards, sample_index) + if reward is not None: + trace_metadata['reward'] = _to_metadata_value(reward) + advantage = _sequence_item(advantages, sample_index) + if advantage is not None: + trace_metadata['advantages'] = _to_metadata_value(advantage) + if metadata: + trace_metadata.update(_to_metadata_value(metadata)) + + traces.append( + self._trackio.Trace( + messages=_messages_for_trace(prompt, completions[sample_index]), + metadata=trace_metadata, + ) + ) + + if traces: + run.log({self.trace_key: traces}, step=step) + + def log_messages( + self, + *, + messages_list: Sequence[list[dict[str, str]]], + mode: Any = 'train', + step: int | None = None, + metadata_list: Sequence[dict[str, Any]] | None = None, + metadata: dict[str, Any] | None = None, + trace_key: str | None = None, + ) -> None: + """Logs fully assembled chat message traces.""" + if len(messages_list) == 0: + return + run = self._ensure_run() + if run is None: + return + + traces = [] + max_traces = min(self.max_traces_per_step, len(messages_list)) + for sample_index in range(max_traces): + trace_metadata = { + 'mode': str(mode), + 'sample_index': sample_index, + } + if step is not None: + trace_metadata['step'] = int(step) + if metadata_list is not None and sample_index < len(metadata_list): + trace_metadata.update(_to_metadata_value(metadata_list[sample_index])) + if metadata: + trace_metadata.update(_to_metadata_value(metadata)) + + traces.append( + self._trackio.Trace( + messages=[ + dict(message) for message in messages_list[sample_index] + ], + metadata=trace_metadata, + ) + ) + + if traces: + run.log({trace_key or self.trace_key: traces}, step=step) + + def close(self) -> None: + if self._run is not None: + self._run.finish() + self._run = None + + class AsyncTrajectoryLogger: """A logger that logs trajectories asynchronously in a background thread.""" - def __init__(self, log_dir: str): - self._log_dir = log_dir + def __init__( + self, + log_dir: str | None = None, + backends: list[Any] | None = None, + ): self._file_suffix = str(int(time.time())) self._logging_queue = queue.Queue() self._stopped = False + self._backends = list(backends or []) + if log_dir: + self._backends.insert( + 0, CsvTrajectoryLogBackend(log_dir, self._file_suffix) + ) + self._logging_thread = None + if not self._backends: + return def _worker(): while True: @@ -181,7 +409,10 @@ def _worker(): break try: - log_item(self._log_dir, items, self._file_suffix) + for backend in self._backends: + log_items = getattr(backend, 'log_items', None) + if log_items is not None: + log_items(items) except Exception: # pylint: disable=broad-except logging.exception('Failed to log trajectories.') finally: @@ -205,6 +436,21 @@ def _worker(): logging.info('Started trajectory logging thread.') + @classmethod + def from_config( + cls, log_dir: str | None = None, config: Any | None = None + ) -> 'AsyncTrajectoryLogger': + backends = [] + if config is not None: + trackio_backend = TrackioTrajectoryLogBackend.from_config(config) + if trackio_backend.enabled: + backends.append(trackio_backend) + return cls(log_dir=log_dir, backends=backends) + + @property + def has_backends(self) -> bool: + return bool(self._backends) + def _handle_signal(self, signum: int, frame: types.FrameType): """Gracefully stops the logger and exits.""" del frame # Unused. @@ -222,16 +468,71 @@ def stop(self): """Stops the background logging thread gracefully.""" if self._stopped: return + if self._logging_thread is None: + self._stopped = True + return logging.info('Stopping trajectory logging thread...') self._logging_queue.put(None) self._logging_queue.join() self._logging_thread.join(timeout=10) + for backend in self._backends: + close = getattr(backend, 'close', None) + if close is not None: + close() self._stopped = True logging.info('Stopped trajectory logging thread.') def log_item_async(self, item: dict[str, Any] | Any): """Adds an item to the logging queue to be logged asynchronously.""" + if not self._backends: + return if self._stopped: logging.warning('Trajectory logger already stopped.') return self._logging_queue.put(item) + + def log_rollouts( + self, + *, + prompts: Sequence[str | list[dict[str, str]]], + completions: Sequence[str], + rewards: Any = None, + advantages: Any = None, + mode: Any = 'train', + step: int | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + for backend in self._backends: + log_rollouts = getattr(backend, 'log_rollouts', None) + if log_rollouts is not None: + log_rollouts( + prompts=prompts, + completions=completions, + rewards=rewards, + advantages=advantages, + mode=mode, + step=step, + metadata=metadata, + ) + + def log_messages( + self, + *, + messages_list: Sequence[list[dict[str, str]]], + mode: Any = 'train', + step: int | None = None, + metadata_list: Sequence[dict[str, Any]] | None = None, + metadata: dict[str, Any] | None = None, + trace_key: str | None = None, + ) -> None: + for backend in self._backends: + log_messages = getattr(backend, 'log_messages', None) + if log_messages is not None: + log_messages( + messages_list=messages_list, + mode=mode, + step=step, + metadata_list=metadata_list, + metadata=metadata, + trace_key=trace_key, + )