diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 747a15e0..913009ce 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -66,9 +66,16 @@ impl SessionManager for LocalSessionManager { Ok(response) } async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> { - let mut sessions = self.sessions.write().await; - if let Some(handle) = sessions.remove(id) { - handle.close().await?; + let handle = { + let mut sessions = self.sessions.write().await; + sessions.remove(id) + }; + if let Some(handle) = handle { + match handle.close().await { + // Worker already exited — nothing left to clean up. + Ok(()) | Err(SessionError::SessionServiceTerminated) => {} + Err(e) => return Err(e.into()), + } } Ok(()) } @@ -928,8 +935,6 @@ pub enum LocalSessionWorkerError { FailToSendInitializeRequest(SessionError), #[error("fail to handle message: {0}")] FailToHandleMessage(SessionError), - #[error("keep alive timeout after {}ms", _0.as_millis())] - KeepAliveTimeout(Duration), #[error("Transport closed")] TransportClosed, #[error("Tokio join error {0}")] @@ -1008,7 +1013,7 @@ impl Worker for LocalSessionWorker { return Err(WorkerQuitReason::Cancelled) } _ = keep_alive_timeout => { - return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event")) + return Err(WorkerQuitReason::IdleTimeout(keep_alive)) } }; match event { diff --git a/crates/rmcp/src/transport/worker.rs b/crates/rmcp/src/transport/worker.rs index a5d722d4..090381d0 100644 --- a/crates/rmcp/src/transport/worker.rs +++ b/crates/rmcp/src/transport/worker.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::{borrow::Cow, time::Duration}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, Level}; @@ -22,6 +22,8 @@ pub enum WorkerQuitReason { TransportClosed, #[error("Handler terminated")] HandlerTerminated, + #[error("Worker idle timeout ({}ms)", _0.as_millis())] + IdleTimeout(Duration), } impl WorkerQuitReason { @@ -122,7 +124,8 @@ impl WorkerTransport { .inspect_err(|e| match e { WorkerQuitReason::Cancelled | WorkerQuitReason::TransportClosed - | WorkerQuitReason::HandlerTerminated => { + | WorkerQuitReason::HandlerTerminated + | WorkerQuitReason::IdleTimeout(_) => { tracing::debug!("worker quit with reason: {:?}", e); } WorkerQuitReason::Join(e) => { diff --git a/crates/rmcp/tests/test_streamable_http_idle_timeout_log.rs b/crates/rmcp/tests/test_streamable_http_idle_timeout_log.rs new file mode 100644 index 00000000..6f20864f --- /dev/null +++ b/crates/rmcp/tests/test_streamable_http_idle_timeout_log.rs @@ -0,0 +1,246 @@ +#![cfg(all( + feature = "transport-streamable-http-server", + feature = "transport-streamable-http-client-reqwest", + not(feature = "local") +))] + +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; + +use rmcp::transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, + session::{SessionManager, local::LocalSessionManager}, +}; +use tokio_util::sync::CancellationToken; +use tracing_subscriber::layer::SubscriberExt; + +mod common; +use common::calculator::Calculator; + +struct CapturedEvent { + level: tracing::Level, + message: String, +} + +struct CapturingLayer { + events: Arc>>, +} + +impl tracing_subscriber::Layer for CapturingLayer { + fn on_event( + &self, + event: &tracing::Event<'_>, + _ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + let mut visitor = MessageVisitor(String::new()); + event.record(&mut visitor); + self.events.lock().unwrap().push(CapturedEvent { + level: *event.metadata().level(), + message: visitor.0, + }); + } +} + +struct MessageVisitor(String); + +impl tracing::field::Visit for MessageVisitor { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { + if field.name() == "message" { + self.0 = format!("{:?}", value); + } + } +} + +#[tokio::test(flavor = "current_thread")] +async fn test_keep_alive_timeout_does_not_emit_error_log() { + let events = Arc::new(Mutex::new(Vec::::new())); + + let subscriber = tracing_subscriber::registry().with(CapturingLayer { + events: events.clone(), + }); + + let _guard = tracing::subscriber::set_default(subscriber); + + let ct = CancellationToken::new(); + let mut session_manager = LocalSessionManager::default(); + session_manager.session_config.keep_alive = Some(Duration::from_millis(200)); + let session_manager = Arc::new(session_manager); + + let service = StreamableHttpService::new( + || Ok(Calculator::new()), + session_manager.clone(), + StreamableHttpServerConfig::default() + .with_sse_keep_alive(None) + .with_cancellation_token(ct.child_token()), + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let client = reqwest::Client::new(); + + let response = client + .post(format!("http://{addr}/mcp")) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#) + .send() + .await + .unwrap(); + assert_eq!(response.status(), 200); + let session_id = response.headers()["mcp-session-id"] + .to_str() + .unwrap() + .to_string(); + + client + .post(format!("http://{addr}/mcp")) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .header("mcp-session-id", &session_id) + .header("Mcp-Protocol-Version", "2025-06-18") + .body(r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#) + .send() + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(400)).await; + + // Wait until close_session() has completed so all logs are captured. + let session_id_parsed: Arc = Arc::from(session_id.as_str()); + for _ in 0..20 { + if !session_manager + .has_session(&session_id_parsed) + .await + .unwrap() + { + break; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + assert!( + !session_manager + .has_session(&session_id_parsed) + .await + .unwrap(), + "session should have been removed after idle reap" + ); + + let captured = events.lock().unwrap(); + + let error_events: Vec<_> = captured + .iter() + .filter(|e| e.level == tracing::Level::ERROR) + .collect(); + assert!( + error_events.is_empty(), + "idle reap should not produce any ERROR logs, found {}: {:?}", + error_events.len(), + error_events.iter().map(|e| &e.message).collect::>() + ); + + let debug_events: Vec<_> = captured + .iter() + .filter(|e| e.level == tracing::Level::DEBUG && e.message.contains("IdleTimeout")) + .collect(); + assert!( + !debug_events.is_empty(), + "expected a DEBUG log with IdleTimeout, but found none" + ); + + ct.cancel(); +} + +#[tokio::test(flavor = "current_thread")] +async fn test_explicit_close_on_live_session_succeeds() { + let ct = CancellationToken::new(); + let mut session_manager = LocalSessionManager::default(); + session_manager.session_config.keep_alive = Some(Duration::from_secs(60)); + let session_manager = Arc::new(session_manager); + + let service = StreamableHttpService::new( + || Ok(Calculator::new()), + session_manager.clone(), + StreamableHttpServerConfig::default() + .with_sse_keep_alive(None) + .with_cancellation_token(ct.child_token()), + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let client = reqwest::Client::new(); + + let response = client + .post(format!("http://{addr}/mcp")) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#) + .send() + .await + .unwrap(); + assert_eq!(response.status(), 200); + let session_id = response.headers()["mcp-session-id"] + .to_str() + .unwrap() + .to_string(); + + client + .post(format!("http://{addr}/mcp")) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .header("mcp-session-id", &session_id) + .header("Mcp-Protocol-Version", "2025-06-18") + .body(r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#) + .send() + .await + .unwrap(); + + let session_id_parsed: Arc = Arc::from(session_id.as_str()); + + assert!( + session_manager + .has_session(&session_id_parsed) + .await + .unwrap(), + "session should exist before explicit close" + ); + + let result = session_manager.close_session(&session_id_parsed).await; + assert!( + result.is_ok(), + "close_session on a live worker should succeed: {result:?}" + ); + + assert!( + !session_manager + .has_session(&session_id_parsed) + .await + .unwrap(), + "session should not exist after explicit close" + ); + + ct.cancel(); +}