diff --git a/src/server/handlers.rs b/src/server/handlers.rs index eddf4db..bf3e2b8 100644 --- a/src/server/handlers.rs +++ b/src/server/handlers.rs @@ -4,6 +4,7 @@ use base64::Engine; use bytes::{Bytes, BytesMut}; use futures::StreamExt; use jsonwebtoken::{DecodingKey, Validation}; +use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex}; use tokio::io::AsyncWriteExt; use tokio::sync::{mpsc, oneshot}; @@ -55,6 +56,7 @@ fn spawn_stream_response( encoding: state.traffic_config.encoding_type, max_download_bytes: state.traffic_config.max_download_bytes, handoff_tx: Mutex::new(None), + handoff_done: AtomicBool::new(false), }); match state.streams.entry(map_key.clone()) { diff --git a/src/server/state.rs b/src/server/state.rs index bf3677b..7f6a221 100644 --- a/src/server/state.rs +++ b/src/server/state.rs @@ -90,6 +90,7 @@ pub struct StreamBundle { pub encoding: EncodingType, pub max_download_bytes: Option, pub(crate) handoff_tx: Mutex>>, + pub(crate) handoff_done: AtomicBool, } impl StreamBundle { @@ -120,6 +121,7 @@ pub struct DownloadStream { impl DownloadStream { fn release_upstream(&self) { + self.bundle.handoff_done.store(true, Ordering::Release); if let Ok(mut guard) = self.bundle.handoff_tx.lock() && let Some(tx) = guard.take() { @@ -140,9 +142,17 @@ impl Stream for DownloadStream { if let Some(rx) = &mut this.handoff_rx { match rx.as_mut().poll(cx) { Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => { + this.bundle.handoff_done.store(false, Ordering::Release); this.handoff_rx = None; } - Poll::Pending => return Poll::Pending, + Poll::Pending => { + if this.bundle.handoff_done.load(Ordering::Acquire) { + this.bundle.handoff_done.store(false, Ordering::Release); + this.handoff_rx = None; + } else { + return Poll::Pending; + } + } } }