diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 46b9550c..5993c75b 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -328,34 +328,42 @@ fn bad_request_response(message: &str) -> BoxResponse { .expect("failed to build bad request response") } -fn parse_host_header(headers: &HeaderMap) -> Result { - let Some(host) = headers.get(http::header::HOST) else { - tracing::warn!("rejected request with missing Host header"); - return Err(bad_request_response("Bad Request: missing Host header")); - }; - - let host_str = host - .to_str() - .inspect_err(|_| { - tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header"); - }) - .map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?; - let authority = http::uri::Authority::try_from(host_str) - .inspect_err(|_| { - tracing::warn!( - host = host_str, - "rejected request with malformed Host header" - ); - }) - .map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?; +fn parse_host_header( + uri: &http::Uri, + headers: &HeaderMap, +) -> Result { + if let Some(host) = headers.get(http::header::HOST) { + let host_str = host + .to_str() + .inspect_err(|_| { + tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header"); + }) + .map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?; + let authority = http::uri::Authority::try_from(host_str) + .inspect_err(|_| { + tracing::warn!( + host = host_str, + "rejected request with malformed Host header" + ); + }) + .map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?; + return Ok(normalize_authority(authority.host(), authority.port_u16())); + } + // HTTP/2 carries the host in `:authority`; middleware such as + // `axum::Router::nest` can drop the `Host` header hyper synthesizes from it. + let authority = uri.authority().ok_or_else(|| { + tracing::warn!("rejected request with missing Host header and no :authority"); + bad_request_response("Bad Request: missing Host header") + })?; Ok(normalize_authority(authority.host(), authority.port_u16())) } fn validate_dns_rebinding_headers( + uri: &http::Uri, headers: &HeaderMap, config: &StreamableHttpServerConfig, ) -> Result<(), BoxResponse> { - let host = parse_host_header(headers)?; + let host = parse_host_header(uri, headers)?; if !host_is_allowed(&host, &config.allowed_hosts) { tracing::warn!( host = ?host, @@ -806,7 +814,9 @@ where B: Body + Send + 'static, B::Error: Display, { - if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) { + if let Err(response) = + validate_dns_rebinding_headers(request.uri(), request.headers(), &self.config) + { return response; } let method = request.method().clone(); diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs index d2d53682..9b9dfc05 100644 --- a/crates/rmcp/tests/test_custom_headers.rs +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -1031,6 +1031,95 @@ async fn test_server_validates_host_header_port_for_dns_rebinding_protection() { assert_eq!(response.status(), http::StatusCode::FORBIDDEN); } +/// Integration test: Verify the validator falls back to the URI authority when +/// the Host header is absent (HTTP/2 :authority pseudo-header scenario). +#[tokio::test] +#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))] +async fn test_server_falls_back_to_uri_authority_when_host_header_missing() { + use std::sync::Arc; + + use bytes::Bytes; + use http::{Method, Request, header::CONTENT_TYPE}; + use http_body_util::Full; + use rmcp::{ + handler::server::ServerHandler, + model::{ServerCapabilities, ServerInfo}, + transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, + }; + use serde_json::json; + + #[derive(Clone)] + struct TestHandler; + + impl ServerHandler for TestHandler { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().build()) + } + } + + let service = StreamableHttpService::new( + || Ok(TestHandler), + Arc::new(LocalSessionManager::default()), + StreamableHttpServerConfig::default(), + ); + + let init_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }); + + // Allowed authority via URI only — no Host header. + let allowed_request = Request::builder() + .method(Method::POST) + .uri("http://localhost:8080/") + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + assert!(allowed_request.headers().get("Host").is_none()); + + let response = service.handle(allowed_request).await; + assert_eq!(response.status(), http::StatusCode::OK); + + // Disallowed authority via URI only — no Host header. + let bad_request = Request::builder() + .method(Method::POST) + .uri("http://attacker.example/") + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + assert!(bad_request.headers().get("Host").is_none()); + + let response = service.handle(bad_request).await; + assert_eq!(response.status(), http::StatusCode::FORBIDDEN); + + // Neither Host header nor URI authority — still a 400. + let missing_request = Request::builder() + .method(Method::POST) + .uri("/") + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + assert!(missing_request.headers().get("Host").is_none()); + assert!(missing_request.uri().authority().is_none()); + + let response = service.handle(missing_request).await; + assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); +} + #[cfg(all(feature = "transport-streamable-http-server", feature = "server"))] mod origin_validation { use std::sync::Arc;