diff --git a/scripts/run-qwen3.5-35B-A3B-sft.sh b/scripts/run-qwen3.5-35B-A3B-sft.sh index 6893133924..b5910ae5b8 100644 --- a/scripts/run-qwen3.5-35B-A3B-sft.sh +++ b/scripts/run-qwen3.5-35B-A3B-sft.sh @@ -46,6 +46,18 @@ CKPT_ARGS=( --save-interval 20 ) +SFT_MAX_PROMPT_TOKENS=${SFT_MAX_PROMPT_TOKENS:-8192} +MAX_TOKENS_PER_GPU=${MAX_TOKENS_PER_GPU:-8192} +CP_SIZE=${CP_SIZE:-1} +echo "SFT_MAX_PROMPT_TOKENS: ${SFT_MAX_PROMPT_TOKENS}" +echo "MAX_TOKENS_PER_GPU: ${MAX_TOKENS_PER_GPU}" +echo "CP_SIZE: ${CP_SIZE}" + +if (( SFT_MAX_PROMPT_TOKENS > MAX_TOKENS_PER_GPU * CP_SIZE )); then + echo "Invalid config: SFT_MAX_PROMPT_TOKENS (${SFT_MAX_PROMPT_TOKENS}) must be <= MAX_TOKENS_PER_GPU * CP_SIZE (${MAX_TOKENS_PER_GPU} * ${CP_SIZE} = $((MAX_TOKENS_PER_GPU * CP_SIZE)))." + exit 1 +fi + SFT_ARGS=( --rollout-function-path slime.rollout.sft_rollout.generate_rollout --prompt-data ${BASE_FOLDER}/openhermes2_5.parquet @@ -54,6 +66,7 @@ SFT_ARGS=( --num-epoch 3 --rollout-batch-size 128 --global-batch-size 128 + --rollout-max-prompt-len ${SFT_MAX_PROMPT_TOKENS} --loss-type sft_loss --loss-mask-type qwen3_5 @@ -66,7 +79,7 @@ PERF_ARGS=( --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 - --context-parallel-size 1 + --context-parallel-size ${CP_SIZE} --expert-model-parallel-size 8 --expert-tensor-parallel-size 1 @@ -76,7 +89,7 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size - --max-tokens-per-gpu 8192 + --max-tokens-per-gpu ${MAX_TOKENS_PER_GPU} ) OPTIMIZER_ARGS=( @@ -88,7 +101,7 @@ OPTIMIZER_ARGS=( --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98 - + --use-distributed-optimizer --optimizer-cpu-offload --overlap-cpu-optimizer-d2h-h2d diff --git a/slime/utils/data.py b/slime/utils/data.py index 4bb81e5677..74cdce0301 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -82,13 +82,19 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l if max_length is None: return origin_samples - if not isinstance(origin_samples[0].prompt, str): + # Keep filtering compatible with both text datasets and chat-message datasets used by SFT. + first_prompt = origin_samples[0].prompt + supports_text_prompt = isinstance(first_prompt, str) + supports_message_prompt = isinstance(first_prompt, list) + if not (supports_text_prompt or supports_message_prompt): + # Unknown prompt shapes are skipped to avoid breaking custom data pipelines. logger.warning( - "Skipping max_length check for list prompt. Set apply_chat_template=True to enable length filtering." + "Skipping max_length check for unsupported prompt type %s. Set apply_chat_template=True to enable length filtering.", + type(first_prompt), ) return origin_samples - if processor: + if processor and supports_text_prompt: # Use processor only for samples with actual multimodal content; use batched tokenizer for text-only. text_only = [] multimodal = [] @@ -113,7 +119,8 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l input_ids = processor_output["input_ids"][0] if len(input_ids) <= max_length: filtered_samples.append(sample) - else: + elif supports_text_prompt: + # Fast path for text-only prompts: tokenize in batch for throughput. prompts = [sample.prompt for sample in origin_samples] input_ids_list = tokenizer(prompts, add_special_tokens=False)["input_ids"] filtered_samples = [ @@ -121,6 +128,21 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l for sample, input_ids in zip(origin_samples, input_ids_list, strict=True) if len(input_ids) <= max_length ] + else: + # `messages` datasets used by SFT rollout pass a list[dict] conversation as prompt. + # Use chat template tokenization to keep filtering behavior aligned with training tokenization. + filtered_samples = [] + for sample in origin_samples: + tools = sample.metadata.get("tools") if isinstance(sample.metadata, dict) else None + input_ids = tokenizer.apply_chat_template( + sample.prompt, + tokenize=True, + add_generation_prompt=False, + tools=tools, + return_dict=False, + ) + if len(input_ids) <= max_length: + filtered_samples.append(sample) logger.info(f"Filtered {len(origin_samples) - len(filtered_samples)} samples longer than max_length={max_length}.")