Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,27 @@ With the above, agentic_grpo_learner will by default start an async trajectory
logger which logs the trajectories including prompts, responses, etc. to the
specified `log_dir`.

### Trackio Trajectory Backend

RL algorithms can also log rollout prompt/completion pairs as Trackio traces
through the trajectory logger. Set `trackio_project` and a positive
`trackio_max_traces_per_step` on the algorithm config. Tunix logs standard
PPO/GRPO rollouts under `rollout/traces` and agentic rollouts under
`agentic/trajectories`.

```python
from tunix.rl.grpo import grpo_learner

algo_config = grpo_learner.GRPOConfig(
trackio_project="my-rl-project",
trackio_run_name="experiment-1",
trackio_max_traces_per_step=16,
)
```

Trackio is optional. Install it in the training environment before enabling
trace logging.

### Enabling Metrics in Jobs

Once you have your `MetricsLoggerOptions` configured, you can pass it to your
Expand Down
197 changes: 197 additions & 0 deletions tests/test_trajectory_logger_trackio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import pathlib
import sys
import types
import unittest
from unittest import mock

import numpy as np

_MODULE_PATH = (
pathlib.Path(__file__).resolve().parents[1]
/ "tunix"
/ "utils"
/ "trajectory_logger.py"
)
_SPEC = importlib.util.spec_from_file_location(
"trajectory_logger", _MODULE_PATH
)
trajectory_logger = importlib.util.module_from_spec(_SPEC)
assert _SPEC.loader is not None
sys.modules[_SPEC.name] = trajectory_logger
_SPEC.loader.exec_module(trajectory_logger)


class TrackioTrajectoryLoggerTest(unittest.TestCase):

def test_log_rollouts_creates_trackio_traces(self):
run = mock.MagicMock()
trace = mock.MagicMock(
side_effect=lambda messages, metadata: {
"messages": messages,
"metadata": metadata,
}
)
fake_trackio = types.SimpleNamespace(
init=mock.MagicMock(return_value=run),
Trace=trace,
)

with mock.patch.dict(sys.modules, {"trackio": fake_trackio}):
logger = trajectory_logger.AsyncTrajectoryLogger(
backends=[
trajectory_logger.TrackioTrajectoryLogBackend(
project="tunix-smoke",
run_name="run-1",
max_traces_per_step=2,
)
],
)
logger.log_rollouts(
prompts=["What is 2 + 2?"],
completions=["4", "Four."],
rewards=np.array([1.0, 0.8]),
advantages=np.array([[0.1, 0.2], [0.3, 0.4]]),
mode="train",
step=7,
)
logger.stop()

fake_trackio.init.assert_called_once_with(
project="tunix-smoke", name="run-1"
)
self.assertEqual(trace.call_count, 2)
self.assertEqual(
trace.call_args_list[0].kwargs["messages"],
[
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "4"},
],
)
self.assertEqual(trace.call_args_list[0].kwargs["metadata"]["step"], 7)
self.assertEqual(
trace.call_args_list[1].kwargs["metadata"]["advantages"], [0.3, 0.4]
)
run.log.assert_called_once_with(
{"rollout/traces": [mock.ANY, mock.ANY]}, step=7
)

def test_from_config_adds_trackio_backend(self):
config = types.SimpleNamespace(
trackio_project="tunix-smoke",
trackio_run_name="run-1",
trackio_trace_key="rollout/traces",
trackio_max_traces_per_step=2,
trackio_init_kwargs={},
)

logger = trajectory_logger.AsyncTrajectoryLogger.from_config(config=config)

self.assertTrue(logger.has_backends)
logger.stop()

def test_log_messages_uses_existing_chat_messages(self):
run = mock.MagicMock()
trace = mock.MagicMock(return_value="trace")
fake_trackio = types.SimpleNamespace(
init=mock.MagicMock(return_value=run),
Trace=trace,
)
messages = [
{"role": "system", "content": "Be concise."},
{"role": "user", "content": "Name a color."},
{"role": "assistant", "content": "Blue."},
]

with mock.patch.dict(sys.modules, {"trackio": fake_trackio}):
logger = trajectory_logger.AsyncTrajectoryLogger(
backends=[
trajectory_logger.TrackioTrajectoryLogBackend(
project="tunix-smoke",
max_traces_per_step=1,
)
],
)
logger.log_messages(
messages_list=[messages],
metadata_list=[{"trajectory_reward": 0.9}],
trace_key="agentic/trajectories",
step=3,
)
logger.stop()

trace.assert_called_once_with(
messages=messages,
metadata={
"mode": "train",
"sample_index": 0,
"step": 3,
"trajectory_reward": 0.9,
},
)
run.log.assert_called_once_with({"agentic/trajectories": ["trace"]}, step=3)

def test_direct_trackio_backend_still_logs_rollouts(self):
run = mock.MagicMock()
trace = mock.MagicMock(
side_effect=lambda messages, metadata: {
"messages": messages,
"metadata": metadata,
}
)
fake_trackio = types.SimpleNamespace(
init=mock.MagicMock(return_value=run),
Trace=trace,
)

with mock.patch.dict(sys.modules, {"trackio": fake_trackio}):
logger = trajectory_logger.TrackioTrajectoryLogBackend(
project="tunix-smoke",
run_name="run-1",
max_traces_per_step=2,
)
logger.log_rollouts(
prompts=["What is 2 + 2?"],
completions=["4", "Four."],
rewards=np.array([1.0, 0.8]),
advantages=np.array([[0.1, 0.2], [0.3, 0.4]]),
mode="train",
step=7,
)

fake_trackio.init.assert_called_once_with(
project="tunix-smoke", name="run-1"
)
self.assertEqual(trace.call_count, 2)
self.assertEqual(
trace.call_args_list[0].kwargs["messages"],
[
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "4"},
],
)
self.assertEqual(trace.call_args_list[0].kwargs["metadata"]["step"], 7)
self.assertEqual(
trace.call_args_list[1].kwargs["metadata"]["advantages"], [0.3, 0.4]
)
run.log.assert_called_once_with(
{"rollout/traces": [mock.ANY, mock.ANY]}, step=7
)


if __name__ == "__main__":
unittest.main()
36 changes: 31 additions & 5 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ def __init__(
metrics_logger_options.log_dir if metrics_logger_options else None
)

if metrics_log_dir:
self._trajectory_logger = trajectory_logger.AsyncTrajectoryLogger(
metrics_log_dir
)
else:
self._trajectory_logger = (
trajectory_logger.AsyncTrajectoryLogger.from_config(
log_dir=metrics_log_dir,
config=algo_config,
)
)
if not self._trajectory_logger.has_backends:
logging.warning("Metrics log dir is None, skipping trajectory logging.")

self.algo_config.temperature = self.rl_cluster.get_rollout_config(
Expand Down Expand Up @@ -334,6 +336,30 @@ def _process_results(
if self._trajectory_logger and trajectories_to_log:
for traj in trajectories_to_log:
self._trajectory_logger.log_item_async(traj)
messages_list = []
metadata_list = []
for trajectory_index, traj in enumerate(trajectories_to_log):
messages = traj.get("conversation_text") or []
if not messages:
continue
messages_list.append(messages)
metadata_list.append({
"trajectory_index": trajectory_index,
"trajectory_reward": traj.get("trajectory_reward"),
"status": traj.get("status"),
"policy_version": traj.get("policy_version"),
})
self._trajectory_logger.log_messages(
messages_list=messages_list,
mode=mode,
step=self.rl_cluster.global_steps,
metadata_list=metadata_list,
metadata={
"algorithm": self.algo_config.algo_variant,
"num_generations": self.algo_config.num_generations,
},
trace_key="agentic/trajectories",
)

# Pad all prompts and completions to consistent lengths.
rollout_config = self.rl_cluster.cluster_config.rollout_config
Expand Down
15 changes: 15 additions & 0 deletions tunix/rl/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.

import dataclasses
from typing import Any

from absl import logging


@dataclasses.dataclass(slots=True, kw_only=True)
class AlgorithmConfig:
"""Configuration for RL algorithms.
Expand All @@ -23,12 +26,24 @@ class AlgorithmConfig:
algo_variant: The core algorithm variant to use.
advantage_estimator: The advantage estimator to use.
policy_loss_fn: The policy loss function to use.
trackio_project: Trackio project name for rollout trace logging. If unset,
Trackio trace logging is disabled.
trackio_run_name: Optional Trackio run name for rollout trace logging.
trackio_trace_key: Metric key used for Trackio trace records.
trackio_max_traces_per_step: Maximum rollout traces to log per step. Set to
0 to disable trace logging.
trackio_init_kwargs: Extra keyword arguments forwarded to `trackio.init`.
"""

algo_variant: str = "grpo"
advantage_estimator: str = "grpo"
policy_loss_fn: str = "grpo"
reward_manager: str = "sequence-level"
trackio_project: str | None = None
trackio_run_name: str | None = None
trackio_trace_key: str = "rollout/traces"
trackio_max_traces_per_step: int = 0
trackio_init_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
valid_algo_variants = [
Expand Down
8 changes: 8 additions & 0 deletions tunix/rl/grpo/grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,14 @@ def _generate_and_compute_advantage(
)
self.rl_cluster.buffer_metrics(user_defined_metric, mode=mode)

self._log_trackio_rollout_traces(
prompts=training_input["prompts"],
completions=rollout_output.text,
rewards=rewards,
advantages=advantages,
mode=mode,
)

return TrainExample(
prompt_ids=prompt_ids,
prompt_mask=prompt_mask,
Expand Down
8 changes: 8 additions & 0 deletions tunix/rl/ppo/ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,14 @@ def _generate_and_compute_advantage(
)
self.rl_cluster.buffer_metrics(user_defined_metric, mode=mode)

self._log_trackio_rollout_traces(
prompts=training_input["prompts"],
completions=rollout_output.text,
rewards=last_token_scores,
advantages=advantages,
mode=mode,
)

return TrainExample(
prompt_ids=prompt_ids,
prompt_mask=prompt_mask,
Expand Down
27 changes: 26 additions & 1 deletion tunix/rl/rl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from tunix.rl import utils as rl_utils
from tunix.rl.queue import data_queue as queue_lib
from tunix.sft import utils as sft_utils

from tunix.utils import trajectory_logger

ABC = abc.ABC
abstractmethod = abc.abstractmethod
Expand Down Expand Up @@ -130,6 +130,9 @@ def __init__(
)
self.executor = futures.ThreadPoolExecutor(max_workers=1)
self._last_iter_step = self.rl_cluster.actor_trainer.iter_steps
self._trajectory_logger = (
trajectory_logger.AsyncTrajectoryLogger.from_config(config=algo_config)
)

self._rollout_micro_batch_size = (
self._training_config.rollout_micro_batch_size
Expand All @@ -139,6 +142,28 @@ def __init__(
)
sft_utils.show_hbm_usage(title="RLLearner init")

def _log_trackio_rollout_traces(
self,
*,
prompts: Sequence[str | list[dict[str, str]]],
completions: Sequence[str],
rewards: Any = None,
advantages: Any = None,
mode: rl_cluster_lib.Mode = rl_cluster_lib.Mode.TRAIN,
) -> None:
self._trajectory_logger.log_rollouts(
prompts=prompts,
completions=completions,
rewards=rewards,
advantages=advantages,
mode=mode,
step=self.rl_cluster.global_steps,
metadata={
"algorithm": self.algo_config.algo_variant,
"num_generations": self._num_generations(),
},
)

@abstractmethod
def _generate_and_compute_advantage(
self,
Expand Down
Loading
Loading