diff --git a/src/server/handlers.rs b/src/server/handlers.rs index d2ca84c..b1ef2e0 100644 --- a/src/server/handlers.rs +++ b/src/server/handlers.rs @@ -1,3 +1,4 @@ +use axum::body::HttpBody; use axum::{body::Body, extract::State, http::HeaderMap, response::Response}; use base64::Engine; use bytes::{Bytes, BytesMut}; @@ -29,39 +30,28 @@ pub async fn dispatch( ) -> Result { let span = tracing::Span::current(); - let body_bytes = axum::body::to_bytes(body, MAX_UPLOAD_BODY_SIZE) - .await - .map_err(|e| ServerError::bad_request(format!("failed to read body: {e}")))?; - let has_x_target = headers.get("X-Target").is_some(); let session_cookie = utils::extract_cookie_value(&headers, "session"); if let Some(cookie_val) = session_cookie && state.streams.contains_key(cookie_val) { - if body_bytes.is_empty() { + if is_body_empty(&headers, &body) { return handle_download_continuation(state, cookie_val, span).await; - } else { - let session_id = cookie_val.split(':').next().unwrap_or(cookie_val); - if let Some(entry) = state.master_store.get(session_id) { - span.record("user", &entry.value().0); - } - return handle_stream_upload( - state, - cookie_val.to_owned(), - Body::from(body_bytes), - span, - ) - .await; } + let session_id = cookie_val.split(':').next().unwrap_or(cookie_val); + if let Some(entry) = state.master_store.get(session_id) { + span.record("user", &entry.value().0); + } + return handle_stream_upload(state, cookie_val.to_owned(), body, span).await; } if has_x_target { - return handle_plaintext_download(state, headers, body_bytes, span).await; + return handle_plaintext_download(state, headers, body, span).await; } if session_cookie.is_none() { - return handle_fresh_handshake(state, headers, Body::from(body_bytes), span).await; + return handle_fresh_handshake(state, headers, body, span).await; } let cookie_val = session_cookie.unwrap(); @@ -72,22 +62,19 @@ pub async fn dispatch( return Err(ServerError::precondition_required("session not found")); } - if !body_bytes.is_empty() && !state.streams.contains_key(cookie_val) { - return handle_pq_download(state, cookie_val, body_bytes, span).await; - } + let body_bytes = axum::body::to_bytes(body, MAX_UPLOAD_BODY_SIZE) + .await + .map_err(|e| ServerError::bad_request(format!("failed to read body: {e}")))?; - if body_bytes.is_empty() { - handle_pq_download(state, cookie_val, body_bytes, span).await - } else { - let user = &state - .master_store - .get(session_id) - .map(|e| e.value().0.clone()) - .unwrap_or_default(); - span.record("user", user); - let upload_body = Body::from(body_bytes); - handle_stream_upload(state, cookie_val.to_owned(), upload_body, span).await + handle_pq_download(state, cookie_val, body_bytes, span).await +} + +#[inline] +fn is_body_empty(headers: &HeaderMap, body: &Body) -> bool { + if let Some(cl) = headers.get("content-length").and_then(|v| v.to_str().ok()) { + return cl == "0"; } + body.is_end_stream() } #[inline] @@ -106,12 +93,16 @@ fn build_download_response( async fn handle_plaintext_download( state: Arc, headers: HeaderMap, - early_data: Bytes, + body: Body, span: tracing::Span, ) -> Result { let user = validate_jwt_if_needed(&headers, false, &state.decoding_key, &state.jwt_validation)?; span.record("user", &user); + let early_data = axum::body::to_bytes(body, MAX_UPLOAD_BODY_SIZE) + .await + .map_err(|e| ServerError::bad_request(format!("failed to read body: {e}")))?; + let target = headers .get("X-Target") .and_then(|v| v.to_str().ok())