Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions volo-http/src/server/response/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ use tokio::time::{Instant, Sleep};
use super::IntoResponse;
use crate::{body::Body, error::BoxError, response::Response};

/// Extension trait for [`Response`] to check if it's an SSE response.
pub trait ResponseExt {
/// Check if the response is an SSE response by checking `Content-Type` header.
fn is_sse(&self) -> bool;
}

impl ResponseExt for Response {
fn is_sse(&self) -> bool {
self.headers()
.get(header::CONTENT_TYPE) // Get the Content-Type header
.and_then(|v| v.to_str().ok()) // Convert header value to &str
.map(|v| v.starts_with(mime::TEXT_EVENT_STREAM.essence_str())) // Check SSE type
.unwrap_or(false) // Return false if header is missing or invalid
}
}

/// Response of [SSE][sse] (Server-Sent Events), inclusing a stream with SSE [`Event`]s.
///
/// [sse]: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events
Expand Down
105 changes: 101 additions & 4 deletions volo-http/src/server/route/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
context::ServerContext,
request::Request,
response::Response,
server::{IntoResponse, handler::Handler},
server::{IntoResponse, handler::Handler, response::sse::ResponseExt},
};

/// The router for routing path to [`Service`]s or handlers.
Expand Down Expand Up @@ -450,7 +450,32 @@ where
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
match self {
Self::MethodRouter(mr) => mr.call(cx, req).await,
Self::MethodRouter(mr) => {
// Determine if the client accepts SSE by checking the `Accept` header.
// If the header is missing, assume the client accepts SSE (default true).
let accepts_sse = req
.headers()
.get(http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|v| v.contains(mime::TEXT_EVENT_STREAM.essence_str()))
.unwrap_or(true);

// Call the inner router and get the response.
let resp = mr.call(cx, req).await?;

// If the client does not explicitly accept SSE but the response is SSE,
// return 415 Unsupported Media Type.
if !accepts_sse && resp.is_sse() {
return Ok(Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body("Not Acceptable".into())
.unwrap());
}

// Otherwise, return the response as-is.
Ok(resp)
}

Self::Service(service) => service.call(cx, req).await,
}
}
Expand Down Expand Up @@ -489,14 +514,22 @@ where

#[cfg(test)]
mod router_tests {
use std::convert::Infallible;

use async_stream::stream;
use faststr::FastStr;
use http::{method::Method, status::StatusCode, uri::Uri};
use futures::Stream;
use http::{header, method::Method, status::StatusCode, uri::Uri};

use super::Router;
use crate::{
body::{Body, BodyConversion},
server::{
Server, param::PathParamsVec, route::method_router::any, test_helpers::TestServer,
IntoResponse, Request, Server,
param::PathParamsVec,
response::sse::{Event, Sse},
route::method_router::any,
test_helpers::TestServer,
},
};

Expand Down Expand Up @@ -814,4 +847,68 @@ mod router_tests {
.is_err()
);
}

async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let stream = stream! {
yield Ok(Event::new().event("ping").data("hello"));
};
Sse::new(stream)
}

async fn hello_handler() -> &'static str {
"Hello, World"
}

async fn get_status(
server: &TestServer<Router<Option<Body>>, Option<Body>>,
uri: &str,
accept: Option<&'static str>,
) -> StatusCode {
let mut builder = Request::builder().method(Method::GET).uri(uri);
if let Some(accept) = accept {
builder = builder.header(header::ACCEPT, accept);
}
server
.call_without_cx(builder.body(None).expect("Failed to build request"))
.await
.into_response()
.status()
}

#[tokio::test]
async fn sse_accepted() {
let router: Router<Option<Body>> = Router::new().route("/sse", any(sse_handler));
let server = Server::new(router).into_test_server();
assert_eq!(
get_status(&server, "/sse", Some(mime::TEXT_EVENT_STREAM.essence_str())).await,
StatusCode::OK
);
}

#[tokio::test]
async fn sse_not_accepted_returns_415() {
let router: Router<Option<Body>> = Router::new().route("/sse", any(sse_handler));
let server = Server::new(router).into_test_server();
assert_eq!(
get_status(&server, "/sse", Some("application/json")).await,
StatusCode::UNSUPPORTED_MEDIA_TYPE
);
}

#[tokio::test]
async fn sse_no_accept_header_defaults_true() {
let router: Router<Option<Body>> = Router::new().route("/sse", any(sse_handler));
let server = Server::new(router).into_test_server();
assert_eq!(get_status(&server, "/sse", None).await, StatusCode::OK);
}

#[tokio::test]
async fn non_sse_response_not_blocked() {
let router: Router<Option<Body>> = Router::new().route("/hello", any(hello_handler));
let server = Server::new(router).into_test_server();
assert_eq!(
get_status(&server, "/hello", Some("application/json")).await,
StatusCode::OK
);
}
}
Loading