Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions scripts/run-glm5-744B-A40B-sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/bin/bash

# for rerun the task
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python

set -ex

if [ -z "${BASE_DIR}" ]; then
echo "BASE_DIR is not set. Please set it to the base directory of your checkpoints."
exit 1
fi

if [ -z "${MASTER_ADDR}" ]; then
echo "MASTER_ADDR is not set. Please set it to the master node address."
exit 1
fi

# will prevent ray from buffering stdout/stderr
export PYTHONBUFFERED=16

NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
HAS_NVLINK=1
else
HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
source "${SCRIPT_DIR}/models/glm5-744B-A40B.sh"

CKPT_ARGS=(
--hf-checkpoint $BASE_DIR/GLM-5
--ref-load $BASE_DIR/GLM-5_torch_dist/
--load $BASE_DIR/GLM-5_slime/
--save $BASE_DIR/GLM-5_slime/
--save-interval 1000
)

SFT_ARGS=(
--rollout-function-path slime.rollout.sft_rollout.generate_rollout
--prompt-data $BASE_DIR/openhermes2_5.parquet
--input-key messages
--rollout-shuffle
--num-epoch 3
--rollout-batch-size 128
--global-batch-size 128

--loss-type sft_loss
--loss-mask-type glm5
--calculate-per-token-loss
--disable-compute-advantages-and-returns
--debug-train-only
)

PERF_ARGS=(
--tensor-model-parallel-size 4
--sequence-parallel
--pipeline-model-parallel-size 4
--decoder-last-pipeline-num-layers 18
--expert-model-parallel-size 32
--expert-tensor-parallel-size 1
--context-parallel-size 2

--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1

--use-dynamic-batch-size
--max-tokens-per-gpu 16384
--data-pad-size-multiplier 4096
)

OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-5
--lr-decay-style cosine
--min-lr 1e-6
--lr-warmup-fraction 0.1
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98

--optimizer-cpu-offload
--overlap-cpu-optimizer-d2h-h2d
--use-precision-aware-optimizer
)

WANDB_ARGS=(
# --use-wandb
# --wandb-project slime-dev
# --wandb-group glm5-sft
# --wandb-key ${WANDB_KEY}
)

MISC_ARGS=(
# default dropout in megatron is 0.1
--attention-dropout 0.0
--hidden-dropout 0.0
# should be good for model performance
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
# need to comment this when using model with MLA
--attention-backend flash

# use deepep for megatron
--moe-enable-deepep
--moe-token-dispatcher-type flex
)

# launch the master node of ray in container
export no_proxy="127.0.0.1,${MASTER_ADDR}"
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do
if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then
continue
fi
echo "Starting Ray worker on ${WORKER_IP}"
ssh root@"${WORKER_IP}" \
"pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" &
done
wait


# Build the runtime environment JSON with proper variable substitution
RUNTIME_ENV_JSON="{
\"env_vars\": {
\"PYTHONPATH\": \"/root/Megatron-LM/\",
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
\"no_proxy\": \"${no_proxy}\",
\"MASTER_ADDR\": \"${MASTER_ADDR}\",
\"INDEXER_ROPE_NEOX_STYLE\": \"0\",
\"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK\": \"32\",
\"NVSHMEM_DISABLE_NCCL\": \"1\",
\"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\"
}
}"

ray job submit --address="http://127.0.0.1:8265" \
--runtime-env-json="${RUNTIME_ENV_JSON}" \
-- python3 train_async.py \
--actor-num-nodes 32 \
--actor-num-gpus-per-node 8 \
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${SFT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${WANDB_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${MISC_ARGS[@]}
2 changes: 1 addition & 1 deletion slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ def add_rollout_buffer_arguments(parser):
"--loss-mask-type",
type=str,
default="qwen",
choices=["qwen", "qwen3", "qwen3_5", "distill_qwen"],
choices=["qwen", "qwen3", "qwen3_5", "glm5", "distill_qwen"],
help="Loss mask type",
)
parser.add_argument(
Expand Down
96 changes: 96 additions & 0 deletions slime/utils/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,100 @@ def gen_multi_turn_loss_mask_qwen3_5(

return token_ids, loss_mask

def gen_multi_turn_loss_mask_glm5(
self, messages: list[dict], tools: list[dict] = None
) -> tuple[list[int], list[int]]:
rendered_text = self.tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, return_dict=False)
template_token_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, tools=tools, return_dict=False
)

stop_words = ("<|endoftext|>", "<|user|>", "<|observation|>")
final_stop = "<|user|>"
append_final_stop = (
bool(messages)
and messages[-1]["role"] == "assistant"
and messages[-1].get("step_loss_mask", 1) == 1
and not rendered_text.endswith(stop_words)
)
if append_final_stop:
rendered_text += final_stop

tokenized = self.tokenizer(rendered_text, add_special_tokens=False, return_offsets_mapping=True)
token_ids = tokenized["input_ids"]
offset_mapping = tokenized.get("offset_mapping")

if offset_mapping is None:
raise ValueError(
"GLM5 loss mask generation requires a fast tokenizer with `return_offsets_mapping` support."
)

if token_ids[: len(template_token_ids)] != template_token_ids:
raise ValueError(
"GLM5 rendered text tokenization does not match "
"`apply_chat_template(..., tokenize=True)` output."
)

assistant_header = "<|assistant|>"
stop_markers = ("<|endoftext|>", "<|user|>", "<|observation|>")
think_prefix = "<think>"
no_think_prefix = "</think>"

char_mask = [0] * len(rendered_text)
cursor = 0

for message in messages:
if message["role"] != "assistant":
continue

header_pos = rendered_text.find(assistant_header, cursor)
if header_pos < 0:
raise ValueError("Failed to locate assistant message in rendered GLM5 chat template output.")

content_start = header_pos + len(assistant_header)
next_stop = min(
(
(marker_pos, marker)
for marker in stop_markers
if (marker_pos := rendered_text.find(marker, content_start)) >= 0
),
default=None,
)
span_end = next_stop[0] if next_stop is not None else len(rendered_text)
cursor = span_end

if message.get("step_loss_mask", 1) != 1:
continue

mask_start = content_start
while mask_start < span_end and rendered_text[mask_start].isspace():
mask_start += 1

if rendered_text.startswith(think_prefix, mask_start):
mask_start += len(think_prefix)
elif rendered_text.startswith(no_think_prefix, mask_start):
mask_start += len(no_think_prefix)

for pos in range(mask_start, span_end):
char_mask[pos] = 1
if next_stop is not None:
stop_start, stop_marker = next_stop
for pos in range(stop_start, stop_start + len(stop_marker)):
char_mask[pos] = 1

char_mask_prefix_sum = [0]
for value in char_mask:
char_mask_prefix_sum.append(char_mask_prefix_sum[-1] + value)

loss_mask = []
for start, end in offset_mapping:
if end <= start:
loss_mask.append(0)
else:
loss_mask.append(1 if char_mask_prefix_sum[end] - char_mask_prefix_sum[start] > 0 else 0)

return token_ids, loss_mask

def gen_multi_turn_loss_mask_distill_qwen(
self, messages: list[dict], tools: list[dict] = None
) -> tuple[list[int], list[int]]:
Expand Down Expand Up @@ -223,6 +317,8 @@ def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple
return self.gen_multi_turn_loss_mask_qwen3(messages, tools)
elif self.tokenizer_type == "qwen3_5":
return self.gen_multi_turn_loss_mask_qwen3_5(messages, tools)
elif self.tokenizer_type == "glm5":
return self.gen_multi_turn_loss_mask_glm5(messages, tools)
elif self.tokenizer_type == "distill_qwen":
return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools)
else:
Expand Down
Loading