Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions examples/retool/generate_with_retool.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
"""Custom generation function supporting tool calls"""
assert not args.partial_rollout, "Partial rollout is not supported for " "this function at the moment."

# Retried samples (previously aborted / partial) arrive here with stale
# rollout state from the first attempt. Clear it so this generation starts
# clean; otherwise the concatenation below appends new tokens to old ones
# and downstream `slice_log_prob_with_cp` sees a length mismatch.
sample.rollout_log_probs = None
sample.response = ""
sample.response_length = 0
sample.loss_mask = None

state = GenerateState(args)
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

Expand All @@ -229,22 +238,37 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
loss_masks = []
tool_call_count = 0 # Track actual tool call rounds

if args.rollout_max_context_len is not None:
max_context_length = args.rollout_max_context_len
else:
max_context_length = args.context_parallel_size * args.max_tokens_per_gpu

for turn in range(TOOL_CONFIGS["max_turns"]):
# Check if total length exceeds max context length
total_length = len(prompt_tokens_ids) + len(response_token_ids)
if args.rollout_max_context_len is not None:
max_context_length = args.rollout_max_context_len
else:
max_context_length = args.context_parallel_size * args.max_tokens_per_gpu
if total_length >= max_context_length:
sample.status = Sample.Status.TRUNCATED
break

# Clamp per-turn max_new_tokens to the remaining context budget so a
# single turn cannot push total_length past max_context_length. Without
# this, a turn can append up to rollout_max_response_len tokens on top
# of a total that was just barely under the cap, producing samples
# that exceed the training-side max_tokens_per_gpu * cp_size budget
# and crash the partition/batch code (asserts or OOMs on an oversized
# partition).
remaining_budget = max_context_length - total_length
per_turn_sampling_params = dict(sampling_params)
per_turn_sampling_params["max_new_tokens"] = min(
sampling_params.get("max_new_tokens", remaining_budget),
remaining_budget,
)

# Use token IDs instead of text
current_token_ids = prompt_tokens_ids + response_token_ids
payload = {
"input_ids": current_token_ids,
"sampling_params": sampling_params,
"sampling_params": per_turn_sampling_params,
"return_logprob": True, # Request log probabilities for training
}

Expand Down Expand Up @@ -285,9 +309,14 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
sample.rollout_log_probs += cur_log_probs

else:
cur_response = output["text"]
cur_response = postprocess_responses(cur_response)
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
# sglang returned text but no output_token_logprobs — we cannot
# recover per-token logprobs for this turn, which would desync
# rollout_log_probs from response_token_ids and blow up
# `slice_log_prob_with_cp` downstream. Abort the sample so the
# fully_async rollout manager returns the whole group to the
# buffer for retry instead of poisoning the trainer.
sample.status = Sample.Status.ABORTED
return sample

response += cur_response
response_token_ids += cur_response_token_ids
Expand Down Expand Up @@ -321,6 +350,26 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
sample.rollout_log_probs
), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps"

# Tool output is appended verbatim and can push total_length past
# max_context_length (the per-turn generation was clamped to the
# remaining budget, but tool output is unconstrained). Trim tail
# tokens so the final sample fits the training budget exactly.
overflow = len(prompt_tokens_ids) + len(response_token_ids) - max_context_length
if overflow > 0:
response_token_ids = response_token_ids[:-overflow]
loss_masks = loss_masks[:-overflow]
if sample.rollout_log_probs is not None:
sample.rollout_log_probs = sample.rollout_log_probs[:-overflow]
# Resync the text field from the trimmed token list so
# reward_func's `sample.prompt + sample.response` matches what
# the model was actually trained on. decode(tokenize(text)) can
# be lossy on some tokenizers (whitespace / special-token
# collapse), but reward_func's regex is whitespace-robust and
# the trainer sees tokens, not text — so the drift is safe.
response = state.tokenizer.decode(response_token_ids)
sample.status = Sample.Status.TRUNCATED
break

if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]:
break

Expand Down