Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,34 +328,42 @@ fn bad_request_response(message: &str) -> BoxResponse {
.expect("failed to build bad request response")
}

fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
let Some(host) = headers.get(http::header::HOST) else {
tracing::warn!("rejected request with missing Host header");
return Err(bad_request_response("Bad Request: missing Host header"));
};

let host_str = host
.to_str()
.inspect_err(|_| {
tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header");
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host_str)
.inspect_err(|_| {
tracing::warn!(
host = host_str,
"rejected request with malformed Host header"
);
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
fn parse_host_header(
uri: &http::Uri,
headers: &HeaderMap,
) -> Result<NormalizedAuthority, BoxResponse> {
if let Some(host) = headers.get(http::header::HOST) {
let host_str = host
.to_str()
.inspect_err(|_| {
tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header");
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host_str)
.inspect_err(|_| {
tracing::warn!(
host = host_str,
"rejected request with malformed Host header"
);
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
return Ok(normalize_authority(authority.host(), authority.port_u16()));
}
// HTTP/2 carries the host in `:authority`; middleware such as
// `axum::Router::nest` can drop the `Host` header hyper synthesizes from it.
let authority = uri.authority().ok_or_else(|| {
tracing::warn!("rejected request with missing Host header and no :authority");
bad_request_response("Bad Request: missing Host header")
})?;
Ok(normalize_authority(authority.host(), authority.port_u16()))
}

fn validate_dns_rebinding_headers(
uri: &http::Uri,
headers: &HeaderMap,
config: &StreamableHttpServerConfig,
) -> Result<(), BoxResponse> {
let host = parse_host_header(headers)?;
let host = parse_host_header(uri, headers)?;
if !host_is_allowed(&host, &config.allowed_hosts) {
tracing::warn!(
host = ?host,
Expand Down Expand Up @@ -806,7 +814,9 @@ where
B: Body + Send + 'static,
B::Error: Display,
{
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
if let Err(response) =
validate_dns_rebinding_headers(request.uri(), request.headers(), &self.config)
{
return response;
}
let method = request.method().clone();
Expand Down
89 changes: 89 additions & 0 deletions crates/rmcp/tests/test_custom_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,95 @@ async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}

/// Integration test: Verify the validator falls back to the URI authority when
/// the Host header is absent (HTTP/2 :authority pseudo-header scenario).
#[tokio::test]
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
async fn test_server_falls_back_to_uri_authority_when_host_header_missing() {
use std::sync::Arc;

use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;

#[derive(Clone)]
struct TestHandler;

impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}

let service = StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
);

let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});

// Allowed authority via URI only — no Host header.
let allowed_request = Request::builder()
.method(Method::POST)
.uri("http://localhost:8080/")
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
assert!(allowed_request.headers().get("Host").is_none());

let response = service.handle(allowed_request).await;
assert_eq!(response.status(), http::StatusCode::OK);

// Disallowed authority via URI only — no Host header.
let bad_request = Request::builder()
.method(Method::POST)
.uri("http://attacker.example/")
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
assert!(bad_request.headers().get("Host").is_none());

let response = service.handle(bad_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);

// Neither Host header nor URI authority — still a 400.
let missing_request = Request::builder()
.method(Method::POST)
.uri("/")
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
assert!(missing_request.headers().get("Host").is_none());
assert!(missing_request.uri().authority().is_none());

let response = service.handle(missing_request).await;
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
}

#[cfg(all(feature = "transport-streamable-http-server", feature = "server"))]
mod origin_validation {
use std::sync::Arc;
Expand Down
Loading