Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions scripts/run-qwen3.5-35B-A3B-sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ CKPT_ARGS=(
--save-interval 20
)

SFT_MAX_PROMPT_TOKENS=${SFT_MAX_PROMPT_TOKENS:-8192}
MAX_TOKENS_PER_GPU=${MAX_TOKENS_PER_GPU:-8192}
CP_SIZE=${CP_SIZE:-1}
echo "SFT_MAX_PROMPT_TOKENS: ${SFT_MAX_PROMPT_TOKENS}"
echo "MAX_TOKENS_PER_GPU: ${MAX_TOKENS_PER_GPU}"
echo "CP_SIZE: ${CP_SIZE}"

if (( SFT_MAX_PROMPT_TOKENS > MAX_TOKENS_PER_GPU * CP_SIZE )); then
echo "Invalid config: SFT_MAX_PROMPT_TOKENS (${SFT_MAX_PROMPT_TOKENS}) must be <= MAX_TOKENS_PER_GPU * CP_SIZE (${MAX_TOKENS_PER_GPU} * ${CP_SIZE} = $((MAX_TOKENS_PER_GPU * CP_SIZE)))."
exit 1
fi

SFT_ARGS=(
--rollout-function-path slime.rollout.sft_rollout.generate_rollout
--prompt-data ${BASE_FOLDER}/openhermes2_5.parquet
Expand All @@ -54,6 +66,7 @@ SFT_ARGS=(
--num-epoch 3
--rollout-batch-size 128
--global-batch-size 128
--rollout-max-prompt-len ${SFT_MAX_PROMPT_TOKENS}

--loss-type sft_loss
--loss-mask-type qwen3_5
Expand All @@ -66,7 +79,7 @@ PERF_ARGS=(
--tensor-model-parallel-size 2
--sequence-parallel
--pipeline-model-parallel-size 1
--context-parallel-size 1
--context-parallel-size ${CP_SIZE}
--expert-model-parallel-size 8
--expert-tensor-parallel-size 1

Expand All @@ -76,7 +89,7 @@ PERF_ARGS=(

# --micro-batch-size 1
--use-dynamic-batch-size
--max-tokens-per-gpu 8192
--max-tokens-per-gpu ${MAX_TOKENS_PER_GPU}
)

OPTIMIZER_ARGS=(
Expand All @@ -88,7 +101,7 @@ OPTIMIZER_ARGS=(
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98

--use-distributed-optimizer
--optimizer-cpu-offload
--overlap-cpu-optimizer-d2h-h2d
Expand Down
30 changes: 26 additions & 4 deletions slime/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,19 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l
if max_length is None:
return origin_samples

if not isinstance(origin_samples[0].prompt, str):
# Keep filtering compatible with both text datasets and chat-message datasets used by SFT.
first_prompt = origin_samples[0].prompt
supports_text_prompt = isinstance(first_prompt, str)
supports_message_prompt = isinstance(first_prompt, list)
if not (supports_text_prompt or supports_message_prompt):
# Unknown prompt shapes are skipped to avoid breaking custom data pipelines.
logger.warning(
"Skipping max_length check for list prompt. Set apply_chat_template=True to enable length filtering."
"Skipping max_length check for unsupported prompt type %s. Set apply_chat_template=True to enable length filtering.",
type(first_prompt),
)
return origin_samples

if processor:
if processor and supports_text_prompt:
# Use processor only for samples with actual multimodal content; use batched tokenizer for text-only.
text_only = []
multimodal = []
Expand All @@ -113,14 +119,30 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l
input_ids = processor_output["input_ids"][0]
if len(input_ids) <= max_length:
filtered_samples.append(sample)
else:
elif supports_text_prompt:
# Fast path for text-only prompts: tokenize in batch for throughput.
prompts = [sample.prompt for sample in origin_samples]
input_ids_list = tokenizer(prompts, add_special_tokens=False)["input_ids"]
filtered_samples = [
sample
for sample, input_ids in zip(origin_samples, input_ids_list, strict=True)
if len(input_ids) <= max_length
]
else:
# `messages` datasets used by SFT rollout pass a list[dict] conversation as prompt.
# Use chat template tokenization to keep filtering behavior aligned with training tokenization.
filtered_samples = []
for sample in origin_samples:
tools = sample.metadata.get("tools") if isinstance(sample.metadata, dict) else None
input_ids = tokenizer.apply_chat_template(
sample.prompt,
tokenize=True,
add_generation_prompt=False,
tools=tools,
return_dict=False,
)
if len(input_ids) <= max_length:
filtered_samples.append(sample)

logger.info(f"Filtered {len(origin_samples) - len(filtered_samples)} samples longer than max_length={max_length}.")

Expand Down