diff --git a/examples/rl/grpo/gsm8k/configs/qwen3.yaml b/examples/rl/grpo/gsm8k/configs/qwen3.yaml
new file mode 100644
index 000000000..71d84854c
--- /dev/null
+++ b/examples/rl/grpo/gsm8k/configs/qwen3.yaml
@@ -0,0 +1,89 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+model_config:
+ model_name: "Qwen3-1.7B-base"
+ model_id: "Qwen/Qwen3-1.7B-base"
+ model_source: "huggingface"
+ use_flash_attention: true
+ flash_attention_block_size: 256
+ mesh:
+ shape: "(2,4)"
+ axis_names: "('fsdp','tp')"
+ rng_seed: 42
+
+actor_model_config:
+ lora_config:
+ rank: 64
+ alpha: 64.0
+ module_path: ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj"
+ mesh:
+ shape: "(2,4)"
+ axis_names: "('fsdp','tp')"
+
+reference_model_config:
+ mesh: null
+ same_mesh_as: "actor"
+
+rollout_model_config:
+ mesh: null
+ same_mesh_as: "actor"
+
+tokenizer_config:
+ tokenizer_type: "huggingface"
+ add_bos: false
+
+dataset_name: "gsm8k"
+batch_size: 8
+num_test_batches: 100
+num_train_epochs: 1
+
+rl_training_config:
+ actor_optimizer_config:
+ opt_type: "adamw"
+ peak_value: 3e-6
+ schedule_type: "warmup_cosine_decay_schedule"
+ init_value: 0.0
+ end_value: 0.0
+ warmup_ratio: 0.1
+ b1: 0.9
+ b2: 0.99
+ weight_decay: 0.1
+ max_grad_norm: 0.1
+ eval_every_n_steps: 10
+ metrics_logging_options:
+ flush_every_n_steps: 20
+ checkpointing_options:
+ save_interval_steps: 500
+ max_to_keep: 4
+ profiler_options: {}
+
+rollout_config:
+ total_generation_steps: 768
+ max_prompt_length: 256
+ temperature: 0.9
+ top_p: 1.0
+ top_k: 50
+
+rollout_engine: "vanilla"
+offload_to_cpu: false
+
+grpo_config:
+ num_generations: 4
+ num_iterations: 1
+ beta: 0.08
+ epsilon: 0.2
+
+reward_functions:
+ - "tunix/cli/reward_fn/gsm8k.py"
diff --git a/examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml b/examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml
new file mode 100644
index 000000000..773790ff6
--- /dev/null
+++ b/examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml
@@ -0,0 +1,100 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+model_config:
+ rng_seed: 42
+ model_display: false
+ remat_config: 3
+
+actor_model_config:
+ mesh:
+ shape: "(8,1)"
+ axis_names: "('fsdp','tp')"
+
+rollout_model_config:
+ mesh:
+ shape: "(1,8)"
+ axis_names: "('fsdp','tp')"
+
+reference_model_config:
+ mesh: null
+ same_mesh_as: "actor"
+
+data_source: "huggingface"
+dataset_name: "openai/gsm8k:main"
+prompt_key: "question"
+
+training_mode: "agentic_grpo"
+num_test_batches: 100
+reward_functions:
+ - "tunix/cli/reward_fn/gsm8k.py"
+verl_compatible: false
+
+rollout_engine: "vllm"
+offload_to_cpu: false
+
+rollout_config:
+ max_prompt_length: 256
+ total_generation_steps: 768
+ max_tokens_to_generate: 768
+ temperature: 0.9
+ top_p: 1.0
+ top_k: 50
+ return_logprobs: true
+
+vllm_config:
+ hbm_utilization: 0.4
+ tpu_backend_type: "jax"
+ server_mode: true
+ async_scheduling: true
+ kwargs:
+ kv_cache_metrics: true
+ disable_log_stats: false
+ enable_prefix_caching: true
+
+chat_parser_config:
+ type: "qwen"
+
+tokenizer_config:
+ tokenizer_type: "huggingface"
+ add_bos: false
+ add_eos: false
+
+agentic_grpo_config:
+ num_iterations: 1
+ beta: 0.08
+ epsilon: 0.2
+ system_prompt: "You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ."
+ max_concurrency: 128
+ max_response_length: 768
+ max_turns: 1
+
+rl_training_config:
+ actor_optimizer_config:
+ opt_type: "adamw"
+ learning_rate: 3e-6
+ schedule_type: "warmup_cosine_decay_schedule"
+ init_value: 0.0
+ peak_value: 3e-6
+ end_value: 0.0
+ b1: 0.9
+ b2: 0.99
+ weight_decay: 0.1
+ max_grad_norm: 0.1
+ eval_every_n_steps: 10
+ checkpointing_options:
+ save_interval_steps: 250
+ max_to_keep: 4
+ metrics_logging_options:
+ flush_every_n_steps: 20
diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3.sh
index 05fb5077e..befad2749 100755
--- a/examples/rl/grpo/gsm8k/run_qwen3.sh
+++ b/examples/rl/grpo/gsm8k/run_qwen3.sh
@@ -31,60 +31,16 @@ echo " Train Fraction: $train_fraction"
echo " Checkpoint Directory: $checkpoint_dir"
python3 -m tunix.cli.grpo_main \
- base_config.yaml \
- model_config.model_name=${model_name} \
- model_config.model_id=Qwen/${model_name} \
- model_config.model_source=huggingface \
- model_config.use_flash_attention=true \
- model_config.flash_attention_block_size=256 \
+ tunix/cli/base_config.yaml \
+ override_config_file=examples/rl/grpo/gsm8k/configs/qwen3.yaml \
+ model_config.model_name="${model_name}" \
+ model_config.model_id="Qwen/${model_name}" \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.model_download_path="/tmp/models/${model_name}" \
- model_config.mesh.shape="(2,4)" \
- model_config.mesh.axis_names="('fsdp','tp')" \
- model_config.rng_seed=42 \
- actor_model_config.lora_config.rank=64 \
- actor_model_config.lora_config.alpha=64.0 \
- actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
- actor_model_config.mesh.shape="(2,4)" \
- actor_model_config.mesh.axis_names="('fsdp','tp')" \
- reference_model_config.mesh=null \
- reference_model_config.same_mesh_as="actor" \
- rollout_model_config.mesh=null \
- rollout_model_config.same_mesh_as="actor" \
- tokenizer_config.tokenizer_path=Qwen/${model_name} \
- tokenizer_config.tokenizer_type=huggingface \
- tokenizer_config.add_bos=false \
- dataset_name="gsm8k" \
+ tokenizer_config.tokenizer_path="Qwen/${model_name}" \
batch_size=$batch_size \
- num_test_batches=100 \
num_train_epochs=$num_train_epochs \
train_fraction=$train_fraction \
- rl_training_config.actor_optimizer_config.opt_type="adamw" \
- rl_training_config.actor_optimizer_config.peak_value=3e-6 \
- rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
- rl_training_config.actor_optimizer_config.init_value=0.0 \
- rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
- rl_training_config.actor_optimizer_config.b1=0.9 \
- rl_training_config.actor_optimizer_config.b2=0.99 \
- rl_training_config.actor_optimizer_config.weight_decay=0.1 \
- rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
- rl_training_config.eval_every_n_steps=10 \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
- rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
- rl_training_config.checkpoint_root_directory="$checkpoint_dir" \
- rl_training_config.checkpointing_options.save_interval_steps=500 \
- rl_training_config.checkpointing_options.max_to_keep=4 \
- rl_training_config.profiler_options={} \
- rollout_config.total_generation_steps=768 \
- rollout_config.max_prompt_length=256 \
- rollout_config.temperature=0.9 \
- rollout_config.top_p=1.0 \
- rollout_config.top_k=50 \
- rollout_engine="vanilla" \
- offload_to_cpu=false \
- grpo_config.num_generations=4 \
- grpo_config.num_iterations=1 \
- grpo_config.beta=0.08 \
- grpo_config.epsilon=0.2 \
- reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
+ rl_training_config.checkpoint_root_directory="$checkpoint_dir"
diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh
index fa343b222..72315bbec 100755
--- a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh
+++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh
@@ -72,103 +72,28 @@ vllm_max_num_seqs=$(awk "BEGIN {
python -m tunix.cli.grpo_main \
tunix/cli/base_agentic_config.yaml \
- \
- `# -- Model ------------------------------------------------------------` \
+ override_config_file=examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml \
model_config.model_name="$model_name" \
model_config.model_id="$model_id" \
model_config.model_source="huggingface" \
model_config.model_download_path="/tmp/models/${model_name}" \
- model_config.rng_seed=42 \
- model_config.model_display=false \
- model_config.remat_config=3 \
+ tokenizer_config.tokenizer_path="$tokenizer_path" \
actor_model_config.mesh.shape="$train_mesh" \
- actor_model_config.mesh.axis_names="('fsdp','tp')" \
- reference_model_config.mesh=null \
- reference_model_config.same_mesh_as="actor" \
rollout_model_config.mesh.shape="$rollout_mesh" \
- rollout_model_config.mesh.axis_names="('fsdp','tp')" \
- \
- `# -- Data -------------------------------------------------------------` \
- data_source="huggingface" \
- dataset_name="openai/gsm8k:main" \
- prompt_key="question" \
- \
- `# -- Training loop ----------------------------------------------------` \
- training_mode="agentic_grpo" \
batch_size="$batch_size" \
num_batches="$num_batches" \
- num_test_batches=100 \
num_train_epochs="$num_train_epochs" \
train_fraction="$train_fraction" \
- reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \
- verl_compatible=false \
- \
- `# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \
- rollout_engine="vllm" \
- offload_to_cpu=false \
- \
- `# -- Rollout config ---------------------------------------------------` \
- rollout_config.max_prompt_length=256 \
- rollout_config.total_generation_steps=768 \
- rollout_config.max_tokens_to_generate=768 \
- rollout_config.temperature=0.9 \
- rollout_config.top_p=1.0 \
- rollout_config.top_k=50 \
- rollout_config.return_logprobs=true \
- \
- `# -- vLLM (used when rollout_engine=vllm) -----------------------------` \
- vllm_config.hbm_utilization=0.4 \
- vllm_config.tpu_backend_type="jax" \
- vllm_config.server_mode=true \
- vllm_config.async_scheduling=true \
vllm_config.max_num_seqs="$vllm_max_num_seqs" \
- vllm_config.kwargs.kv_cache_metrics=true \
- vllm_config.kwargs.disable_log_stats=false \
- vllm_config.kwargs.enable_prefix_caching=true \
- \
- `# -- Tokenizer / chat parsing ----------------------------------------` \
- chat_parser_config.type="qwen" \
- tokenizer_config.tokenizer_type="huggingface" \
- tokenizer_config.tokenizer_path="$tokenizer_path" \
- tokenizer_config.add_bos=false \
- tokenizer_config.add_eos=false \
- \
- `# -- GRPO algorithm ---------------------------------------------------` \
agentic_grpo_config.num_generations="$num_generations" \
- agentic_grpo_config.num_iterations=1 \
- agentic_grpo_config.beta=0.08 \
- agentic_grpo_config.epsilon=0.2 \
- agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \
- agentic_grpo_config.max_concurrency=128 \
- agentic_grpo_config.max_response_length=768 \
- agentic_grpo_config.max_turns=1 \
- \
- `# -- Optimizer --------------------------------------------------------` \
- rl_training_config.actor_optimizer_config.opt_type="adamw" \
- rl_training_config.actor_optimizer_config.learning_rate=3e-6 \
- rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
- rl_training_config.actor_optimizer_config.init_value=0.0 \
- rl_training_config.actor_optimizer_config.peak_value=3e-6 \
- rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \
rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \
rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \
- rl_training_config.actor_optimizer_config.b1=0.9 \
- rl_training_config.actor_optimizer_config.b2=0.99 \
- rl_training_config.actor_optimizer_config.weight_decay=0.1 \
- rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
- \
- `# -- RL training ------------------------------------------------------` \
- rl_training_config.eval_every_n_steps=10 \
rl_training_config.max_steps="$max_steps" \
rl_training_config.mini_batch_size="$mini_batch_size" \
rl_training_config.train_micro_batch_size="$train_micro_batch_size" \
rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \
rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \
rl_training_config.checkpoint_root_directory="$checkpoint_dir" \
- rl_training_config.checkpointing_options.save_interval_steps=250 \
- rl_training_config.checkpointing_options.max_to_keep=4 \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b" \
- rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
- \
"$@"
diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh
index 33af305c0..4e39cbc2a 100755
--- a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh
+++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh
@@ -70,108 +70,31 @@ vllm_max_num_seqs=$(awk "BEGIN {
python -m tunix.cli.grpo_main \
tunix/cli/base_agentic_config.yaml \
- \
- `# -- Model ------------------------------------------------------------` \
+ override_config_file=examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml \
model_config.model_name="$model_name" \
model_config.model_id="$model_id" \
model_config.model_source="maxtext" \
- model_config.rng_seed=42 \
- model_config.model_display=false \
- model_config.remat_config=3 \
model_config.model_download_path="/tmp/models/${model_name}" \
- `# -- Maxtext specific configs mapping ---------------------------------` \
model_config.kwargs.base_emb_dim=4096 \
model_config.kwargs.sparse_matmul=true \
model_config.kwargs.remat_policy="minimal" \
- `# -- Mesh configurations ----------------------------------------------` \
+ tokenizer_config.tokenizer_path="$tokenizer_path" \
actor_model_config.mesh.shape="$train_mesh" \
- actor_model_config.mesh.axis_names="('fsdp','tp')" \
- reference_model_config.mesh=null \
- reference_model_config.same_mesh_as="actor" \
rollout_model_config.mesh.shape="$rollout_mesh" \
- rollout_model_config.mesh.axis_names="('fsdp','tp')" \
- \
- `# -- Data -------------------------------------------------------------` \
- data_source="huggingface" \
- dataset_name="openai/gsm8k:main" \
- prompt_key="question" \
- \
- `# -- Training loop ----------------------------------------------------` \
- training_mode="agentic_grpo" \
batch_size="$batch_size" \
num_batches="$num_batches" \
- num_test_batches=100 \
num_train_epochs="$num_train_epochs" \
train_fraction="$train_fraction" \
- reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \
- verl_compatible=false \
- \
- `# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \
- rollout_engine="vllm" \
- offload_to_cpu=false \
- \
- `# -- Rollout config ---------------------------------------------------` \
- rollout_config.max_prompt_length=256 \
- rollout_config.total_generation_steps=768 \
- rollout_config.max_tokens_to_generate=768 \
- rollout_config.temperature=0.9 \
- rollout_config.top_p=1.0 \
- rollout_config.top_k=50 \
- rollout_config.return_logprobs=true \
- \
- `# -- vLLM (used when rollout_engine=vllm) -----------------------------` \
- vllm_config.hbm_utilization=0.4 \
- vllm_config.tpu_backend_type="jax" \
- vllm_config.server_mode=true \
- vllm_config.async_scheduling=true \
vllm_config.max_num_seqs="$vllm_max_num_seqs" \
- vllm_config.kwargs.kv_cache_metrics=true \
- vllm_config.kwargs.disable_log_stats=false \
- vllm_config.kwargs.enable_prefix_caching=true \
- \
- `# -- Tokenizer / chat parsing ----------------------------------------` \
- chat_parser_config.type="qwen" \
- tokenizer_config.tokenizer_type="huggingface" \
- tokenizer_config.tokenizer_path="$tokenizer_path" \
- tokenizer_config.add_bos=false \
- tokenizer_config.add_eos=false \
- \
- `# -- GRPO algorithm ---------------------------------------------------` \
agentic_grpo_config.num_generations="$num_generations" \
- agentic_grpo_config.num_iterations=1 \
- agentic_grpo_config.beta=0.08 \
- agentic_grpo_config.epsilon=0.2 \
- agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \
- agentic_grpo_config.max_concurrency=128 \
- agentic_grpo_config.max_response_length=768 \
- agentic_grpo_config.max_turns=1 \
- \
- `# -- Optimizer --------------------------------------------------------` \
- rl_training_config.actor_optimizer_config.opt_type="adamw" \
- rl_training_config.actor_optimizer_config.learning_rate=3e-6 \
- rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
- rl_training_config.actor_optimizer_config.init_value=0.0 \
- rl_training_config.actor_optimizer_config.peak_value=3e-6 \
- rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \
rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \
rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \
- rl_training_config.actor_optimizer_config.b1=0.9 \
- rl_training_config.actor_optimizer_config.b2=0.99 \
- rl_training_config.actor_optimizer_config.weight_decay=0.1 \
- rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
- \
- `# -- RL training ------------------------------------------------------` \
- rl_training_config.eval_every_n_steps=10 \
rl_training_config.max_steps="$max_steps" \
rl_training_config.mini_batch_size="$mini_batch_size" \
rl_training_config.train_micro_batch_size="$train_micro_batch_size" \
rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \
rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \
rl_training_config.checkpoint_root_directory="$checkpoint_dir" \
- rl_training_config.checkpointing_options.save_interval_steps=250 \
- rl_training_config.checkpointing_options.max_to_keep=4 \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b_maxtext" \
- rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
- \
"$@"
diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh
index b28f4aad3..753cd9088 100644
--- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh
+++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh
@@ -38,63 +38,20 @@ echo "Max steps: $max_steps"
echo "Rounded warmup steps: $warmup_steps"
python3 -m tunix.cli.grpo_main \
- base_config.yaml \
- model_config.model_name=${model_name} \
- model_config.model_id=Qwen/${model_name} \
- model_config.model_source=huggingface \
- model_config.use_flash_attention=true \
- model_config.flash_attention_block_size=256 \
+ tunix/cli/base_config.yaml \
+ override_config_file=examples/rl/grpo/gsm8k/configs/qwen3.yaml \
+ model_config.model_name="${model_name}" \
+ model_config.model_id="Qwen/${model_name}" \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
model_config.model_download_path="/tmp/models/${model_name}" \
- model_config.mesh.shape="(2,4)" \
- model_config.mesh.axis_names="('fsdp','tp')" \
- model_config.rng_seed=42 \
- actor_model_config.lora_config.rank=64 \
- actor_model_config.lora_config.alpha=64.0 \
- actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
- actor_model_config.mesh.shape="(2,4)" \
- actor_model_config.mesh.axis_names="('fsdp','tp')" \
- reference_model_config.mesh=null \
- reference_model_config.same_mesh_as="actor" \
- rollout_model_config.mesh=null \
- rollout_model_config.same_mesh_as="actor" \
- tokenizer_config.tokenizer_path=Qwen/${model_name} \
- tokenizer_config.tokenizer_type=huggingface \
- tokenizer_config.add_bos=false \
- dataset_name="gsm8k" \
+ tokenizer_config.tokenizer_path="Qwen/${model_name}" \
batch_size=$batch_size \
num_batches=$num_batches \
- num_test_batches=100 \
num_train_epochs=$num_train_epochs \
train_fraction=$train_fraction \
- rl_training_config.actor_optimizer_config.opt_type="adamw" \
- rl_training_config.actor_optimizer_config.peak_value=3e-6 \
- rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
- rl_training_config.actor_optimizer_config.init_value=0.0 \
- rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
- rl_training_config.actor_optimizer_config.b1=0.9 \
- rl_training_config.actor_optimizer_config.b2=0.99 \
- rl_training_config.actor_optimizer_config.weight_decay=0.1 \
- rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
- rl_training_config.eval_every_n_steps=10 \
rl_training_config.max_steps=$max_steps \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
- rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
- rl_training_config.checkpointing_options.save_interval_steps=500 \
- rl_training_config.checkpointing_options.max_to_keep=4 \
- rl_training_config.profiler_options={} \
- rollout_config.total_generation_steps=768 \
- rollout_config.max_prompt_length=256 \
- rollout_config.temperature=0.9 \
- rollout_config.top_p=1.0 \
- rollout_config.top_k=50 \
- rollout_engine="vanilla" \
- offload_to_cpu=false \
- grpo_config.num_generations=4 \
- grpo_config.num_iterations=1 \
- grpo_config.beta=0.08 \
- grpo_config.epsilon=0.2 \
reward_functions="['tunix/cli/reward_fn/simple_math.py']"
diff --git a/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh b/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh
index d3e5563bf..cd8f0fae1 100644
--- a/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh
+++ b/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh
@@ -1,4 +1,4 @@
-# Copyright 2025 Google LLC
+# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,54 +35,20 @@ echo " Rollout Mesh Shape: $rollout_mesh_shape"
echo " Checkpoint Directory: $checkpoint_dir"
python3 -m tunix.cli.grpo_main \
- base_config.yaml \
- model_config.model_name=${model_name} \
- model_config.model_id=Qwen/${model_name} \
- model_config.model_source=huggingface \
- model_config.use_flash_attention=true \
- model_config.flash_attention_block_size=256 \
+ tunix/cli/base_config.yaml \
+ override_config_file=examples/rl/grpo/gsm8k/configs/qwen3.yaml \
+ model_config.model_name="${model_name}" \
+ model_config.model_id="Qwen/${model_name}" \
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
- model_config.rng_seed=42 \
- actor_model_config.mesh.shape=${actor_mesh_shape} \
- actor_model_config.mesh.axis_names="('fsdp','tp')" \
- reference_model_config.mesh=null \
- reference_model_config.same_mesh_as="actor" \
- rollout_model_config.mesh.shape=${rollout_mesh_shape} \
- rollout_model_config.mesh.axis_names="('fsdp','tp')" \
- tokenizer_config.tokenizer_path=Qwen/${model_name} \
- tokenizer_config.tokenizer_type=huggingface \
- tokenizer_config.add_bos=false \
- dataset_name="gsm8k" \
+ actor_model_config.lora_config={} \
+ actor_model_config.mesh.shape="${actor_mesh_shape}" \
+ rollout_model_config.mesh.shape="${rollout_mesh_shape}" \
+ tokenizer_config.tokenizer_path="Qwen/${model_name}" \
batch_size=$batch_size \
- num_test_batches=100 \
num_train_epochs=$num_train_epochs \
- rl_training_config.actor_optimizer_config.opt_type="adamw" \
- rl_training_config.actor_optimizer_config.peak_value=3e-6 \
- rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
- rl_training_config.actor_optimizer_config.init_value=0.0 \
- rl_training_config.actor_optimizer_config.end_value=0.0 \
+ train_fraction=$train_fraction \
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
- rl_training_config.actor_optimizer_config.b1=0.9 \
- rl_training_config.actor_optimizer_config.b2=0.99 \
- rl_training_config.actor_optimizer_config.weight_decay=0.1 \
- rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
- rl_training_config.eval_every_n_steps=10 \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
- rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
rl_training_config.checkpoint_root_directory="$checkpoint_dir" \
- rl_training_config.checkpointing_options.save_interval_steps=500 \
- rl_training_config.checkpointing_options.max_to_keep=4 \
- rl_training_config.profiler_options={} \
- rollout_config.total_generation_steps=768 \
- rollout_config.max_prompt_length=256 \
- rollout_config.temperature=0.9 \
- rollout_config.top_p=1.0 \
- rollout_config.top_k=50 \
rollout_engine="vllm" \
- vllm_config.async_scheduling=false \
- offload_to_cpu=false \
- grpo_config.num_generations=4 \
- grpo_config.num_iterations=1 \
- grpo_config.beta=0.08 \
- grpo_config.epsilon=0.2 \
- reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
+ vllm_config.async_scheduling=false