From ebfadec925b346fc250ff2dbe2182158a1379aae Mon Sep 17 00:00:00 2001 From: qaijuang Date: Wed, 29 Apr 2026 07:48:41 -0400 Subject: [PATCH] Normalize long tool call IDs in agent requests --- app/src/ai/agent/api/impl.rs | 110 +++++++++++-- app/src/ai/agent/api/impl_tests.rs | 246 +++++++++++++++++++++++++++++ 2 files changed, 346 insertions(+), 10 deletions(-) diff --git a/app/src/ai/agent/api/impl.rs b/app/src/ai/agent/api/impl.rs index d5390b8b2..27a513582 100644 --- a/app/src/ai/agent/api/impl.rs +++ b/app/src/ai/agent/api/impl.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ai::agent::redaction, terminal::model::session::SessionType}; use futures_util::StreamExt; +use sha2::{Digest as _, Sha256}; use warp_core::features::FeatureFlag; use warp_multi_agent_api as api; @@ -9,6 +10,10 @@ use crate::server::server_api::ServerApi; use super::{convert_to::convert_input, ConvertToAPITypeError, RequestParams, ResponseStream}; +const MAX_EXTERNAL_TOOL_CALL_ID_LEN: usize = 64; +const NORMALIZED_TOOL_CALL_ID_PREFIX: &str = "tc_"; +const NORMALIZED_TOOL_CALL_ID_HASH_HEX_LEN: usize = 60; + pub async fn generate_multi_agent_output( server_api: Arc, mut params: RequestParams, @@ -19,6 +24,27 @@ pub async fn generate_multi_agent_output( .take() .unwrap_or_else(|| get_supported_tools(¶ms)); let supported_cli_agent_tools = get_supported_cli_agent_tools(¶ms); + let request = build_request(params, supported_tools, supported_cli_agent_tools)?; + + let response_stream = server_api.generate_multi_agent_output(&request).await; + match response_stream { + Ok(stream) => { + let output_stream = stream.take_until(cancellation_rx); + Ok(Box::pin(output_stream)) + } + Err(e) => { + let (tx, rx) = async_channel::unbounded(); + let _ = tx.send(Err(e)).await; + Ok(Box::pin(rx)) + } + } +} + +fn build_request( + mut params: RequestParams, + supported_tools: Vec, + supported_cli_agent_tools: Vec, +) -> Result { let mut logging_metadata = HashMap::new(); if let Some(metadata) = params.metadata { logging_metadata.insert( @@ -56,7 +82,7 @@ pub async fn generate_multi_agent_output( api_keys.allow_use_of_warp_credits = params.allow_use_of_warp_credits_with_byok; } - let request = api::Request { + let mut request = api::Request { task_context: Some(api::request::TaskContext { tasks: params.tasks, }), @@ -129,20 +155,84 @@ pub async fn generate_multi_agent_output( mcp_context: params.mcp_context.map(Into::into), }; - let response_stream = server_api.generate_multi_agent_output(&request).await; - match response_stream { - Ok(stream) => { - let output_stream = stream.take_until(cancellation_rx); - Ok(Box::pin(output_stream)) + normalize_external_tool_call_ids(&mut request); + Ok(request) +} + +// Some providers emit opaque tool call IDs that exceed downstream API limits. +// We normalize those IDs on the fully assembled outbound request so every +// provider-facing reference stays in sync without mutating internal client state. +fn normalize_external_tool_call_ids(request: &mut api::Request) { + if let Some(task_context) = &mut request.task_context { + for task in &mut task_context.tasks { + for message in &mut task.messages { + match &mut message.message { + Some(api::message::Message::ToolCall(tool_call)) => { + normalize_external_tool_call_id_in_place(&mut tool_call.tool_call_id); + } + Some(api::message::Message::ToolCallResult(tool_call_result)) => { + normalize_external_tool_call_id_in_place( + &mut tool_call_result.tool_call_id, + ); + } + _ => {} + } + } } - Err(e) => { - let (tx, rx) = async_channel::unbounded(); - let _ = tx.send(Err(e)).await; - Ok(Box::pin(rx)) + } + + if let Some(input) = &mut request.input { + normalize_external_tool_call_ids_in_input(input); + } +} + +fn normalize_external_tool_call_ids_in_input(input: &mut api::request::Input) { + let Some(api::request::input::Type::UserInputs(user_inputs)) = input.r#type.as_mut() else { + return; + }; + + for user_input in &mut user_inputs.inputs { + match user_input.input.as_mut() { + Some(api::request::input::user_inputs::user_input::Input::ToolCallResult( + tool_call_result, + )) => { + normalize_external_tool_call_id_in_place(&mut tool_call_result.tool_call_id); + } + Some(api::request::input::user_inputs::user_input::Input::CliAgentUserQuery( + cli_agent_user_query, + )) => { + normalize_external_tool_call_id_in_place( + &mut cli_agent_user_query.run_shell_command_tool_call_id, + ); + } + _ => {} } } } +fn normalize_external_tool_call_id_in_place(tool_call_id: &mut String) { + if tool_call_id.len() > MAX_EXTERNAL_TOOL_CALL_ID_LEN { + *tool_call_id = normalize_external_tool_call_id(tool_call_id); + } +} + +fn normalize_external_tool_call_id(tool_call_id: &str) -> String { + if tool_call_id.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN { + return tool_call_id.to_owned(); + } + + debug_assert!( + NORMALIZED_TOOL_CALL_ID_PREFIX.len() + NORMALIZED_TOOL_CALL_ID_HASH_HEX_LEN + <= MAX_EXTERNAL_TOOL_CALL_ID_LEN + ); + + let hash = hex::encode(Sha256::digest(tool_call_id.as_bytes())); + format!( + "{NORMALIZED_TOOL_CALL_ID_PREFIX}{}", + &hash[..NORMALIZED_TOOL_CALL_ID_HASH_HEX_LEN] + ) +} + fn get_supported_tools(params: &RequestParams) -> Vec { let mut supported_tools = vec![ api::ToolType::Grep, diff --git a/app/src/ai/agent/api/impl_tests.rs b/app/src/ai/agent/api/impl_tests.rs index c219ff5cc..2b71b53f2 100644 --- a/app/src/ai/agent/api/impl_tests.rs +++ b/app/src/ai/agent/api/impl_tests.rs @@ -79,3 +79,249 @@ fn supported_tools_omit_upload_artifact_when_feature_flag_is_disabled() { assert!(!supported_tools.contains(&api::ToolType::UploadFileArtifact)); } + +mod normalize_external_tool_call_ids { + use std::{collections::HashMap, sync::Arc}; + + use warp_core::command::ExitCode; + use warp_multi_agent_api as api; + + use crate::ai::agent::api::RequestParams; + use crate::ai::agent::task::TaskId; + use crate::ai::agent::{ + AIAgentActionResult, AIAgentActionResultType, AIAgentInput, RunningCommand, + TransferShellCommandControlToUserResult, UserQueryMode, + }; + use crate::terminal::model::block::BlockId; + + use super::super::{ + build_request, normalize_external_tool_call_id, MAX_EXTERNAL_TOOL_CALL_ID_LEN, + }; + use super::request_params_with_ask_user_question_enabled; + + fn build_request_for_test(params: RequestParams) -> api::Request { + build_request(params, vec![], vec![]).expect("request should build") + } + + fn build_request_with_tool_call_id(tool_call_id: &str) -> api::Request { + let tool_call_id = tool_call_id.to_string(); + let params = request_params_with_inputs_and_tasks( + vec![ + AIAgentInput::UserQuery { + query: "continue".to_string(), + context: Arc::new([]), + static_query_type: None, + referenced_attachments: HashMap::new(), + user_query_mode: UserQueryMode::Normal, + running_command: Some(RunningCommand { + command: "sleep 1".to_string(), + block_id: BlockId::default(), + grid_contents: "running".to_string(), + cursor: String::new(), + requested_command_id: Some(tool_call_id.clone().into()), + is_alt_screen_active: false, + }), + intended_agent: None, + }, + AIAgentInput::ActionResult { + result: AIAgentActionResult { + id: tool_call_id.clone().into(), + task_id: TaskId::new("task".to_string()), + result: AIAgentActionResultType::TransferShellCommandControlToUser( + TransferShellCommandControlToUserResult::CommandFinished { + block_id: BlockId::default(), + output: "done".to_string(), + exit_code: ExitCode::from(0), + }, + ), + }, + context: Arc::new([]), + }, + ], + vec![task_with_tool_call_history(&tool_call_id)], + ); + + build_request_for_test(params) + } + + fn request_params_with_inputs_and_tasks( + inputs: Vec, + tasks: Vec, + ) -> RequestParams { + let mut params = request_params_with_ask_user_question_enabled(false); + params.input = inputs; + params.tasks = tasks; + params + } + + fn tool_call_message(tool_call_id: &str) -> api::Message { + api::Message { + id: "tool-call-message".to_string(), + task_id: "task".to_string(), + request_id: "request".to_string(), + message: Some(api::message::Message::ToolCall(api::message::ToolCall { + tool_call_id: tool_call_id.to_string(), + tool: Some(api::message::tool_call::Tool::RunShellCommand( + api::message::tool_call::RunShellCommand { + command: "echo hi".to_string(), + ..Default::default() + }, + )), + })), + ..Default::default() + } + } + + fn tool_call_result_message(tool_call_id: &str) -> api::Message { + api::Message { + id: "tool-call-result-message".to_string(), + task_id: "task".to_string(), + request_id: "request".to_string(), + message: Some(api::message::Message::ToolCallResult( + api::message::ToolCallResult { + tool_call_id: tool_call_id.to_string(), + result: Some(api::message::tool_call_result::Result::RunShellCommand( + api::RunShellCommandResult { + command: "echo hi".to_string(), + output: "done".to_string(), + exit_code: 0, + ..Default::default() + }, + )), + ..Default::default() + }, + )), + ..Default::default() + } + } + + fn task_with_tool_call_history(tool_call_id: &str) -> api::Task { + api::Task { + id: "task".to_string(), + messages: vec![ + tool_call_message(tool_call_id), + tool_call_result_message(tool_call_id), + ], + ..Default::default() + } + } + + fn request_tool_call_ids(request: &api::Request) -> (String, String, String, String) { + let task_context = request + .task_context + .as_ref() + .expect("task context should exist"); + let task = task_context.tasks.first().expect("task should exist"); + + let task_tool_call_id = match task + .messages + .first() + .and_then(|message| message.message.as_ref()) + { + Some(api::message::Message::ToolCall(tool_call)) => tool_call.tool_call_id.clone(), + other => panic!("expected tool call message, got {other:?}"), + }; + + let task_tool_call_result_id = match task + .messages + .get(1) + .and_then(|message| message.message.as_ref()) + { + Some(api::message::Message::ToolCallResult(tool_call_result)) => { + tool_call_result.tool_call_id.clone() + } + other => panic!("expected tool call result message, got {other:?}"), + }; + + let user_inputs = match request + .input + .as_ref() + .and_then(|input| input.r#type.as_ref()) + { + Some(api::request::input::Type::UserInputs(user_inputs)) => user_inputs, + other => panic!("expected user inputs request, got {other:?}"), + }; + + let cli_agent_tool_call_id = user_inputs + .inputs + .iter() + .find_map(|user_input| match user_input.input.as_ref() { + Some(api::request::input::user_inputs::user_input::Input::CliAgentUserQuery( + cli_agent_query, + )) => Some(cli_agent_query.run_shell_command_tool_call_id.clone()), + _ => None, + }) + .expect("cli agent user query should exist"); + + let action_result_tool_call_id = user_inputs + .inputs + .iter() + .find_map(|user_input| match user_input.input.as_ref() { + Some(api::request::input::user_inputs::user_input::Input::ToolCallResult( + tool_call_result, + )) => Some(tool_call_result.tool_call_id.clone()), + _ => None, + }) + .expect("tool call result input should exist"); + + ( + task_tool_call_id, + task_tool_call_result_id, + cli_agent_tool_call_id, + action_result_tool_call_id, + ) + } + + #[test] + fn build_request_normalizes_over_limit_tool_call_ids_consistently() { + let over_limit_tool_call_id = "a".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN + 1); + let normalized_tool_call_id = normalize_external_tool_call_id(&over_limit_tool_call_id); + let request = build_request_with_tool_call_id(&over_limit_tool_call_id); + let ( + task_tool_call_id, + task_tool_call_result_id, + cli_agent_tool_call_id, + action_result_tool_call_id, + ) = request_tool_call_ids(&request); + + assert!(normalized_tool_call_id.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN); + assert_ne!(normalized_tool_call_id, over_limit_tool_call_id); + assert_eq!(task_tool_call_id, normalized_tool_call_id); + assert_eq!(task_tool_call_result_id, normalized_tool_call_id); + assert_eq!(cli_agent_tool_call_id, normalized_tool_call_id); + assert_eq!(action_result_tool_call_id, normalized_tool_call_id); + } + + #[test] + fn build_request_preserves_at_limit_tool_call_ids() { + let at_limit_tool_call_id = "a".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN); + let request = build_request_with_tool_call_id(&at_limit_tool_call_id); + let ( + task_tool_call_id, + task_tool_call_result_id, + cli_agent_tool_call_id, + action_result_tool_call_id, + ) = request_tool_call_ids(&request); + + assert_eq!(task_tool_call_id, at_limit_tool_call_id); + assert_eq!(task_tool_call_result_id, at_limit_tool_call_id); + assert_eq!(cli_agent_tool_call_id, at_limit_tool_call_id); + assert_eq!(action_result_tool_call_id, at_limit_tool_call_id); + } + + #[test] + fn normalized_tool_call_ids_are_stable_and_distinct_for_over_limit_inputs() { + let first = "a".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN + 1); + let second = "b".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN + 1); + + let first_normalized = normalize_external_tool_call_id(&first); + let second_normalized = normalize_external_tool_call_id(&second); + + assert_eq!(first_normalized, normalize_external_tool_call_id(&first)); + assert_ne!(first_normalized, second_normalized); + assert_ne!(first_normalized, first); + assert_ne!(second_normalized, second); + assert!(first_normalized.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN); + assert!(second_normalized.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN); + } +}