Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 100 additions & 10 deletions app/src/ai/agent/api/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ use std::{collections::HashMap, sync::Arc};

use crate::{ai::agent::redaction, terminal::model::session::SessionType};
use futures_util::StreamExt;
use sha2::{Digest as _, Sha256};
use warp_core::features::FeatureFlag;
use warp_multi_agent_api as api;

use crate::server::server_api::ServerApi;

use super::{convert_to::convert_input, ConvertToAPITypeError, RequestParams, ResponseStream};

const MAX_EXTERNAL_TOOL_CALL_ID_LEN: usize = 64;
const NORMALIZED_TOOL_CALL_ID_PREFIX: &str = "tc_";
const NORMALIZED_TOOL_CALL_ID_HASH_HEX_LEN: usize = 60;

pub async fn generate_multi_agent_output(
server_api: Arc<ServerApi>,
mut params: RequestParams,
Expand All @@ -19,6 +24,27 @@ pub async fn generate_multi_agent_output(
.take()
.unwrap_or_else(|| get_supported_tools(&params));
let supported_cli_agent_tools = get_supported_cli_agent_tools(&params);
let request = build_request(params, supported_tools, supported_cli_agent_tools)?;

let response_stream = server_api.generate_multi_agent_output(&request).await;
match response_stream {
Ok(stream) => {
let output_stream = stream.take_until(cancellation_rx);
Ok(Box::pin(output_stream))
}
Err(e) => {
let (tx, rx) = async_channel::unbounded();
let _ = tx.send(Err(e)).await;
Ok(Box::pin(rx))
}
}
}

fn build_request(
mut params: RequestParams,
supported_tools: Vec<api::ToolType>,
supported_cli_agent_tools: Vec<api::ToolType>,
) -> Result<api::Request, ConvertToAPITypeError> {
let mut logging_metadata = HashMap::new();
if let Some(metadata) = params.metadata {
logging_metadata.insert(
Expand Down Expand Up @@ -56,7 +82,7 @@ pub async fn generate_multi_agent_output(
api_keys.allow_use_of_warp_credits = params.allow_use_of_warp_credits_with_byok;
}

let request = api::Request {
let mut request = api::Request {
task_context: Some(api::request::TaskContext {
tasks: params.tasks,
}),
Expand Down Expand Up @@ -129,20 +155,84 @@ pub async fn generate_multi_agent_output(
mcp_context: params.mcp_context.map(Into::into),
};

let response_stream = server_api.generate_multi_agent_output(&request).await;
match response_stream {
Ok(stream) => {
let output_stream = stream.take_until(cancellation_rx);
Ok(Box::pin(output_stream))
normalize_external_tool_call_ids(&mut request);
Ok(request)
}

// Some providers emit opaque tool call IDs that exceed downstream API limits.
// We normalize those IDs on the fully assembled outbound request so every
// provider-facing reference stays in sync without mutating internal client state.
fn normalize_external_tool_call_ids(request: &mut api::Request) {
if let Some(task_context) = &mut request.task_context {
for task in &mut task_context.tasks {
for message in &mut task.messages {
match &mut message.message {
Some(api::message::Message::ToolCall(tool_call)) => {
normalize_external_tool_call_id_in_place(&mut tool_call.tool_call_id);
}
Some(api::message::Message::ToolCallResult(tool_call_result)) => {
normalize_external_tool_call_id_in_place(
&mut tool_call_result.tool_call_id,
);
}
_ => {}
}
}
}
Err(e) => {
let (tx, rx) = async_channel::unbounded();
let _ = tx.send(Err(e)).await;
Ok(Box::pin(rx))
}

if let Some(input) = &mut request.input {
normalize_external_tool_call_ids_in_input(input);
}
}

fn normalize_external_tool_call_ids_in_input(input: &mut api::request::Input) {
let Some(api::request::input::Type::UserInputs(user_inputs)) = input.r#type.as_mut() else {
return;
};

for user_input in &mut user_inputs.inputs {
match user_input.input.as_mut() {
Some(api::request::input::user_inputs::user_input::Input::ToolCallResult(
tool_call_result,
)) => {
normalize_external_tool_call_id_in_place(&mut tool_call_result.tool_call_id);
}
Some(api::request::input::user_inputs::user_input::Input::CliAgentUserQuery(
cli_agent_user_query,
)) => {
normalize_external_tool_call_id_in_place(
&mut cli_agent_user_query.run_shell_command_tool_call_id,
);
}
_ => {}
}
}
}

fn normalize_external_tool_call_id_in_place(tool_call_id: &mut String) {
if tool_call_id.len() > MAX_EXTERNAL_TOOL_CALL_ID_LEN {
*tool_call_id = normalize_external_tool_call_id(tool_call_id);
}
}

fn normalize_external_tool_call_id(tool_call_id: &str) -> String {
if tool_call_id.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN {
return tool_call_id.to_owned();
}

debug_assert!(
NORMALIZED_TOOL_CALL_ID_PREFIX.len() + NORMALIZED_TOOL_CALL_ID_HASH_HEX_LEN
<= MAX_EXTERNAL_TOOL_CALL_ID_LEN
);

let hash = hex::encode(Sha256::digest(tool_call_id.as_bytes()));
format!(
"{NORMALIZED_TOOL_CALL_ID_PREFIX}{}",
&hash[..NORMALIZED_TOOL_CALL_ID_HASH_HEX_LEN]
)
}

fn get_supported_tools(params: &RequestParams) -> Vec<api::ToolType> {
let mut supported_tools = vec![
api::ToolType::Grep,
Expand Down
246 changes: 246 additions & 0 deletions app/src/ai/agent/api/impl_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,249 @@ fn supported_tools_omit_upload_artifact_when_feature_flag_is_disabled() {

assert!(!supported_tools.contains(&api::ToolType::UploadFileArtifact));
}

mod normalize_external_tool_call_ids {
use std::{collections::HashMap, sync::Arc};

use warp_core::command::ExitCode;
use warp_multi_agent_api as api;

use crate::ai::agent::api::RequestParams;
use crate::ai::agent::task::TaskId;
use crate::ai::agent::{
AIAgentActionResult, AIAgentActionResultType, AIAgentInput, RunningCommand,
TransferShellCommandControlToUserResult, UserQueryMode,
};
use crate::terminal::model::block::BlockId;

use super::super::{
build_request, normalize_external_tool_call_id, MAX_EXTERNAL_TOOL_CALL_ID_LEN,
};
use super::request_params_with_ask_user_question_enabled;

fn build_request_for_test(params: RequestParams) -> api::Request {
build_request(params, vec![], vec![]).expect("request should build")
}

fn build_request_with_tool_call_id(tool_call_id: &str) -> api::Request {
let tool_call_id = tool_call_id.to_string();
let params = request_params_with_inputs_and_tasks(
vec![
AIAgentInput::UserQuery {
query: "continue".to_string(),
context: Arc::new([]),
static_query_type: None,
referenced_attachments: HashMap::new(),
user_query_mode: UserQueryMode::Normal,
running_command: Some(RunningCommand {
command: "sleep 1".to_string(),
block_id: BlockId::default(),
grid_contents: "running".to_string(),
cursor: String::new(),
requested_command_id: Some(tool_call_id.clone().into()),
is_alt_screen_active: false,
}),
intended_agent: None,
},
AIAgentInput::ActionResult {
result: AIAgentActionResult {
id: tool_call_id.clone().into(),
task_id: TaskId::new("task".to_string()),
result: AIAgentActionResultType::TransferShellCommandControlToUser(
TransferShellCommandControlToUserResult::CommandFinished {
block_id: BlockId::default(),
output: "done".to_string(),
exit_code: ExitCode::from(0),
},
),
},
context: Arc::new([]),
},
],
vec![task_with_tool_call_history(&tool_call_id)],
);

build_request_for_test(params)
}

fn request_params_with_inputs_and_tasks(
inputs: Vec<AIAgentInput>,
tasks: Vec<api::Task>,
) -> RequestParams {
let mut params = request_params_with_ask_user_question_enabled(false);
params.input = inputs;
params.tasks = tasks;
params
}

fn tool_call_message(tool_call_id: &str) -> api::Message {
api::Message {
id: "tool-call-message".to_string(),
task_id: "task".to_string(),
request_id: "request".to_string(),
message: Some(api::message::Message::ToolCall(api::message::ToolCall {
tool_call_id: tool_call_id.to_string(),
tool: Some(api::message::tool_call::Tool::RunShellCommand(
api::message::tool_call::RunShellCommand {
command: "echo hi".to_string(),
..Default::default()
},
)),
})),
..Default::default()
}
}

fn tool_call_result_message(tool_call_id: &str) -> api::Message {
api::Message {
id: "tool-call-result-message".to_string(),
task_id: "task".to_string(),
request_id: "request".to_string(),
message: Some(api::message::Message::ToolCallResult(
api::message::ToolCallResult {
tool_call_id: tool_call_id.to_string(),
result: Some(api::message::tool_call_result::Result::RunShellCommand(
api::RunShellCommandResult {
command: "echo hi".to_string(),
output: "done".to_string(),
exit_code: 0,
..Default::default()
},
)),
..Default::default()
},
)),
..Default::default()
}
}

fn task_with_tool_call_history(tool_call_id: &str) -> api::Task {
api::Task {
id: "task".to_string(),
messages: vec![
tool_call_message(tool_call_id),
tool_call_result_message(tool_call_id),
],
..Default::default()
}
}

fn request_tool_call_ids(request: &api::Request) -> (String, String, String, String) {
let task_context = request
.task_context
.as_ref()
.expect("task context should exist");
let task = task_context.tasks.first().expect("task should exist");

let task_tool_call_id = match task
.messages
.first()
.and_then(|message| message.message.as_ref())
{
Some(api::message::Message::ToolCall(tool_call)) => tool_call.tool_call_id.clone(),
other => panic!("expected tool call message, got {other:?}"),
};

let task_tool_call_result_id = match task
.messages
.get(1)
.and_then(|message| message.message.as_ref())
{
Some(api::message::Message::ToolCallResult(tool_call_result)) => {
tool_call_result.tool_call_id.clone()
}
other => panic!("expected tool call result message, got {other:?}"),
};

let user_inputs = match request
.input
.as_ref()
.and_then(|input| input.r#type.as_ref())
{
Some(api::request::input::Type::UserInputs(user_inputs)) => user_inputs,
other => panic!("expected user inputs request, got {other:?}"),
};

let cli_agent_tool_call_id = user_inputs
.inputs
.iter()
.find_map(|user_input| match user_input.input.as_ref() {
Some(api::request::input::user_inputs::user_input::Input::CliAgentUserQuery(
cli_agent_query,
)) => Some(cli_agent_query.run_shell_command_tool_call_id.clone()),
_ => None,
})
.expect("cli agent user query should exist");

let action_result_tool_call_id = user_inputs
.inputs
.iter()
.find_map(|user_input| match user_input.input.as_ref() {
Some(api::request::input::user_inputs::user_input::Input::ToolCallResult(
tool_call_result,
)) => Some(tool_call_result.tool_call_id.clone()),
_ => None,
})
.expect("tool call result input should exist");

(
task_tool_call_id,
task_tool_call_result_id,
cli_agent_tool_call_id,
action_result_tool_call_id,
)
}

#[test]
fn build_request_normalizes_over_limit_tool_call_ids_consistently() {
let over_limit_tool_call_id = "a".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN + 1);
let normalized_tool_call_id = normalize_external_tool_call_id(&over_limit_tool_call_id);
let request = build_request_with_tool_call_id(&over_limit_tool_call_id);
let (
task_tool_call_id,
task_tool_call_result_id,
cli_agent_tool_call_id,
action_result_tool_call_id,
) = request_tool_call_ids(&request);

assert!(normalized_tool_call_id.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN);
assert_ne!(normalized_tool_call_id, over_limit_tool_call_id);
assert_eq!(task_tool_call_id, normalized_tool_call_id);
assert_eq!(task_tool_call_result_id, normalized_tool_call_id);
assert_eq!(cli_agent_tool_call_id, normalized_tool_call_id);
assert_eq!(action_result_tool_call_id, normalized_tool_call_id);
}

#[test]
fn build_request_preserves_at_limit_tool_call_ids() {
let at_limit_tool_call_id = "a".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN);
let request = build_request_with_tool_call_id(&at_limit_tool_call_id);
let (
task_tool_call_id,
task_tool_call_result_id,
cli_agent_tool_call_id,
action_result_tool_call_id,
) = request_tool_call_ids(&request);

assert_eq!(task_tool_call_id, at_limit_tool_call_id);
assert_eq!(task_tool_call_result_id, at_limit_tool_call_id);
assert_eq!(cli_agent_tool_call_id, at_limit_tool_call_id);
assert_eq!(action_result_tool_call_id, at_limit_tool_call_id);
}

#[test]
fn normalized_tool_call_ids_are_stable_and_distinct_for_over_limit_inputs() {
let first = "a".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN + 1);
let second = "b".repeat(MAX_EXTERNAL_TOOL_CALL_ID_LEN + 1);

let first_normalized = normalize_external_tool_call_id(&first);
let second_normalized = normalize_external_tool_call_id(&second);

assert_eq!(first_normalized, normalize_external_tool_call_id(&first));
assert_ne!(first_normalized, second_normalized);
assert_ne!(first_normalized, first);
assert_ne!(second_normalized, second);
assert!(first_normalized.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN);
assert!(second_normalized.len() <= MAX_EXTERNAL_TOOL_CALL_ID_LEN);
}
}