From 18314deba5797d0ad55afeb225543650f0957e18 Mon Sep 17 00:00:00 2001 From: Shadi Noghabi Date: Wed, 13 May 2026 16:50:01 -0700 Subject: [PATCH] update qwen3 scripts to use yaml configs PiperOrigin-RevId: 915129002 --- examples/rl/grpo/gsm8k/configs/qwen3.yaml | 89 ++++++++++++++++ .../rl/grpo/gsm8k/configs/qwen3_disagg.yaml | 100 ++++++++++++++++++ examples/rl/grpo/gsm8k/run_qwen3.sh | 56 ++-------- examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh | 79 +------------- .../grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh | 81 +------------- .../rl/grpo/gsm8k/run_qwen3_simplereward.sh | 53 +--------- .../rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh | 56 ++-------- 7 files changed, 215 insertions(+), 299 deletions(-) create mode 100644 examples/rl/grpo/gsm8k/configs/qwen3.yaml create mode 100644 examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml diff --git a/examples/rl/grpo/gsm8k/configs/qwen3.yaml b/examples/rl/grpo/gsm8k/configs/qwen3.yaml new file mode 100644 index 000000000..71d84854c --- /dev/null +++ b/examples/rl/grpo/gsm8k/configs/qwen3.yaml @@ -0,0 +1,89 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_config: + model_name: "Qwen3-1.7B-base" + model_id: "Qwen/Qwen3-1.7B-base" + model_source: "huggingface" + use_flash_attention: true + flash_attention_block_size: 256 + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" + rng_seed: 42 + +actor_model_config: + lora_config: + rank: 64 + alpha: 64.0 + module_path: ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" + +reference_model_config: + mesh: null + same_mesh_as: "actor" + +rollout_model_config: + mesh: null + same_mesh_as: "actor" + +tokenizer_config: + tokenizer_type: "huggingface" + add_bos: false + +dataset_name: "gsm8k" +batch_size: 8 +num_test_batches: 100 +num_train_epochs: 1 + +rl_training_config: + actor_optimizer_config: + opt_type: "adamw" + peak_value: 3e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + b1: 0.9 + b2: 0.99 + weight_decay: 0.1 + max_grad_norm: 0.1 + eval_every_n_steps: 10 + metrics_logging_options: + flush_every_n_steps: 20 + checkpointing_options: + save_interval_steps: 500 + max_to_keep: 4 + profiler_options: {} + +rollout_config: + total_generation_steps: 768 + max_prompt_length: 256 + temperature: 0.9 + top_p: 1.0 + top_k: 50 + +rollout_engine: "vanilla" +offload_to_cpu: false + +grpo_config: + num_generations: 4 + num_iterations: 1 + beta: 0.08 + epsilon: 0.2 + +reward_functions: + - "tunix/cli/reward_fn/gsm8k.py" diff --git a/examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml b/examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml new file mode 100644 index 000000000..773790ff6 --- /dev/null +++ b/examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml @@ -0,0 +1,100 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_config: + rng_seed: 42 + model_display: false + remat_config: 3 + +actor_model_config: + mesh: + shape: "(8,1)" + axis_names: "('fsdp','tp')" + +rollout_model_config: + mesh: + shape: "(1,8)" + axis_names: "('fsdp','tp')" + +reference_model_config: + mesh: null + same_mesh_as: "actor" + +data_source: "huggingface" +dataset_name: "openai/gsm8k:main" +prompt_key: "question" + +training_mode: "agentic_grpo" +num_test_batches: 100 +reward_functions: + - "tunix/cli/reward_fn/gsm8k.py" +verl_compatible: false + +rollout_engine: "vllm" +offload_to_cpu: false + +rollout_config: + max_prompt_length: 256 + total_generation_steps: 768 + max_tokens_to_generate: 768 + temperature: 0.9 + top_p: 1.0 + top_k: 50 + return_logprobs: true + +vllm_config: + hbm_utilization: 0.4 + tpu_backend_type: "jax" + server_mode: true + async_scheduling: true + kwargs: + kv_cache_metrics: true + disable_log_stats: false + enable_prefix_caching: true + +chat_parser_config: + type: "qwen" + +tokenizer_config: + tokenizer_type: "huggingface" + add_bos: false + add_eos: false + +agentic_grpo_config: + num_iterations: 1 + beta: 0.08 + epsilon: 0.2 + system_prompt: "You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." + max_concurrency: 128 + max_response_length: 768 + max_turns: 1 + +rl_training_config: + actor_optimizer_config: + opt_type: "adamw" + learning_rate: 3e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + peak_value: 3e-6 + end_value: 0.0 + b1: 0.9 + b2: 0.99 + weight_decay: 0.1 + max_grad_norm: 0.1 + eval_every_n_steps: 10 + checkpointing_options: + save_interval_steps: 250 + max_to_keep: 4 + metrics_logging_options: + flush_every_n_steps: 20 diff --git a/examples/rl/grpo/gsm8k/run_qwen3.sh b/examples/rl/grpo/gsm8k/run_qwen3.sh index 05fb5077e..befad2749 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3.sh @@ -31,60 +31,16 @@ echo " Train Fraction: $train_fraction" echo " Checkpoint Directory: $checkpoint_dir" python3 -m tunix.cli.grpo_main \ - base_config.yaml \ - model_config.model_name=${model_name} \ - model_config.model_id=Qwen/${model_name} \ - model_config.model_source=huggingface \ - model_config.use_flash_attention=true \ - model_config.flash_attention_block_size=256 \ + tunix/cli/base_config.yaml \ + override_config_file=examples/rl/grpo/gsm8k/configs/qwen3.yaml \ + model_config.model_name="${model_name}" \ + model_config.model_id="Qwen/${model_name}" \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ model_config.model_download_path="/tmp/models/${model_name}" \ - model_config.mesh.shape="(2,4)" \ - model_config.mesh.axis_names="('fsdp','tp')" \ - model_config.rng_seed=42 \ - actor_model_config.lora_config.rank=64 \ - actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ - actor_model_config.mesh.shape="(2,4)" \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ - rollout_model_config.mesh=null \ - rollout_model_config.same_mesh_as="actor" \ - tokenizer_config.tokenizer_path=Qwen/${model_name} \ - tokenizer_config.tokenizer_type=huggingface \ - tokenizer_config.add_bos=false \ - dataset_name="gsm8k" \ + tokenizer_config.tokenizer_path="Qwen/${model_name}" \ batch_size=$batch_size \ - num_test_batches=100 \ num_train_epochs=$num_train_epochs \ train_fraction=$train_fraction \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ - rl_training_config.checkpoint_root_directory="$checkpoint_dir" \ - rl_training_config.checkpointing_options.save_interval_steps=500 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ - rl_training_config.profiler_options={} \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_prompt_length=256 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ - rollout_engine="vanilla" \ - offload_to_cpu=false \ - grpo_config.num_generations=4 \ - grpo_config.num_iterations=1 \ - grpo_config.beta=0.08 \ - grpo_config.epsilon=0.2 \ - reward_functions="['tunix/cli/reward_fn/gsm8k.py']" + rl_training_config.checkpoint_root_directory="$checkpoint_dir" diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh index fa343b222..72315bbec 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh @@ -72,103 +72,28 @@ vllm_max_num_seqs=$(awk "BEGIN { python -m tunix.cli.grpo_main \ tunix/cli/base_agentic_config.yaml \ - \ - `# -- Model ------------------------------------------------------------` \ + override_config_file=examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml \ model_config.model_name="$model_name" \ model_config.model_id="$model_id" \ model_config.model_source="huggingface" \ model_config.model_download_path="/tmp/models/${model_name}" \ - model_config.rng_seed=42 \ - model_config.model_display=false \ - model_config.remat_config=3 \ + tokenizer_config.tokenizer_path="$tokenizer_path" \ actor_model_config.mesh.shape="$train_mesh" \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ rollout_model_config.mesh.shape="$rollout_mesh" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ - \ - `# -- Data -------------------------------------------------------------` \ - data_source="huggingface" \ - dataset_name="openai/gsm8k:main" \ - prompt_key="question" \ - \ - `# -- Training loop ----------------------------------------------------` \ - training_mode="agentic_grpo" \ batch_size="$batch_size" \ num_batches="$num_batches" \ - num_test_batches=100 \ num_train_epochs="$num_train_epochs" \ train_fraction="$train_fraction" \ - reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \ - verl_compatible=false \ - \ - `# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \ - rollout_engine="vllm" \ - offload_to_cpu=false \ - \ - `# -- Rollout config ---------------------------------------------------` \ - rollout_config.max_prompt_length=256 \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_tokens_to_generate=768 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ - rollout_config.return_logprobs=true \ - \ - `# -- vLLM (used when rollout_engine=vllm) -----------------------------` \ - vllm_config.hbm_utilization=0.4 \ - vllm_config.tpu_backend_type="jax" \ - vllm_config.server_mode=true \ - vllm_config.async_scheduling=true \ vllm_config.max_num_seqs="$vllm_max_num_seqs" \ - vllm_config.kwargs.kv_cache_metrics=true \ - vllm_config.kwargs.disable_log_stats=false \ - vllm_config.kwargs.enable_prefix_caching=true \ - \ - `# -- Tokenizer / chat parsing ----------------------------------------` \ - chat_parser_config.type="qwen" \ - tokenizer_config.tokenizer_type="huggingface" \ - tokenizer_config.tokenizer_path="$tokenizer_path" \ - tokenizer_config.add_bos=false \ - tokenizer_config.add_eos=false \ - \ - `# -- GRPO algorithm ---------------------------------------------------` \ agentic_grpo_config.num_generations="$num_generations" \ - agentic_grpo_config.num_iterations=1 \ - agentic_grpo_config.beta=0.08 \ - agentic_grpo_config.epsilon=0.2 \ - agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \ - agentic_grpo_config.max_concurrency=128 \ - agentic_grpo_config.max_response_length=768 \ - agentic_grpo_config.max_turns=1 \ - \ - `# -- Optimizer --------------------------------------------------------` \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.learning_rate=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \ rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \ rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - \ - `# -- RL training ------------------------------------------------------` \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.max_steps="$max_steps" \ rl_training_config.mini_batch_size="$mini_batch_size" \ rl_training_config.train_micro_batch_size="$train_micro_batch_size" \ rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \ rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \ rl_training_config.checkpoint_root_directory="$checkpoint_dir" \ - rl_training_config.checkpointing_options.save_interval_steps=250 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ - \ "$@" diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh index 33af305c0..4e39cbc2a 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b_disagg_maxtext.sh @@ -70,108 +70,31 @@ vllm_max_num_seqs=$(awk "BEGIN { python -m tunix.cli.grpo_main \ tunix/cli/base_agentic_config.yaml \ - \ - `# -- Model ------------------------------------------------------------` \ + override_config_file=examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml \ model_config.model_name="$model_name" \ model_config.model_id="$model_id" \ model_config.model_source="maxtext" \ - model_config.rng_seed=42 \ - model_config.model_display=false \ - model_config.remat_config=3 \ model_config.model_download_path="/tmp/models/${model_name}" \ - `# -- Maxtext specific configs mapping ---------------------------------` \ model_config.kwargs.base_emb_dim=4096 \ model_config.kwargs.sparse_matmul=true \ model_config.kwargs.remat_policy="minimal" \ - `# -- Mesh configurations ----------------------------------------------` \ + tokenizer_config.tokenizer_path="$tokenizer_path" \ actor_model_config.mesh.shape="$train_mesh" \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ rollout_model_config.mesh.shape="$rollout_mesh" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ - \ - `# -- Data -------------------------------------------------------------` \ - data_source="huggingface" \ - dataset_name="openai/gsm8k:main" \ - prompt_key="question" \ - \ - `# -- Training loop ----------------------------------------------------` \ - training_mode="agentic_grpo" \ batch_size="$batch_size" \ num_batches="$num_batches" \ - num_test_batches=100 \ num_train_epochs="$num_train_epochs" \ train_fraction="$train_fraction" \ - reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \ - verl_compatible=false \ - \ - `# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \ - rollout_engine="vllm" \ - offload_to_cpu=false \ - \ - `# -- Rollout config ---------------------------------------------------` \ - rollout_config.max_prompt_length=256 \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_tokens_to_generate=768 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ - rollout_config.return_logprobs=true \ - \ - `# -- vLLM (used when rollout_engine=vllm) -----------------------------` \ - vllm_config.hbm_utilization=0.4 \ - vllm_config.tpu_backend_type="jax" \ - vllm_config.server_mode=true \ - vllm_config.async_scheduling=true \ vllm_config.max_num_seqs="$vllm_max_num_seqs" \ - vllm_config.kwargs.kv_cache_metrics=true \ - vllm_config.kwargs.disable_log_stats=false \ - vllm_config.kwargs.enable_prefix_caching=true \ - \ - `# -- Tokenizer / chat parsing ----------------------------------------` \ - chat_parser_config.type="qwen" \ - tokenizer_config.tokenizer_type="huggingface" \ - tokenizer_config.tokenizer_path="$tokenizer_path" \ - tokenizer_config.add_bos=false \ - tokenizer_config.add_eos=false \ - \ - `# -- GRPO algorithm ---------------------------------------------------` \ agentic_grpo_config.num_generations="$num_generations" \ - agentic_grpo_config.num_iterations=1 \ - agentic_grpo_config.beta=0.08 \ - agentic_grpo_config.epsilon=0.2 \ - agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \ - agentic_grpo_config.max_concurrency=128 \ - agentic_grpo_config.max_response_length=768 \ - agentic_grpo_config.max_turns=1 \ - \ - `# -- Optimizer --------------------------------------------------------` \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.learning_rate=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \ rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \ rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - \ - `# -- RL training ------------------------------------------------------` \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.max_steps="$max_steps" \ rl_training_config.mini_batch_size="$mini_batch_size" \ rl_training_config.train_micro_batch_size="$train_micro_batch_size" \ rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \ rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \ rl_training_config.checkpoint_root_directory="$checkpoint_dir" \ - rl_training_config.checkpointing_options.save_interval_steps=250 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b_maxtext" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ - \ "$@" diff --git a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh index b28f4aad3..753cd9088 100644 --- a/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh @@ -38,63 +38,20 @@ echo "Max steps: $max_steps" echo "Rounded warmup steps: $warmup_steps" python3 -m tunix.cli.grpo_main \ - base_config.yaml \ - model_config.model_name=${model_name} \ - model_config.model_id=Qwen/${model_name} \ - model_config.model_source=huggingface \ - model_config.use_flash_attention=true \ - model_config.flash_attention_block_size=256 \ + tunix/cli/base_config.yaml \ + override_config_file=examples/rl/grpo/gsm8k/configs/qwen3.yaml \ + model_config.model_name="${model_name}" \ + model_config.model_id="Qwen/${model_name}" \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ model_config.model_download_path="/tmp/models/${model_name}" \ - model_config.mesh.shape="(2,4)" \ - model_config.mesh.axis_names="('fsdp','tp')" \ - model_config.rng_seed=42 \ - actor_model_config.lora_config.rank=64 \ - actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \ - actor_model_config.mesh.shape="(2,4)" \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ - rollout_model_config.mesh=null \ - rollout_model_config.same_mesh_as="actor" \ - tokenizer_config.tokenizer_path=Qwen/${model_name} \ - tokenizer_config.tokenizer_type=huggingface \ - tokenizer_config.add_bos=false \ - dataset_name="gsm8k" \ + tokenizer_config.tokenizer_path="Qwen/${model_name}" \ batch_size=$batch_size \ num_batches=$num_batches \ - num_test_batches=100 \ num_train_epochs=$num_train_epochs \ train_fraction=$train_fraction \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \ rl_training_config.actor_optimizer_config.decay_steps=$max_steps \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.max_steps=$max_steps \ rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ - rl_training_config.checkpointing_options.save_interval_steps=500 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ - rl_training_config.profiler_options={} \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_prompt_length=256 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ - rollout_engine="vanilla" \ - offload_to_cpu=false \ - grpo_config.num_generations=4 \ - grpo_config.num_iterations=1 \ - grpo_config.beta=0.08 \ - grpo_config.epsilon=0.2 \ reward_functions="['tunix/cli/reward_fn/simple_math.py']" diff --git a/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh b/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh index d3e5563bf..cd8f0fae1 100644 --- a/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_vllm_disagg.sh @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,54 +35,20 @@ echo " Rollout Mesh Shape: $rollout_mesh_shape" echo " Checkpoint Directory: $checkpoint_dir" python3 -m tunix.cli.grpo_main \ - base_config.yaml \ - model_config.model_name=${model_name} \ - model_config.model_id=Qwen/${model_name} \ - model_config.model_source=huggingface \ - model_config.use_flash_attention=true \ - model_config.flash_attention_block_size=256 \ + tunix/cli/base_config.yaml \ + override_config_file=examples/rl/grpo/gsm8k/configs/qwen3.yaml \ + model_config.model_name="${model_name}" \ + model_config.model_id="Qwen/${model_name}" \ model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \ - model_config.rng_seed=42 \ - actor_model_config.mesh.shape=${actor_mesh_shape} \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.mesh=null \ - reference_model_config.same_mesh_as="actor" \ - rollout_model_config.mesh.shape=${rollout_mesh_shape} \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ - tokenizer_config.tokenizer_path=Qwen/${model_name} \ - tokenizer_config.tokenizer_type=huggingface \ - tokenizer_config.add_bos=false \ - dataset_name="gsm8k" \ + actor_model_config.lora_config={} \ + actor_model_config.mesh.shape="${actor_mesh_shape}" \ + rollout_model_config.mesh.shape="${rollout_mesh_shape}" \ + tokenizer_config.tokenizer_path="Qwen/${model_name}" \ batch_size=$batch_size \ - num_test_batches=100 \ num_train_epochs=$num_train_epochs \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ + train_fraction=$train_fraction \ rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ rl_training_config.checkpoint_root_directory="$checkpoint_dir" \ - rl_training_config.checkpointing_options.save_interval_steps=500 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ - rl_training_config.profiler_options={} \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_prompt_length=256 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ rollout_engine="vllm" \ - vllm_config.async_scheduling=false \ - offload_to_cpu=false \ - grpo_config.num_generations=4 \ - grpo_config.num_iterations=1 \ - grpo_config.beta=0.08 \ - grpo_config.epsilon=0.2 \ - reward_functions="['tunix/cli/reward_fn/gsm8k.py']" + vllm_config.async_scheduling=false