diff --git a/volo-http/src/server/response/sse.rs b/volo-http/src/server/response/sse.rs index 3588def7..a0ffa65e 100644 --- a/volo-http/src/server/response/sse.rs +++ b/volo-http/src/server/response/sse.rs @@ -20,6 +20,22 @@ use tokio::time::{Instant, Sleep}; use super::IntoResponse; use crate::{body::Body, error::BoxError, response::Response}; +/// Extension trait for [`Response`] to check if it's an SSE response. +pub trait ResponseExt { + /// Check if the response is an SSE response by checking `Content-Type` header. + fn is_sse(&self) -> bool; +} + +impl ResponseExt for Response { + fn is_sse(&self) -> bool { + self.headers() + .get(header::CONTENT_TYPE) // Get the Content-Type header + .and_then(|v| v.to_str().ok()) // Convert header value to &str + .map(|v| v.starts_with(mime::TEXT_EVENT_STREAM.essence_str())) // Check SSE type + .unwrap_or(false) // Return false if header is missing or invalid + } +} + /// Response of [SSE][sse] (Server-Sent Events), inclusing a stream with SSE [`Event`]s. /// /// [sse]: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events diff --git a/volo-http/src/server/route/router.rs b/volo-http/src/server/route/router.rs index 938bc29f..7fe5fe5c 100644 --- a/volo-http/src/server/route/router.rs +++ b/volo-http/src/server/route/router.rs @@ -22,7 +22,7 @@ use crate::{ context::ServerContext, request::Request, response::Response, - server::{IntoResponse, handler::Handler}, + server::{IntoResponse, handler::Handler, response::sse::ResponseExt}, }; /// The router for routing path to [`Service`]s or handlers. @@ -450,7 +450,32 @@ where req: Request, ) -> Result { match self { - Self::MethodRouter(mr) => mr.call(cx, req).await, + Self::MethodRouter(mr) => { + // Determine if the client accepts SSE by checking the `Accept` header. + // If the header is missing, assume the client accepts SSE (default true). + let accepts_sse = req + .headers() + .get(http::header::ACCEPT) + .and_then(|v| v.to_str().ok()) + .map(|v| v.contains(mime::TEXT_EVENT_STREAM.essence_str())) + .unwrap_or(true); + + // Call the inner router and get the response. + let resp = mr.call(cx, req).await?; + + // If the client does not explicitly accept SSE but the response is SSE, + // return 415 Unsupported Media Type. + if !accepts_sse && resp.is_sse() { + return Ok(Response::builder() + .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body("Not Acceptable".into()) + .unwrap()); + } + + // Otherwise, return the response as-is. + Ok(resp) + } + Self::Service(service) => service.call(cx, req).await, } } @@ -489,14 +514,22 @@ where #[cfg(test)] mod router_tests { + use std::convert::Infallible; + + use async_stream::stream; use faststr::FastStr; - use http::{method::Method, status::StatusCode, uri::Uri}; + use futures::Stream; + use http::{header, method::Method, status::StatusCode, uri::Uri}; use super::Router; use crate::{ body::{Body, BodyConversion}, server::{ - Server, param::PathParamsVec, route::method_router::any, test_helpers::TestServer, + IntoResponse, Request, Server, + param::PathParamsVec, + response::sse::{Event, Sse}, + route::method_router::any, + test_helpers::TestServer, }, }; @@ -814,4 +847,68 @@ mod router_tests { .is_err() ); } + + async fn sse_handler() -> Sse>> { + let stream = stream! { + yield Ok(Event::new().event("ping").data("hello")); + }; + Sse::new(stream) + } + + async fn hello_handler() -> &'static str { + "Hello, World" + } + + async fn get_status( + server: &TestServer>, Option>, + uri: &str, + accept: Option<&'static str>, + ) -> StatusCode { + let mut builder = Request::builder().method(Method::GET).uri(uri); + if let Some(accept) = accept { + builder = builder.header(header::ACCEPT, accept); + } + server + .call_without_cx(builder.body(None).expect("Failed to build request")) + .await + .into_response() + .status() + } + + #[tokio::test] + async fn sse_accepted() { + let router: Router> = Router::new().route("/sse", any(sse_handler)); + let server = Server::new(router).into_test_server(); + assert_eq!( + get_status(&server, "/sse", Some(mime::TEXT_EVENT_STREAM.essence_str())).await, + StatusCode::OK + ); + } + + #[tokio::test] + async fn sse_not_accepted_returns_415() { + let router: Router> = Router::new().route("/sse", any(sse_handler)); + let server = Server::new(router).into_test_server(); + assert_eq!( + get_status(&server, "/sse", Some("application/json")).await, + StatusCode::UNSUPPORTED_MEDIA_TYPE + ); + } + + #[tokio::test] + async fn sse_no_accept_header_defaults_true() { + let router: Router> = Router::new().route("/sse", any(sse_handler)); + let server = Server::new(router).into_test_server(); + assert_eq!(get_status(&server, "/sse", None).await, StatusCode::OK); + } + + #[tokio::test] + async fn non_sse_response_not_blocked() { + let router: Router> = Router::new().route("/hello", any(hello_handler)); + let server = Server::new(router).into_test_server(); + assert_eq!( + get_status(&server, "/hello", Some("application/json")).await, + StatusCode::OK + ); + } }