Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion crates/goat-agent/src/compaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ struct Measured {
#[derive(Default)]
pub(crate) struct ContextTracker {
measured: Option<Measured>,
tool_def_cache: std::cell::Cell<Option<(usize, usize, usize, u32)>>,
}

impl ContextTracker {
Expand All @@ -29,12 +30,30 @@ impl ContextTracker {
self.measured = None;
}

fn tool_def_tokens(&self, tool_defs: &[ToolDefinition]) -> u32 {
let fingerprint: usize = tool_defs
.iter()
.map(|def| def.name.len().wrapping_add(def.description.len()))
.fold(0usize, usize::wrapping_add);
let key = (tool_defs.as_ptr() as usize, tool_defs.len(), fingerprint);
if let Some((ptr, len, fp, tokens)) = self.tool_def_cache.get()
&& (ptr, len, fp) == key
{
return tokens;
}
let tokens = estimate_tool_defs(tool_defs);
self.tool_def_cache.set(Some((key.0, key.1, key.2, tokens)));
tokens
}

pub(crate) fn estimate(&self, messages: &[Message], tool_defs: &[ToolDefinition]) -> u32 {
match &self.measured {
Some(measured) if measured.history_len <= messages.len() => measured
.tokens
.saturating_add(estimate_messages(&messages[measured.history_len..])),
_ => estimate_tool_defs(tool_defs).saturating_add(estimate_messages(messages)),
_ => self
.tool_def_tokens(tool_defs)
.saturating_add(estimate_messages(messages)),
}
}
}
Expand Down Expand Up @@ -516,6 +535,22 @@ mod tests {
assert!((100..=120).contains(&estimate), "got {estimate}");
}

#[test]
fn tool_def_estimate_is_cached_and_consistent() {
use goat_provider::ToolDefinition;
let tracker = ContextTracker::new();
let tool_defs = vec![ToolDefinition {
name: "Read".to_owned(),
description: "read a file".to_owned(),
input_schema: serde_json::json!({"type":"object","properties":{"path":{"type":"string"}}}),
}];
let messages = vec![Message::text(MessageRole::User, "hi")];
let first = tracker.estimate(&messages, &tool_defs);
let second = tracker.estimate(&messages, &tool_defs);
assert_eq!(first, second);
assert!(first > tracker.estimate(&messages, &[]));
}

#[test]
fn measured_estimate_adds_only_the_delta() {
let mut tracker = ContextTracker::new();
Expand Down
11 changes: 9 additions & 2 deletions crates/goat-agent/src/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,20 @@ pub(crate) async fn run_round_with_retry(
ctx: &Ctx<'_>,
run: &Run<'_>,
env: &LoopEnv<'_>,
request: &Request,
messages: &[goat_provider::Message],
token: &CancellationToken,
) -> RoundResult {
let started = Instant::now();
let mut attempt = 1u32;
loop {
let result = run_round(ctx, run, env.provider, request.clone(), token).await;
let request = Request {
model: env.target.model.clone(),
messages: messages.to_vec(),
tools: env.tool_defs.to_vec(),
effort: env.target.effort,
tool_choice: goat_provider::ToolChoice::Auto,
};
let result = run_round(ctx, run, env.provider, request, token).await;
let RoundEnd::Failed(error) = &result.end else {
return result;
};
Expand Down
12 changes: 3 additions & 9 deletions crates/goat-agent/src/rounds.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use goat_protocol::Event;
use goat_provider::{
ContentBlock, Message, MessageRole, Provider, Request, StreamError, StreamEvent, ToolChoice,
ContentBlock, Message, MessageRole, Provider, Request, StreamError, StreamEvent,
};
use goat_tool::ToolContext;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -430,14 +430,8 @@ pub(crate) async fn core_loop(
}
}
}
let request = Request {
model: env.target.model.clone(),
messages: conversation.messages().to_vec(),
tools: env.tool_defs.to_vec(),
effort: env.target.effort,
tool_choice: ToolChoice::Auto,
};
let round = crate::retry::run_round_with_retry(ctx, run, env, &request, token).await;
let round =
crate::retry::run_round_with_retry(ctx, run, env, conversation.messages(), token).await;
match &round.end {
RoundEnd::Cancelled => return LoopOutcome::Cancelled,
RoundEnd::Failed(StreamError::ContextOverflow { .. }) if !compacted_for_overflow => {
Expand Down
174 changes: 166 additions & 8 deletions crates/goat-auth/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{
collections::HashMap,
fmt, fs,
path::PathBuf,
path::{Path, PathBuf},
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};

Expand Down Expand Up @@ -93,6 +95,25 @@ impl TokenSet {
}
}

fn refresh_locks() -> &'static std::sync::Mutex<HashMap<CredentialKey, Arc<tokio::sync::Mutex<()>>>>
{
static LOCKS: std::sync::OnceLock<
std::sync::Mutex<HashMap<CredentialKey, Arc<tokio::sync::Mutex<()>>>>,
> = std::sync::OnceLock::new();
LOCKS.get_or_init(|| std::sync::Mutex::new(HashMap::new()))
}

fn refresh_lock_for(key: &CredentialKey) -> Arc<tokio::sync::Mutex<()>> {
let mut map = refresh_locks()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
map.entry(key.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}

const REFRESH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);

pub async fn ensure_valid<F, Fut>(
tokens: TokenSet,
store: &CredentialStore,
Expand All @@ -106,18 +127,30 @@ where
if !tokens.is_expired() {
return Some(tokens);
}
let lock = refresh_lock_for(key);
let _guard = lock.lock().await;
if let Some(Credential::OAuth(current)) = store.file_get(key) {
let changed = current.access_token.expose() != tokens.access_token.expose();
if changed && !current.is_expired() {
return Some(current);
}
}
let refresh_token = tokens.refresh_token.as_ref()?.expose().to_owned();
match refresh(refresh_token).await {
Ok(fresh) => {
match tokio::time::timeout(REFRESH_TIMEOUT, refresh(refresh_token)).await {
Ok(Ok(fresh)) => {
if let Err(err) = store.store(key, Credential::OAuth(fresh.clone())) {
tracing::warn!(%err, "failed to persist refreshed oauth tokens");
}
Some(fresh)
}
Err(err) => {
Ok(Err(err)) => {
tracing::warn!(%err, "token refresh failed; treating as logged out");
None
}
Err(_) => {
tracing::warn!("token refresh timed out; treating as logged out");
None
}
}
}

Expand Down Expand Up @@ -290,6 +323,24 @@ pub struct CredentialStore {
path: PathBuf,
}

struct TempCleanup {
path: Option<PathBuf>,
}

impl TempCleanup {
fn disarm(mut self) {
self.path = None;
}
}

impl Drop for TempCleanup {
fn drop(&mut self) {
if let Some(path) = self.path.take() {
let _ = fs::remove_file(path);
}
}
}

impl CredentialStore {
pub fn new(path: PathBuf) -> Self {
Self { path }
Expand Down Expand Up @@ -359,11 +410,40 @@ impl CredentialStore {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(&self.path, serde_json::to_string_pretty(file)?)?;
#[cfg(unix)]
let contents = serde_json::to_string_pretty(file)?;
let parent = self.path.parent().unwrap_or_else(|| Path::new("."));
let file_name = self.path.file_name().map_or_else(
|| "auth.json".to_owned(),
|name| name.to_string_lossy().into_owned(),
);
let tmp_path = parent.join(format!("{file_name}.tmp-{}", std::process::id()));
let cleanup = TempCleanup {
path: Some(tmp_path.clone()),
};
{
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(&self.path, fs::Permissions::from_mode(0o600))?;
let mut options = fs::OpenOptions::new();
options.write(true).create_new(true);
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
options.mode(0o600);
}
let mut handle = match options.open(&tmp_path) {
Ok(handle) => handle,
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
let _ = fs::remove_file(&tmp_path);
options.open(&tmp_path)?
}
Err(err) => return Err(err.into()),
};
std::io::Write::write_all(&mut handle, contents.as_bytes())?;
handle.sync_all()?;
}
fs::rename(&tmp_path, &self.path)?;
cleanup.disarm();
#[cfg(unix)]
if let Ok(dir) = fs::File::open(parent) {
let _ = dir.sync_all();
}
Ok(())
}
Expand Down Expand Up @@ -395,8 +475,56 @@ impl CredentialStore {
mod tests {
use super::{
Credential, CredentialKey, CredentialKind, CredentialStore, Pkce, SecretString, TokenSet,
ensure_valid, now_secs,
};

#[tokio::test]
async fn ensure_valid_single_flights_concurrent_refresh() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let path = std::env::temp_dir().join("goat-auth-singleflight-test.json");
let _ = std::fs::remove_file(&path);
let store = CredentialStore::new(path.clone());
let key = CredentialKey {
provider: "goat-singleflight".into(),
account: "a".into(),
};
let expired = TokenSet {
access_token: SecretString::from("old"),
refresh_token: Some(SecretString::from("refresh")),
expires_at: Some(now_secs() - 100),
};
let calls = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..8 {
let store = store.clone();
let key = key.clone();
let tokens = expired.clone();
let calls = calls.clone();
handles.push(tokio::spawn(async move {
ensure_valid(tokens, &store, &key, |_| {
let calls = calls.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
Ok(TokenSet {
access_token: SecretString::from("new"),
refresh_token: Some(SecretString::from("refresh2")),
expires_at: Some(now_secs() + 3600),
})
}
})
.await
}));
}
for handle in handles {
let result = handle.await.unwrap();
assert!(matches!(result, Some(t) if t.access_token.expose() == "new"));
}
assert_eq!(calls.load(Ordering::SeqCst), 1);
let _ = std::fs::remove_file(&path);
}

#[test]
fn pkce_generates_s256_challenge() {
use base64::Engine;
Expand Down Expand Up @@ -428,6 +556,36 @@ mod tests {
assert_eq!(cred.kind(), CredentialKind::ApiKey);
}

#[cfg(unix)]
#[test]
fn saved_file_is_owner_only_and_atomic() {
use std::os::unix::fs::PermissionsExt;
let path = std::env::temp_dir().join("goat-auth-perms-test.json");
let _ = std::fs::remove_file(&path);
let store = CredentialStore::new(path.clone());
let key = CredentialKey {
provider: "p".into(),
account: "a".into(),
};
store
.file_set(&key, Credential::ApiKey(SecretString::from("secret")))
.unwrap();
let mode = std::fs::metadata(&path).unwrap().permissions().mode();
assert_eq!(mode & 0o777, 0o600);
let got = store.file_get(&key).unwrap();
assert!(matches!(got, Credential::ApiKey(secret) if secret.expose() == "secret"));
let leftover = std::fs::read_dir(path.parent().unwrap())
.unwrap()
.filter_map(Result::ok)
.any(|e| {
e.file_name()
.to_string_lossy()
.contains("goat-auth-perms-test.json.tmp-")
});
assert!(!leftover, "temp file should be cleaned up");
let _ = std::fs::remove_file(&path);
}

#[test]
fn file_store_roundtrip() {
let path = std::env::temp_dir().join("goat-auth-file-roundtrip-test.json");
Expand Down
11 changes: 1 addition & 10 deletions crates/goat-daemon/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,7 @@ async fn dispatch(
ClientFrame::Hello { .. } => true,
ClientFrame::OpenSession { cwd, resume } => {
let cwd_path = PathBuf::from(&cwd);
let existing = if matches!(resume, goat_wire::ResumeMode::Latest) {
manager.find_live_by_cwd(&cwd_path).await
} else {
None
};
let opened = match existing {
Some(session) => Ok(session),
None => manager.open_session(cwd_path, resume).await,
};
match opened {
match manager.open_or_attach(cwd_path, resume).await {
Ok(session) => {
let _ = out_tx
.send(ServerFrame::SessionOpened {
Expand Down
Loading
Loading