diff --git a/examples/retool/generate_with_retool.py b/examples/retool/generate_with_retool.py index a090af63d7..339b79dd45 100644 --- a/examples/retool/generate_with_retool.py +++ b/examples/retool/generate_with_retool.py @@ -216,6 +216,15 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: """Custom generation function supporting tool calls""" assert not args.partial_rollout, "Partial rollout is not supported for " "this function at the moment." + # Retried samples (previously aborted / partial) arrive here with stale + # rollout state from the first attempt. Clear it so this generation starts + # clean; otherwise the concatenation below appends new tokens to old ones + # and downstream `slice_log_prob_with_cp` sees a length mismatch. + sample.rollout_log_probs = None + sample.response = "" + sample.response_length = 0 + sample.loss_mask = None + state = GenerateState(args) url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" @@ -229,22 +238,37 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: loss_masks = [] tool_call_count = 0 # Track actual tool call rounds + if args.rollout_max_context_len is not None: + max_context_length = args.rollout_max_context_len + else: + max_context_length = args.context_parallel_size * args.max_tokens_per_gpu + for turn in range(TOOL_CONFIGS["max_turns"]): # Check if total length exceeds max context length total_length = len(prompt_tokens_ids) + len(response_token_ids) - if args.rollout_max_context_len is not None: - max_context_length = args.rollout_max_context_len - else: - max_context_length = args.context_parallel_size * args.max_tokens_per_gpu if total_length >= max_context_length: sample.status = Sample.Status.TRUNCATED break + # Clamp per-turn max_new_tokens to the remaining context budget so a + # single turn cannot push total_length past max_context_length. Without + # this, a turn can append up to rollout_max_response_len tokens on top + # of a total that was just barely under the cap, producing samples + # that exceed the training-side max_tokens_per_gpu * cp_size budget + # and crash the partition/batch code (asserts or OOMs on an oversized + # partition). + remaining_budget = max_context_length - total_length + per_turn_sampling_params = dict(sampling_params) + per_turn_sampling_params["max_new_tokens"] = min( + sampling_params.get("max_new_tokens", remaining_budget), + remaining_budget, + ) + # Use token IDs instead of text current_token_ids = prompt_tokens_ids + response_token_ids payload = { "input_ids": current_token_ids, - "sampling_params": sampling_params, + "sampling_params": per_turn_sampling_params, "return_logprob": True, # Request log probabilities for training } @@ -285,9 +309,14 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: sample.rollout_log_probs += cur_log_probs else: - cur_response = output["text"] - cur_response = postprocess_responses(cur_response) - cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"] + # sglang returned text but no output_token_logprobs — we cannot + # recover per-token logprobs for this turn, which would desync + # rollout_log_probs from response_token_ids and blow up + # `slice_log_prob_with_cp` downstream. Abort the sample so the + # fully_async rollout manager returns the whole group to the + # buffer for retry instead of poisoning the trainer. + sample.status = Sample.Status.ABORTED + return sample response += cur_response response_token_ids += cur_response_token_ids @@ -321,6 +350,26 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: sample.rollout_log_probs ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps" + # Tool output is appended verbatim and can push total_length past + # max_context_length (the per-turn generation was clamped to the + # remaining budget, but tool output is unconstrained). Trim tail + # tokens so the final sample fits the training budget exactly. + overflow = len(prompt_tokens_ids) + len(response_token_ids) - max_context_length + if overflow > 0: + response_token_ids = response_token_ids[:-overflow] + loss_masks = loss_masks[:-overflow] + if sample.rollout_log_probs is not None: + sample.rollout_log_probs = sample.rollout_log_probs[:-overflow] + # Resync the text field from the trimmed token list so + # reward_func's `sample.prompt + sample.response` matches what + # the model was actually trained on. decode(tokenize(text)) can + # be lossy on some tokenizers (whitespace / special-token + # collapse), but reward_func's regex is whitespace-robust and + # the trainer sees tokens, not text — so the drift is safe. + response = state.tokenizer.decode(response_token_ids) + sample.status = Sample.Status.TRUNCATED + break + if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]: break