diff --git a/app/src/ai/llms.rs b/app/src/ai/llms.rs index d2c220963..875e653b8 100644 --- a/app/src/ai/llms.rs +++ b/app/src/ai/llms.rs @@ -36,6 +36,7 @@ pub fn is_using_api_key_for_provider(provider: &LLMProvider, app: &AppContext) - LLMProvider::OpenAI => api_keys.is_some_and(|keys| keys.openai.is_some()), LLMProvider::Anthropic => api_keys.is_some_and(|keys| keys.anthropic.is_some()), LLMProvider::Google => api_keys.is_some_and(|keys| keys.google.is_some()), + LLMProvider::Ollama => api_keys.is_some_and(|keys| keys.ollama_url.is_some()), _ => false, } } @@ -89,6 +90,7 @@ pub enum LLMProvider { Anthropic, Google, Xai, + Ollama, Unknown, } @@ -100,6 +102,7 @@ impl LLMProvider { LLMProvider::Anthropic => Some(Icon::ClaudeLogo), LLMProvider::Google => Some(Icon::GeminiLogo), LLMProvider::Xai => None, + LLMProvider::Ollama => None, // TODO: Add Ollama icon LLMProvider::Unknown => None, } } diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index e461ca44d..996c36c7b 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -56,6 +56,7 @@ derivative.workspace = true warp_core.workspace = true warp_terminal.workspace = true warp_util.workspace = true +reqwest.workspace = true chrono.workspace = true persistence.workspace = true priority-queue = "2.3.1" diff --git a/crates/ai/src/api_keys.rs b/crates/ai/src/api_keys.rs index 5cc96c225..6d598bd0f 100644 --- a/crates/ai/src/api_keys.rs +++ b/crates/ai/src/api_keys.rs @@ -22,6 +22,8 @@ pub struct ApiKeys { pub anthropic: Option, pub openai: Option, pub open_router: Option, + /// Ollama URL (e.g., "http://localhost:11434"). No API key needed. + pub ollama_url: Option, } impl ApiKeys { @@ -30,6 +32,7 @@ impl ApiKeys { || self.anthropic.is_some() || self.google.is_some() || self.open_router.is_some() + || self.ollama_url.is_some() } } @@ -93,6 +96,12 @@ impl ApiKeyManager { self.write_keys_to_secure_storage(ctx); } + pub fn set_ollama_url(&mut self, url: Option, ctx: &mut ModelContext) { + self.keys.ollama_url = url; + ctx.emit(ApiKeyManagerEvent::KeysUpdated); + self.write_keys_to_secure_storage(ctx); + } + pub fn set_aws_credentials_state( &mut self, state: AwsCredentialsState, @@ -138,6 +147,8 @@ impl ApiKeyManager { .then(|| self.keys.open_router.clone()) .flatten() .unwrap_or_default(); + // NOTE: ollama_url is NOT included here because it's a local endpoint + // and should never be sent through remote request settings. // Also include credentials when running with OIDC-managed Bedrock inference, regardless // of the per-user setting flag (which only applies to the local credential chain path). let include_aws = include_aws_bedrock_credentials diff --git a/crates/ai/src/lib.rs b/crates/ai/src/lib.rs index c4ad90bd7..92db90505 100644 --- a/crates/ai/src/lib.rs +++ b/crates/ai/src/lib.rs @@ -2,6 +2,7 @@ pub mod agent; pub mod api_keys; pub mod aws_credentials; pub mod llm_id; +pub mod ollama_client; pub use llm_id::LLMId; pub mod diff_validation; diff --git a/crates/ai/src/ollama_client.rs b/crates/ai/src/ollama_client.rs new file mode 100644 index 000000000..c5144f510 --- /dev/null +++ b/crates/ai/src/ollama_client.rs @@ -0,0 +1,440 @@ +//! Ollama API client for local model inference. +//! +//! This module provides a client for interacting with Ollama's API, +//! allowing Warp to use locally-hosted models. +//! +//! API Reference: https://github.com/ollama/ollama/blob/main/docs/api.md + +use futures::StreamExt; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +pub use crate::llm_id::LLMId; + +/// Default Ollama base URL. +pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434"; + +/// Default request timeout for Ollama API calls. +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120); + +/// A message in an Ollama chat conversation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +/// Request payload for Ollama chat API. +#[derive(Debug, Clone, Serialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, +} + +/// Response from Ollama chat API (non-streaming). +#[derive(Debug, Clone, Deserialize)] +pub struct ChatResponse { + pub model: String, + pub message: ChatMessage, + #[serde(default)] + pub done_reason: Option, + pub done: bool, +} + +/// A streaming chunk from Ollama's chat API. +#[derive(Debug, Clone, Deserialize)] +#[serde(from = "StreamChunkHelper")] +pub enum StreamChunk { + Partial { + model: String, + message: ChatMessage, + done: bool, + }, + Complete { + model: String, + message: ChatMessage, + done: bool, + #[serde(default)] + done_reason: Option, + #[serde(default)] + total_duration: Option, + #[serde(default)] + eval_count: Option, + #[serde(default)] + eval_duration: Option, + #[serde(default)] + load_duration: Option, + #[serde(default)] + prompt_eval_count: Option, + #[serde(default)] + prompt_eval_duration: Option, + }, +} + +#[derive(Debug, Clone, Deserialize)] +struct StreamChunkHelper { + done: bool, + model: String, + message: ChatMessage, + #[serde(default)] + done_reason: Option, + #[serde(default)] + total_duration: Option, + #[serde(default)] + eval_count: Option, + #[serde(default)] + eval_duration: Option, + #[serde(default)] + load_duration: Option, + #[serde(default)] + prompt_eval_count: Option, + #[serde(default)] + prompt_eval_duration: Option, +} + +impl From for StreamChunk { + fn from(h: StreamChunkHelper) -> Self { + if h.done { + StreamChunk::Complete { + model: h.model, + message: h.message, + done: h.done, + done_reason: h.done_reason, + total_duration: h.total_duration, + eval_count: h.eval_count, + eval_duration: h.eval_duration, + load_duration: h.load_duration, + prompt_eval_count: h.prompt_eval_count, + prompt_eval_duration: h.prompt_eval_duration, + } + } else { + StreamChunk::Partial { + model: h.model, + message: h.message, + done: h.done, + } + } + } +} + +/// Model info returned by Ollama's /api/tags endpoint. +#[derive(Debug, Clone, Deserialize)] +pub struct ModelInfo { + pub name: String, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub modified_at: Option, + #[serde(default)] + pub size: Option, + #[serde(default)] + pub digest: Option, +} + +/// Response from Ollama's /api/tags endpoint. +#[derive(Debug, Clone, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, +} + +/// Error types for Ollama client operations. +#[derive(Debug, thiserror::Error)] +pub enum OllamaError { + #[error("Connection failed: {0}")] + ConnectionError(#[from] reqwest::Error), + + #[error("Ollama server returned error: {0}")] + ServerError(String), + + #[error("Ollama server not running at {0}")] + ServerNotRunning(String), + + #[error("Parse error: {0}")] + ParseError(#[from] serde_json::Error), +} + +/// Result type for Ollama operations. +pub type OllamaResult = std::result::Result; + +/// Client for interacting with Ollama API. +#[derive(Debug, Clone)] +pub struct OllamaClient { + base_url: String, + http_client: Client, +} + +impl OllamaClient { + /// Create a new Ollama client with the default base URL. + pub fn new() -> Self { + Self::with_base_url(DEFAULT_OLLAMA_URL) + } + + /// Create a new Ollama client with a custom base URL. + pub fn with_base_url(base_url: impl Into) -> Self { + let http_client = Client::builder() + .timeout(DEFAULT_TIMEOUT) + .build() + .expect("Failed to create HTTP client"); + + Self { + base_url: base_url.into().trim_end_matches('/').to_string(), + http_client, + } + } + + /// Check if Ollama server is running and accessible. + pub async fn health_check(&self) -> OllamaResult { + match self + .http_client + .get(format!("{}/api/tags", self.base_url)) + .send() + .await + { + Ok(response) => Ok(response.status().is_success()), + Err(e) => { + if e.is_connect() { + Ok(false) + } else { + Err(OllamaError::ConnectionError(e)) + } + } + } + } + + /// List all available models on the Ollama server. + pub async fn list_models(&self) -> OllamaResult> { + let response = self + .http_client + .get(format!("{}/api/tags", self.base_url)) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + return Err(OllamaError::ServerError(format!( + "Server returned status: {}", + status + ))); + } + + let body = response.text().await?; + let result: ListModelsResponse = serde_json::from_str(&body)?; + Ok(result.models) + } + + /// Get a list of model names available on the server. + pub async fn available_model_names(&self) -> OllamaResult> { + let models = self.list_models().await?; + Ok(models.into_iter().map(|m| m.name).collect()) + } + + /// Send a chat request and get a non-streaming response. + pub async fn chat( + &self, + model: &str, + messages: Vec, + ) -> OllamaResult { + let request = ChatRequest { + model: model.to_string(), + messages, + stream: Some(false), + }; + + let response = self + .http_client + .post(format!("{}/api/chat", self.base_url)) + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(OllamaError::ServerError(format!( + "Server returned status {}: {}", + status, body + ))); + } + + let body = response.text().await?; + let result: ChatResponse = serde_json::from_str(&body)?; + Ok(result) + } + + /// Send a chat request and stream responses. + pub async fn chat_streaming( + &self, + model: &str, + messages: Vec, + ) -> OllamaResult>> { + let request = ChatRequest { + model: model.to_string(), + messages, + stream: Some(true), + }; + + let response = self + .http_client + .post(format!("{}/api/chat", self.base_url)) + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(OllamaError::ServerError(format!( + "Server returned status {}: {}", + status, body + ))); + } + + let stream = response.bytes_stream().map(|chunk_result| { + chunk_result + .map_err(OllamaError::ConnectionError) + .and_then(|bytes| { + let text = String::from_utf8_lossy(&bytes); + // Ollama sends newline-delimited JSON - each line is a separate JSON object + let mut accumulated = String::new(); + for ch in text.chars() { + if ch == '\n' { + let line = accumulated.trim(); + if !line.is_empty() { + match serde_json::from_str::(line) { + Ok(chunk) => return Ok(chunk), + Err(e) => { + log::debug!("Failed to parse Ollama stream chunk: {}", e); + } + } + } + accumulated.clear(); + } else { + accumulated.push(ch); + } + } + // Handle any remaining data without trailing newline + let line = accumulated.trim(); + if !line.is_empty() { + match serde_json::from_str::(line) { + Ok(chunk) => return Ok(chunk), + Err(e) => { + log::debug!("Failed to parse Ollama stream chunk: {}", e); + } + } + } + // If no valid chunk found, skip + Ok(StreamChunk::Partial { + model: String::new(), + message: ChatMessage { + role: String::new(), + content: String::new(), + }, + done: false, + }) + }) + }); + + Ok(stream) + } + + /// Create a chat message from a role and content. + pub fn message(role: impl Into, content: impl Into) -> ChatMessage { + ChatMessage { + role: role.into(), + content: content.into(), + } + } +} + +impl Default for OllamaClient { + fn default() -> Self { + Self::new() + } +} + +/// Extension trait for easily creating messages. +pub trait MessageExt { + fn user(content: impl Into) -> ChatMessage; + fn assistant(content: impl Into) -> ChatMessage; + fn system(content: impl Into) -> ChatMessage; +} + +impl MessageExt for ChatMessage { + fn user(content: impl Into) -> ChatMessage { + Self::message("user", content) + } + + fn assistant(content: impl Into) -> ChatMessage { + Self::message("assistant", content) + } + + fn system(content: impl Into) -> ChatMessage { + Self::message("system", content) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chat_message_creation() { + let msg = ChatMessage::user("Hello, world!"); + assert_eq!(msg.role, "user"); + assert_eq!(msg.content, "Hello, world!"); + } + + #[test] + fn test_chat_message_from_trait() { + let msg = ChatMessage::system("You are a helpful assistant."); + assert_eq!(msg.role, "system"); + } + + #[test] + fn test_serialize_chat_request() { + let request = ChatRequest { + model: "llama3".to_string(), + messages: vec![ChatMessage::user("Hi")], + stream: Some(true), + }; + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("\"model\":\"llama3\"")); + assert!(json.contains("\"stream\":true")); + } + + #[tokio::test] + async fn test_client_creation() { + let client = OllamaClient::new(); + assert_eq!(client.base_url, DEFAULT_OLLAMA_URL); + } + + #[tokio::test] + async fn test_custom_base_url() { + let client = OllamaClient::with_base_url("http://192.168.1.100:11434"); + assert_eq!(client.base_url, "http://192.168.1.100:11434"); + } + + #[test] + fn test_parse_stream_chunk() { + let json = r#"{"model":"llama3","message":{"role":"assistant","content":"Hello"},"done":false}"#; + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert!(matches!(chunk, StreamChunk::Partial { .. })); + } + + #[test] + fn test_parse_complete_chunk() { + let json = r#"{"model":"llama3","message":{"role":"assistant","content":"Hello"},"done":true,"done_reason":"stop","eval_count":5}"#; + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert!(matches!(chunk, StreamChunk::Complete { .. })); + } + + #[test] + fn test_parse_model_info() { + let json = r#"{"name":"llama3:latest","model":"llama3","modified_at":"2024-01-01T00:00:00Z","size":3826793472,"digest":"sha256:..."}"#; + let model: ModelInfo = serde_json::from_str(json).unwrap(); + assert_eq!(model.name, "llama3:latest"); + } +} \ No newline at end of file