Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,19 @@ async fn handle_plain_proxy(
.instrument(tracing::Span::current()),
);

let download_fut = tunnel::download_loop(response, write_half, None, encoding);
let download_http_client = Arc::clone(&http_client);
let download_state = Arc::clone(&state);
let cookie_val_for_dl = stream_id.clone();

let download_fut = tunnel::download_loop(
response,
write_half,
None,
encoding,
cookie_val_for_dl,
download_http_client,
download_state,
);
tokio::pin!(download_fut);

let result: Result<()> = tokio::select! {
Expand Down
1 change: 1 addition & 0 deletions src/client/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub const UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
pub const MAX_BATCH_BYTES: usize = 1024 * 1024;
pub const MAX_IN_FLIGHT_BYTES: usize = 2 * 1024 * 1024;
pub const UPLOAD_CONCURRENCY: usize = 128;
pub const PREFETCH_LEAD_BYTES: u64 = 20 * 1024 * 1024;

pub const DECODE_BUF_CAPACITY: usize = 16 * 1024 + 2396;

Expand Down
20 changes: 18 additions & 2 deletions src/client/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,15 @@ pub async fn try_pq_connect(
.instrument(tracing::Span::current()),
);

let download_fut = download_loop(response, write_half, Some(download_cipher), encoding);
let download_fut = download_loop(
response,
write_half,
Some(download_cipher),
encoding,
cookie_val.clone(),
Arc::clone(http_client),
Arc::clone(state),
);
tokio::pin!(download_fut);

let result: Result<()> = tokio::select! {
Expand Down Expand Up @@ -348,7 +356,15 @@ pub async fn full_handshake(
.instrument(tracing::Span::current()),
);

let download_fut = download_loop(response, write_half, Some(download_cipher), encoding);
let download_fut = download_loop(
response,
write_half,
Some(download_cipher),
encoding,
cookie_val.clone(),
Arc::clone(http_client),
Arc::clone(state),
);
tokio::pin!(download_fut);

let result: Result<()> = tokio::select! {
Expand Down
1 change: 1 addition & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ pub fn build_state(cfg: &ClientTopConfig) -> Result<Arc<state::SharedState>> {
proxy_auth,
initial_master: Mutex::new(None),
handshake_lock: OnceCell::new(),
max_download_bytes: cfg.traffic_shaping.max_download_bytes,
}))
}
1 change: 1 addition & 0 deletions src/client/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ pub struct SharedState {
pub proxy_auth: Option<(String, String)>,
pub initial_master: Mutex<Option<InitialMasterEntry>>,
pub handshake_lock: OnceCell<tokio::sync::Mutex<()>>,
pub max_download_bytes: Option<u64>,
}
159 changes: 132 additions & 27 deletions src/client/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use futures::{FutureExt, StreamExt};
use http_body_util::BodyExt;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, oneshot};
use tokio::task::JoinSet;
use tracing::warn;

use crate::client::constants::{
DECODE_BUF_CAPACITY, MAX_BATCH_BYTES, MAX_IN_FLIGHT_BYTES, UPLOAD_CONCURRENCY,
UPLOAD_REQUEST_TIMEOUT,
DECODE_BUF_CAPACITY, DOWNLOAD_CONNECT_TIMEOUT, MAX_BATCH_BYTES, MAX_IN_FLIGHT_BYTES,
PREFETCH_LEAD_BYTES, UPLOAD_CONCURRENCY, UPLOAD_REQUEST_TIMEOUT,
};
use crate::client::utils;
use crate::crypto::AesFrameCipher;
Expand Down Expand Up @@ -190,37 +190,47 @@ pub async fn upload_loop(
Ok(())
}

pub async fn download_loop(
async fn download_single_response(
response: wreq::Response,
mut write_half: tokio::net::tcp::OwnedWriteHalf,
cipher: Option<Arc<AesFrameCipher>>,
write_half: &mut tokio::net::tcp::OwnedWriteHalf,
cipher: Option<&dyn FrameCipher>,
encoding: EncodingType,
) -> Result<()> {
start_seq: u64,
max_bytes: Option<u64>,
mut pre_fetch_trigger: Option<oneshot::Sender<()>>,
) -> Result<(u64, u64)> {
let pre_fetch_at = max_bytes.map(|m| m.saturating_sub(PREFETCH_LEAD_BYTES));

let mut buffer = BytesMut::with_capacity(DECODE_BUF_CAPACITY);
let mut data_stream = response.into_data_stream();
let cipher_ref: Option<&dyn FrameCipher> = cipher.as_deref().map(|c| c as &dyn FrameCipher);
let mut expected_seq: u64 = 0;
let mut expected_seq: u64 = start_seq;
let mut bytes_received: u64 = 0;

let result: Result<()> = async {
while let Some(chunk) = data_stream.next().await {
buffer.extend_from_slice(&chunk.context("response read error")?);
while let Some((seq, frame)) =
shaper::decode_from_buffer(&mut buffer, cipher_ref, encoding)?
{
if seq != expected_seq {
return Err(anyhow!(
"download frame seq {} out of order, expected {}",
seq,
expected_seq
));
}
expected_seq += 1;
write_half.write_all(&frame).await?;
while let Some(chunk) = data_stream.next().await {
let chunk = chunk.context("response read error")?;
bytes_received += chunk.len() as u64;

if let Some(at) = pre_fetch_at
&& bytes_received >= at
{
if let Some(tx) = pre_fetch_trigger.take() {
let _ = tx.send(());
}
}
Ok(())

buffer.extend_from_slice(&chunk);
while let Some((seq, frame)) = shaper::decode_from_buffer(&mut buffer, cipher, encoding)? {
if seq != expected_seq {
return Err(anyhow!(
"download frame seq {} out of order, expected {}",
seq,
expected_seq
));
}
expected_seq += 1;
write_half.write_all(&frame).await?;
}
}
.await;

if !buffer.is_empty() {
warn!(
Expand All @@ -229,6 +239,101 @@ pub async fn download_loop(
);
}

if let Some(tx) = pre_fetch_trigger.take() {
let _ = tx.send(());
}

Ok((bytes_received, expected_seq))
}

async fn send_continue_request(
http_client: &wreq::Client,
state: &SharedState,
cookie_val: &str,
) -> Result<wreq::Response> {
let mut cookie = String::new();
utils::build_tunnel_cookie(&mut cookie, cookie_val);

let mut req = http_client
.post(state.remote_str.as_str())
.header("Cookie", cookie);

if state.server_public_key.is_none() {
req = req.header("Authorization", state.auth_header.as_str());
}

let resp = tokio::time::timeout(DOWNLOAD_CONNECT_TIMEOUT, req.send())
.await
.context("continuation request timed out")?
.context("continuation request failed")?;

if !resp.status().is_success() {
let status = resp.status();
let _ = resp.bytes().await;
return Err(anyhow!("continuation rejected: {status}"));
}
Ok(resp)
}

pub async fn download_loop(
initial_response: wreq::Response,
mut write_half: tokio::net::tcp::OwnedWriteHalf,
cipher: Option<Arc<AesFrameCipher>>,
encoding: EncodingType,
cookie_val: String,
http_client: Arc<wreq::Client>,
state: Arc<SharedState>,
) -> Result<()> {
let cipher_dyn: Option<Arc<dyn FrameCipher>> = cipher.map(|c| c as Arc<dyn FrameCipher>);
let cipher_ref: Option<&dyn FrameCipher> = cipher_dyn.as_deref();

let max_bytes = state.max_download_bytes;
let rotate_enabled = max_bytes.is_some_and(|m| m > 0);

let mut response = initial_response;
let mut expected_seq: u64 = 0;

loop {
let (pre_fetch_trigger, pre_fetch_rx) = if rotate_enabled {
let (trigger_tx, trigger_rx) = oneshot::channel();
let (result_tx, result_rx) = oneshot::channel();
let pre_client = Arc::clone(&http_client);
let pre_state = Arc::clone(&state);
let pre_cookie = cookie_val.clone();
tokio::spawn(async move {
let _ = trigger_rx.await;
let result = send_continue_request(&pre_client, &pre_state, &pre_cookie).await;
let _ = result_tx.send(result);
});
(Some(trigger_tx), Some(result_rx))
} else {
(None, None)
};

let (bytes_received, next_seq) = download_single_response(
response,
&mut write_half,
cipher_ref,
encoding,
expected_seq,
max_bytes,
pre_fetch_trigger,
)
.await?;

expected_seq = next_seq;

let should_rotate = rotate_enabled && bytes_received >= max_bytes.unwrap();
if !should_rotate {
break;
}

response = pre_fetch_rx
.expect("pre_fetch_rx must be Some when rotate_enabled")
.await
.map_err(|_| anyhow!("pre-fetch task panicked"))??;
}

let _ = write_half.shutdown().await;
result
Ok(())
}
1 change: 1 addition & 0 deletions src/client/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ mod tests {
},
stages: vec![],
encoding_type: Default::default(),
max_download_bytes: None,
};
let (body, remaining, seq) =
encode_initial_payload(b"", shaper::MAX_RAW_PAYLOAD, None, &config).unwrap();
Expand Down
Loading
Loading