diff --git a/examples/agentic/qwen3_grpo_gsm8k_vtc_demo.py.py b/examples/agentic/qwen3_grpo_gsm8k_vtc_demo.py.py
new file mode 100644
index 000000000..a1c242275
--- /dev/null
+++ b/examples/agentic/qwen3_grpo_gsm8k_vtc_demo.py.py
@@ -0,0 +1,778 @@
+"""Agentic GSM8K VTC GRPO recipe for Qwen3-1.7B.
+
+This script keeps the current VTC recipe content intact while following the
+same pipeline layout as the FrozenLake Qwen3 recipe:
+
+1. logging / runtime setup
+2. argparse + recipe defaults
+3. shared mesh construction
+4. dataset loading
+5. tokenizer / model loading
+6. checkpoint + metrics + optimizer
+7. rollout + RL cluster
+8. GRPO trainer
+9. training
+"""
+
+from __future__ import annotations
+
+import argparse
+import gc
+import logging
+import math
+import os
+import re
+import sys
+import time
+from typing import Any
+
+from absl import logging as absl_logging
+
+# Disable pathways subslice check by appending it to sys.argv before JAX/absl
+# parse it.
+if "--pathways_enforce_subset_devices_form_subslice=false" not in sys.argv:
+ sys.argv.append("--pathways_enforce_subset_devices_form_subslice=false")
+
+os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
+os.environ["VLLM_TPU_RPA_VERSION"] = "2"
+os.environ["DISABLE_MOSAIC_ATTN"] = "1"
+
+# ====== Logging Configuration ======
+absl_logging.use_python_logging()
+logging.basicConfig(
+ stream=sys.stdout,
+ level=logging.INFO,
+ format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ force=True,
+)
+logging.getLogger().setLevel(logging.INFO)
+logging.getLogger("absl").setLevel(logging.INFO)
+absl_logging.set_verbosity(absl_logging.INFO)
+absl_logging.set_stderrthreshold("info")
+print("Logging configured at INFO level.")
+
+import grain
+from flax import nnx
+import jax
+from jax import numpy as jnp
+from jax.sharding import Mesh
+import numpy as np
+import optax
+from orbax import checkpoint as ocp
+import tensorflow_datasets as tfds
+# For OSS usage
+import tensorflow_datasets.text.gsm8k
+from transformers import AutoTokenizer
+
+REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+WORKDIR = os.getcwd()
+if os.path.exists(os.path.join(WORKDIR, "tunix")):
+ WORKSPACE_ROOT = WORKDIR
+else:
+ WORKSPACE_ROOT = os.path.dirname(REPO_ROOT)
+
+for root in [
+ REPO_ROOT,
+ WORKSPACE_ROOT,
+ os.path.join(WORKSPACE_ROOT, "tunix"),
+ os.path.join(WORKSPACE_ROOT, "pathways-utils"),
+ os.path.join(WORKSPACE_ROOT, "r2egym"),
+]:
+ if root not in sys.path:
+ sys.path.insert(0, root)
+
+_DISTRIBUTED_INITIALIZED = False
+try:
+ import tunix # pytype: disable=import-error # noqa: F401
+except Exception:
+ pass
+
+try:
+ import r2egym # pytype: disable=import-error # noqa: F401
+except Exception:
+ pass
+
+try:
+ import pathwaysutils # pytype: disable=import-error
+
+ pathwaysutils.initialize()
+ _DISTRIBUTED_INITIALIZED = True
+except Exception:
+ pass
+
+if not _DISTRIBUTED_INITIALIZED:
+ try:
+ jax.distributed.initialize()
+ except Exception as exc:
+ print(f"jax.distributed.initialize() skipped: {exc}")
+
+print("jax devices: ", jax.devices())
+
+from tunix.cli.utils import model as model_utils
+from tunix.models.qwen3 import model as qwen3_model_lib
+from tunix.models.qwen3 import params as qwen3_params_lib
+from tunix.oss import utils as oss_utils
+from tunix.rl import rl_cluster as rl_cluster_lib
+from tunix.rl import utils as rl_utils
+from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig, GRPOLearner
+from tunix.rl.agentic.parser.chat_template_parser import parser as chat_parser_lib
+from tunix.rl.rollout import base_rollout
+from tunix.sft import metrics_logger
+from tunix.sft import utils as sft_utils
+
+# ====== Argparse ======
+arg_parser = argparse.ArgumentParser(
+ description="Train Qwen3-1.7B on GSM8K with the VTC GRPO recipe."
+)
+arg_parser.add_argument("--batch_size", type=int, default=4)
+arg_parser.add_argument("--mini_batch_size", type=int, default=2)
+arg_parser.add_argument("--train_micro_batch_size", type=int, default=1)
+arg_parser.add_argument("--compute_logps_micro_batch_size", type=int, default=1)
+arg_parser.add_argument("--max_steps", type=int, default=200)
+arg_parser.add_argument("--max_response_length", type=int, default=1024)
+arg_parser.add_argument("--max_concurrency", type=int, default=None)
+arg_parser.add_argument("--mesh_fsdp", type=int, default=None)
+arg_parser.add_argument("--mesh_tp", type=int, default=None)
+arg_parser.add_argument(
+ "--rollout_vllm_hbm_utilization", type=float, default=0.6
+)
+arg_parser.add_argument("--rollout_vllm_max_num_seqs", type=int, default=None)
+arg_parser.add_argument(
+ "--rollout_vllm_max_num_batched_tokens", type=int, default=None
+)
+args, _ = arg_parser.parse_known_args()
+
+
+# ====== Recipe Defaults ======
+MODEL_NAME = "Qwen3-1.7B"
+MODEL_ID = f"Qwen/{MODEL_NAME}"
+SEED = 42
+
+NUM_PROMPTS_PER_STEP = args.batch_size
+NUM_GENERATIONS = 8
+MINI_BATCH_SIZE = args.mini_batch_size
+TRAIN_MICRO_BATCH_SIZE = args.train_micro_batch_size
+COMPUTE_LOGPS_MICRO_BATCH_SIZE = args.compute_logps_micro_batch_size
+
+MAX_STEPS = args.max_steps
+NUM_EPOCHS = 1000
+EVAL_EVERY_N_STEPS = 50
+EVAL_BATCH_SIZE = 128
+EVAL_AT_START = True
+EVAL_AT_END = True
+
+BETA = 0.04
+EPSILON = 0.2
+# NeMo's reference_policy_kl_type="k2" is exactly 0.5 * (logp-ref_logp)^2,
+# which matches Tunix's "mse_kl" implementation.
+KL_LOSS_MODE = "mse_kl"
+LEARNING_RATE = 2.0e-7
+WEIGHT_DECAY = 0.01
+ADAM_B1 = 0.9
+ADAM_B2 = 0.999
+ADAM_EPS = 1.0e-8
+MAX_GRAD_NORM = 1.0
+WARMUP_STEPS = 50
+LR_DECAY_STEPS = 500
+
+MAX_PROMPT_LENGTH = 1024
+MAX_RESPONSE_LENGTH = args.max_response_length
+MAX_TOTAL_SEQUENCE_LENGTH = 1024
+KV_CACHE_SIZE = MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256
+
+TRAIN_TEMPERATURE = 1.0
+TRAIN_TOP_P = 1.0
+TRAIN_TOP_K = None
+EVAL_TEMPERATURE = 0.0
+EVAL_TOP_P = 1.0
+EVAL_TOP_K = 1
+MAX_CONCURRENCY = args.max_concurrency or (
+ NUM_PROMPTS_PER_STEP * NUM_GENERATIONS
+)
+
+ROLLOUT_ENGINE = os.getenv("ROLLOUT_ENGINE", "vllm")
+USE_LORA = False
+LORA_RANK = 64
+LORA_ALPHA = 64.0
+ENABLE_CHECKPOINTING = False
+ENABLE_REMAT = False
+ENABLE_FLASH_ATTENTION = True
+MODEL_DTYPE = jnp.bfloat16
+
+ARTIFACT_ROOT = os.path.join(REPO_ROOT, "artifacts", "qwen3_grpo_gsm8k_vtc")
+TFDS_DATA_DIR = os.path.join(ARTIFACT_ROOT, "data")
+MODEL_DOWNLOAD_DIR = os.path.join(ARTIFACT_ROOT, "models")
+INTERMEDIATE_CKPT_DIR = os.path.join(ARTIFACT_ROOT, "intermediate_ckpt")
+CHECKPOINT_ROOT = os.path.join(
+ ARTIFACT_ROOT, "checkpoints", str(int(time.time()))
+)
+TB_LOG_DIR = os.path.join(ARTIFACT_ROOT, "logs")
+
+for path in [
+ TFDS_DATA_DIR,
+ MODEL_DOWNLOAD_DIR,
+ INTERMEDIATE_CKPT_DIR,
+ TB_LOG_DIR,
+]:
+ os.makedirs(path, exist_ok=True)
+if ENABLE_CHECKPOINTING:
+ os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
+
+show_hbm_usage = sft_utils.show_hbm_usage
+
+VTC_PROMPT_TEMPLATE = """Solve the following math problem.
+First, put your detailed step-by-step reasoning process inside ... tags.
+Then, put your final numerical answer inside \\boxed{{}} tags. Do not put anything else in the answer tags.
+
+Problem: {}
+
+"""
+
+_metric_call_idx = 0
+
+
+# ====== Shared Mesh ======
+MESH_FSDP = args.mesh_fsdp or 1
+MESH_TP = args.mesh_tp or (jax.device_count() // MESH_FSDP)
+SHARED_MESH_SHAPE = (MESH_FSDP, MESH_TP)
+SHARED_MESH_AXIS_NAMES = ("fsdp", "tp")
+
+if math.prod(SHARED_MESH_SHAPE) != jax.device_count():
+ raise ValueError(
+ "Shared mesh dimensions must multiply to device_count. "
+ f"Got mesh={SHARED_MESH_SHAPE}, devices={jax.device_count()}."
+ )
+
+shared_device_list = jax._src.mesh_utils.create_device_mesh(
+ SHARED_MESH_SHAPE, jax.devices()[: math.prod(SHARED_MESH_SHAPE)]
+)
+shared_mesh = jax.sharding.Mesh(
+ shared_device_list,
+ axis_names=SHARED_MESH_AXIS_NAMES,
+ axis_types=(jax.sharding.AxisType.Auto,) * len(SHARED_MESH_SHAPE),
+)
+print(f"shared_mesh.devices.shape={shared_mesh.devices.shape}")
+
+
+# ====== Data ======
+def _as_text(value: Any) -> str:
+ return value if isinstance(value, str) else value.decode("utf-8")
+
+
+def extract_hash_answer(text: str) -> str | None:
+ if "####" not in text:
+ return None
+ return text.split("####", 1)[1].strip()
+
+
+def build_prompt(question: str) -> str:
+ return VTC_PROMPT_TEMPLATE.format(question)
+
+
+def build_gsm8k_dataset(
+ *,
+ split: str,
+ seed: int,
+ batch_size: int,
+ data_dir: str,
+ shuffle: bool,
+) -> grain.MapDataset:
+ data = tfds.data_source(
+ "gsm8k",
+ split=split,
+ data_dir=data_dir,
+ builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
+ download=True,
+ )
+
+ dataset = grain.MapDataset.source(data)
+ if shuffle:
+ dataset = dataset.shuffle(seed=seed)
+
+ dataset = dataset.map(
+ lambda x: {
+ "prompts": build_prompt(_as_text(x["question"])),
+ "question": _as_text(x["question"]),
+ "answer": extract_hash_answer(_as_text(x["answer"])),
+ }
+ )
+ return dataset.batch(batch_size)
+
+
+def create_datasets() -> tuple[grain.MapDataset, grain.MapDataset]:
+ train_dataset = build_gsm8k_dataset(
+ split="train",
+ seed=SEED,
+ batch_size=NUM_PROMPTS_PER_STEP,
+ data_dir=TFDS_DATA_DIR,
+ shuffle=True,
+ ).repeat(NUM_EPOCHS)
+ eval_dataset = build_gsm8k_dataset(
+ split="test",
+ seed=SEED,
+ batch_size=EVAL_BATCH_SIZE,
+ data_dir=TFDS_DATA_DIR,
+ shuffle=False,
+ )
+ return train_dataset, eval_dataset
+
+
+def _normalize_example_value(value: Any) -> Any:
+ if isinstance(value, np.ndarray):
+ flat = value.reshape(-1).tolist()
+ if len(flat) == 1:
+ return _normalize_example_value(flat[0])
+ return [_normalize_example_value(v) for v in flat]
+ if isinstance(value, np.bytes_):
+ return value.tobytes().decode("utf-8")
+ if isinstance(value, bytes):
+ return value.decode("utf-8")
+ return value
+
+
+def normalize_single_example(example: dict[str, Any]) -> dict[str, Any]:
+ return {
+ key: _normalize_example_value(value) for key, value in example.items()
+ }
+
+
+# ====== Reward + Metrics ======
+def extract_boxed_answer(text: str) -> str | None:
+ answer_blocks = re.findall(r"(.*?)", text, re.DOTALL)
+ content = answer_blocks[-1] if answer_blocks else text
+
+ boxed = []
+ stack = []
+ for i, ch in enumerate(content):
+ if ch == "{":
+ stack.append(i)
+ elif ch == "}":
+ if not stack:
+ continue
+ open_idx = stack.pop()
+ if content[:open_idx].endswith(r"\boxed"):
+ boxed.append(content[open_idx + 1 : i].strip())
+ if boxed:
+ return boxed[-1]
+
+ fallback = re.search(r"\\boxed\s*\{?\s*([a-zA-Z0-9\.,\-]+)\s*\}?", content)
+ if fallback:
+ return fallback.group(1).strip()
+ return None
+
+
+def is_vtc_format_correct(text: str) -> bool:
+ has_reasoning = text.count("") == 1
+ has_answer = text.count("") == 1 and text.count("") == 1
+ reasoning_end = text.find("")
+ answer_open = text.find("")
+ answer_close = text.find("")
+ return (
+ has_reasoning
+ and has_answer
+ and reasoning_end != -1
+ and answer_open != -1
+ and answer_close != -1
+ and reasoning_end < answer_open < answer_close
+ )
+
+
+def normalize_answer(text: str | None) -> str | None:
+ if text is None:
+ return None
+ return str(text).replace(",", "").strip()
+
+
+def _vtc_completion_outcome(
+ completion: str, gold: Any
+) -> tuple[float, bool, bool, bool]:
+ format_ok = is_vtc_format_correct(completion)
+ pred = normalize_answer(extract_boxed_answer(completion))
+ true = normalize_answer(_normalize_example_value(gold))
+ answer_ok = pred is not None and true is not None and pred == true
+ extracted_ok = pred is not None
+
+ if format_ok and answer_ok:
+ score = 1.0
+ elif format_ok and not answer_ok:
+ score = 0.1
+ elif not format_ok and answer_ok:
+ score = 0.5
+ else:
+ score = 0.0
+ return score, format_ok, answer_ok, extracted_ok
+
+
+def vtc_env_reward(task, action):
+ gold = task.get("answer")
+ completion = action.action if hasattr(action, "action") else action
+ score, _, _, _ = _vtc_completion_outcome(completion, gold)
+ return score
+
+
+def vtc_metric_fn(prompts, completions, rewards, advantages, answer, **kwargs):
+ del prompts, completions, advantages, answer, kwargs
+ global _metric_call_idx
+ _metric_call_idx += 1
+
+ rewards = np.asarray(rewards, dtype=np.float32)
+ solve_all = bool(np.all(rewards > 0.1))
+ solve_none = bool(np.all(np.isclose(rewards, 0.0)))
+ solve_partial = (not solve_all) and (not solve_none)
+ solve_ratio = float(np.mean(rewards > 0.1))
+ reward_mean = float(rewards.mean())
+ reward_max = float(rewards.max())
+
+ absl_logging.info(
+ "[rollout-metric] call=%d n=%d solve_ratio=%.3f reward_mean=%.3f"
+ " reward_max=%.3f solve_all=%d solve_none=%d",
+ _metric_call_idx,
+ len(rewards),
+ solve_ratio,
+ reward_mean,
+ reward_max,
+ int(solve_all),
+ int(solve_none),
+ )
+ return {
+ "rewards/solve_all": (1 if solve_all else 0, np.mean),
+ "rewards/solve_none": (1 if solve_none else 0, np.mean),
+ "rewards/solve_partial": (1 if solve_partial else 0, np.mean),
+ "rewards/solve_ratio": (solve_ratio, np.mean),
+ }
+
+
+# ====== Tokenizer / Model ======
+class VTCRawTextParser:
+ """Raw-text prompt parser matching NeMo's vtc_raw_text_processor style."""
+
+ def parse(
+ self,
+ messages,
+ add_generation_prompt: bool = False,
+ is_first_msg: bool = False,
+ ) -> str:
+ del add_generation_prompt, is_first_msg
+ parts = []
+ for message in messages:
+ role = message.get("role")
+ content = message.get("content", "")
+ if role == "system" and content:
+ parts.append(content)
+ elif role == "user":
+ parts.append(content)
+ elif role == "assistant" and content:
+ parts.append(content)
+ return "\n".join(parts)
+
+
+class VTCGRPOLearner(GRPOLearner):
+ """Demo-local learner that normalizes TFDS string payloads to Python str."""
+
+ def _create_agent_env_pair(
+ self, single_example, group_id: int, pair_index: int
+ ):
+ normalized_example = normalize_single_example(single_example)
+ return super()._create_agent_env_pair(
+ normalized_example, group_id=group_id, pair_index=pair_index
+ )
+
+
+def ensure_model_downloaded() -> None:
+ if not os.path.isdir(MODEL_DOWNLOAD_DIR) or not any(
+ filename.endswith(".safetensors")
+ for filename in os.listdir(MODEL_DOWNLOAD_DIR)
+ ):
+ os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True)
+ oss_utils.hf_pipeline(MODEL_ID, MODEL_DOWNLOAD_DIR)
+
+
+def maybe_apply_lora(model: nnx.Module, mesh: Mesh) -> nnx.Module:
+ lora_config = {
+ "module_path": (
+ ".*q_proj|.*k_proj|.*v_proj|.*o_proj|"
+ ".*gate_proj|.*down_proj|.*up_proj"
+ ),
+ "rank": LORA_RANK,
+ "alpha": LORA_ALPHA,
+ }
+ return model_utils.apply_lora_to_model(
+ model, mesh=mesh, lora_config=lora_config
+ )
+
+
+def put_model_on_device(model: nnx.Module) -> nnx.Module:
+ graph_def, state = nnx.split(model)
+ state = rl_utils.put_params_on_memory_kind(state, "device")
+ return nnx.merge(graph_def, state)
+
+
+def create_reference_and_actor(mesh: Mesh) -> tuple[nnx.Module, nnx.Module]:
+ ensure_model_downloaded()
+
+ config = qwen3_model_lib.ModelConfig.qwen3_1p7b()
+ if ENABLE_REMAT:
+ config.remat_config = qwen3_model_lib.RematConfig.DECODER
+ else:
+ config.remat_config = qwen3_model_lib.RematConfig.NONE
+ if ENABLE_FLASH_ATTENTION:
+ config.use_flash_attention = True
+ config.flash_attention_block_size = 256
+ config.dtype = jnp.bfloat16
+ config.param_dtype = jnp.float32
+
+ reference = qwen3_params_lib.create_model_from_safe_tensors(
+ MODEL_DOWNLOAD_DIR, config, mesh, dtype=MODEL_DTYPE
+ )
+ actor_base = qwen3_params_lib.create_model_from_safe_tensors(
+ MODEL_DOWNLOAD_DIR, config, mesh, dtype=jnp.float32
+ )
+
+ reference = put_model_on_device(reference)
+ actor = maybe_apply_lora(actor_base, mesh) if USE_LORA else actor_base
+ actor = put_model_on_device(actor)
+ return reference, actor
+
+
+# ====== Checkpoint + Metrics + Optimizer ======
+if ENABLE_CHECKPOINTING:
+ checkpointing_options = ocp.CheckpointManagerOptions(
+ save_interval_steps=MAX_STEPS,
+ max_to_keep=1,
+ )
+else:
+ checkpointing_options = None
+
+wandb_config = vars(args).copy()
+wandb_config.update({
+ "model_id": MODEL_ID,
+ "mesh_shape": SHARED_MESH_SHAPE,
+ "num_steps": MAX_STEPS,
+ "num_generations": NUM_GENERATIONS,
+ "kl_loss_mode": KL_LOSS_MODE,
+ "train_temperature": TRAIN_TEMPERATURE,
+})
+metrics_logging_options = metrics_logger.MetricsLoggerOptions(
+ log_dir=TB_LOG_DIR,
+ project_name="tunix-gsm8k-vtc",
+ flush_every_n_steps=1,
+ backend_kwargs={"wandb": {"config": wandb_config}},
+)
+
+
+def create_optimizer() -> optax.GradientTransformation:
+ optimizer = optax.adamw(
+ learning_rate=optax.warmup_cosine_decay_schedule(
+ init_value=0.0,
+ peak_value=LEARNING_RATE,
+ warmup_steps=WARMUP_STEPS,
+ decay_steps=LR_DECAY_STEPS,
+ end_value=0.0,
+ ),
+ b1=ADAM_B1,
+ b2=ADAM_B2,
+ eps=ADAM_EPS,
+ weight_decay=WEIGHT_DECAY,
+ )
+ return optax.chain(optax.clip_by_global_norm(MAX_GRAD_NORM), optimizer)
+
+
+def _shutdown_rollout_runtime(rl_cluster) -> None:
+ rollout = getattr(rl_cluster, "rollout", None)
+ if rollout is not None:
+ for method_name in ("close", "stop", "shutdown"):
+ method = getattr(rollout, method_name, None)
+ if callable(method):
+ try:
+ method()
+ except Exception:
+ absl_logging.exception(
+ "Failed to %s rollout runtime during demo teardown.",
+ method_name,
+ )
+ break
+ gc.collect()
+ try:
+ jax.clear_caches()
+ except Exception:
+ absl_logging.exception("Failed to clear JAX caches during demo teardown.")
+
+
+def main() -> None:
+ # ====== Data ======
+ train_dataset, eval_dataset = create_datasets()
+ show_hbm_usage("Done with loading datasets")
+
+ # ====== Tokenizer / Model ======
+ tokenizer = AutoTokenizer.from_pretrained(
+ MODEL_ID,
+ token=os.getenv("HF_TOKEN"),
+ trust_remote_code=True,
+ )
+ chat_parser = VTCRawTextParser()
+ qwen_eos_tokens = tokenizer.encode("<|im_end|>", add_special_tokens=False)
+
+ reference, actor = create_reference_and_actor(shared_mesh)
+ show_hbm_usage("after loading qwen_ref / qwen_actor")
+
+ # ====== Rollout + RL cluster ======
+ base_rollout_dict = {
+ "max_prompt_length": MAX_PROMPT_LENGTH,
+ "kv_cache_size": KV_CACHE_SIZE,
+ "max_tokens_to_generate": MAX_RESPONSE_LENGTH,
+ "eos_tokens": qwen_eos_tokens,
+ "return_logprobs": True,
+ }
+ train_rollout_dict = {
+ "temperature": TRAIN_TEMPERATURE,
+ "top_p": TRAIN_TOP_P,
+ "top_k": TRAIN_TOP_K,
+ }
+ eval_rollout_dict = {
+ "temperature": EVAL_TEMPERATURE,
+ "top_p": EVAL_TOP_P,
+ "top_k": EVAL_TOP_K,
+ }
+
+ vllm_max_num_seqs = (
+ args.rollout_vllm_max_num_seqs
+ if args.rollout_vllm_max_num_seqs is not None
+ else NUM_PROMPTS_PER_STEP * NUM_GENERATIONS
+ )
+ vllm_max_batched_tokens = (
+ args.rollout_vllm_max_num_batched_tokens
+ if args.rollout_vllm_max_num_batched_tokens is not None
+ else (vllm_max_num_seqs * KV_CACHE_SIZE) // 8
+ )
+ vllm_rollout_dict = {
+ "rollout_vllm_model_version": MODEL_ID,
+ "rollout_vllm_hbm_utilization": args.rollout_vllm_hbm_utilization,
+ "rollout_vllm_server_mode": True,
+ "rollout_vllm_async_scheduling": False,
+ "tensor_parallel_size": SHARED_MESH_SHAPE[1],
+ "data_parallel_size": SHARED_MESH_SHAPE[0],
+ "rollout_vllm_max_num_seqs": vllm_max_num_seqs,
+ "rollout_vllm_max_num_batched_tokens": vllm_max_batched_tokens,
+ "rollout_vllm_kwargs": {
+ "kv_cache_metrics": True,
+ "disable_log_stats": False,
+ "enable_prefix_caching": False,
+ "dtype": "bfloat16",
+ },
+ }
+ if jax.default_backend() == "tpu":
+ vllm_rollout_dict["rollout_vllm_tpu_backend_type"] = "jax"
+
+ if ROLLOUT_ENGINE == "vllm":
+ train_rollout_config = base_rollout.RolloutConfig(
+ **base_rollout_dict, **train_rollout_dict, **vllm_rollout_dict
+ )
+ eval_rollout_config = base_rollout.RolloutConfig(
+ **base_rollout_dict, **eval_rollout_dict, **vllm_rollout_dict
+ )
+ elif ROLLOUT_ENGINE == "vanilla":
+ train_rollout_config = base_rollout.RolloutConfig(
+ **base_rollout_dict, **train_rollout_dict
+ )
+ eval_rollout_config = base_rollout.RolloutConfig(
+ **base_rollout_dict, **eval_rollout_dict
+ )
+ else:
+ raise ValueError(f"Unsupported rollout engine: {ROLLOUT_ENGINE}")
+
+ cluster_config = rl_cluster_lib.ClusterConfig(
+ role_to_mesh={
+ rl_cluster_lib.Role.ACTOR: shared_mesh,
+ rl_cluster_lib.Role.REFERENCE: shared_mesh,
+ rl_cluster_lib.Role.ROLLOUT: shared_mesh,
+ },
+ rollout_engine=ROLLOUT_ENGINE,
+ offload_to_cpu=False,
+ training_config=rl_cluster_lib.RLTrainingConfig(
+ actor_optimizer=create_optimizer(),
+ eval_every_n_steps=EVAL_EVERY_N_STEPS,
+ max_steps=MAX_STEPS,
+ max_inflight_computations=1,
+ mini_batch_size=MINI_BATCH_SIZE,
+ train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
+ compute_logps_micro_batch_size=COMPUTE_LOGPS_MICRO_BATCH_SIZE,
+ metrics_logging_options=metrics_logging_options,
+ checkpoint_root_directory=(
+ CHECKPOINT_ROOT if ENABLE_CHECKPOINTING else None
+ ),
+ checkpointing_options=checkpointing_options,
+ ),
+ rollout_config={
+ rl_cluster_lib.Mode.TRAIN: train_rollout_config,
+ rl_cluster_lib.Mode.EVAL: eval_rollout_config,
+ },
+ )
+
+ grpo_config = GRPOConfig(
+ num_generations=NUM_GENERATIONS,
+ num_iterations=1,
+ beta=BETA,
+ kl_loss_mode=KL_LOSS_MODE,
+ epsilon=EPSILON,
+ epsilon_high=EPSILON,
+ advantage_estimator="grpo",
+ degenerate_group_masking=False,
+ use_rollout_logps=False,
+ system_prompt="",
+ max_response_length=MAX_RESPONSE_LENGTH,
+ max_concurrency=MAX_CONCURRENCY,
+ loss_agg_mode="sequence-mean-token-mean",
+ )
+
+ rl_cluster = rl_cluster_lib.RLCluster(
+ actor=actor,
+ reference=reference,
+ tokenizer=tokenizer,
+ cluster_config=cluster_config,
+ )
+ show_hbm_usage("after RLCluster creation")
+
+ # ====== Trainer ======
+ grpo_trainer = VTCGRPOLearner(
+ rl_cluster=rl_cluster,
+ algo_config=grpo_config,
+ chat_parser=chat_parser,
+ metric_fns=[vtc_metric_fn],
+ env_kwargs={"reward_fn": vtc_env_reward},
+ )
+ show_hbm_usage("after GRPOLearner creation")
+
+ print("Shared mesh:", shared_mesh)
+ print(
+ "Config summary:",
+ {
+ "model_id": MODEL_ID,
+ "mesh_shape": SHARED_MESH_SHAPE,
+ "rollout_engine": ROLLOUT_ENGINE,
+ "prompts_per_step": NUM_PROMPTS_PER_STEP,
+ "num_generations": NUM_GENERATIONS,
+ "mini_batch_size": MINI_BATCH_SIZE,
+ "train_micro_batch_size": TRAIN_MICRO_BATCH_SIZE,
+ "compute_logps_micro_batch_size": COMPUTE_LOGPS_MICRO_BATCH_SIZE,
+ "max_steps": MAX_STEPS,
+ "max_response_length": MAX_RESPONSE_LENGTH,
+ "max_concurrency": MAX_CONCURRENCY,
+ "rollout_vllm_hbm_utilization": args.rollout_vllm_hbm_utilization,
+ "rollout_vllm_max_num_seqs": vllm_max_num_seqs,
+ "rollout_vllm_max_num_batched_tokens": vllm_max_batched_tokens,
+ },
+ )
+
+ # ====== Training ======
+ try:
+ grpo_trainer.train(train_dataset, eval_dataset=eval_dataset)
+ except Exception:
+ rl_cluster.close()
+ raise
+ finally:
+ _shutdown_rollout_runtime(rl_cluster)
+
+
+if __name__ == "__main__":
+ main()