diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index d8924b6c..d9c5fa4f 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -12,7 +12,7 @@ bssh (Backend.AI SSH / Broadcast SSH) is a high-performance parallel SSH command - SSH jump host support (-J) - SSH configuration file parsing (-F) - Interactive PTY sessions with single/multiplex modes -- SFTP file transfers (upload/download) with chunked streaming +- SFTP file transfers (upload/download) with bounded pipelined streaming - Backend.AI cluster auto-detection - pdsh compatibility mode diff --git a/Cargo.lock b/Cargo.lock index edb80b38..7097badb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -581,6 +581,7 @@ dependencies = [ "bytes", "chrono", "flurry", + "futures", "log", "serde", "serde_bytes", diff --git a/crates/bssh-russh-sftp/Cargo.toml b/crates/bssh-russh-sftp/Cargo.toml index a69960c7..9bd21c6c 100644 --- a/crates/bssh-russh-sftp/Cargo.toml +++ b/crates/bssh-russh-sftp/Cargo.toml @@ -20,6 +20,7 @@ tokio = { version = "1", default-features = false, features = [ "macros", ] } tokio-util = "0.7" +futures = { version = "0.3", default-features = false, features = ["std", "async-await"] } serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" bitflags = { version = "2.9", features = ["serde"] } diff --git a/crates/bssh-russh-sftp/src/client/fs/file.rs b/crates/bssh-russh-sftp/src/client/fs/file.rs index 4b1cdcd2..a0e28d44 100644 --- a/crates/bssh-russh-sftp/src/client/fs/file.rs +++ b/crates/bssh-russh-sftp/src/client/fs/file.rs @@ -21,6 +21,10 @@ type StateFn = Option> + Send + Syn const MAX_READ_LENGTH: u64 = 261120; const MAX_WRITE_LENGTH: u64 = 261120; +fn bounded_chunk_size(limit: Option, default_limit: u64) -> usize { + limit.map_or(default_limit, |n| n.min(default_limit)) as usize +} + struct FileState { f_read: StateFn>>, f_seek: StateFn, @@ -92,6 +96,402 @@ impl File { self.session.fsync(self.handle.as_str()).await.map(|_| ()) } + + /// Streams `reader` to this remote file with up to `max_inflight` concurrent + /// SFTP `WRITE` requests in flight. Each request carries up to the negotiated + /// `write_len` (or [`MAX_WRITE_LENGTH`] when no limit is advertised). + /// + /// The high-level [`AsyncWrite`] impl issues one `WRITE` at a time and waits + /// for its `STATUS` reply before sending the next, so sustained throughput is + /// bounded by `chunk_size / RTT`. This helper hides the per-request RTT by + /// keeping multiple in-flight, mirroring how OpenSSH's `sftp` client behaves + /// (~64 outstanding requests by default). + /// + /// On success returns the number of bytes streamed. Updates `self.pos` to + /// the new write offset. Reading from `reader` and dispatching writes are + /// interleaved, so memory usage is bounded by `max_inflight * chunk_size`. + pub async fn write_all_pipelined( + &mut self, + reader: &mut R, + max_inflight: usize, + ) -> SftpResult + where + R: tokio::io::AsyncRead + Unpin, + { + use futures::stream::{FuturesUnordered, StreamExt}; + use tokio::io::AsyncReadExt; + + if max_inflight == 0 { + return Err(Error::UnexpectedBehavior( + "max_inflight must be at least 1".to_owned(), + )); + } + + let chunk_size = bounded_chunk_size( + self.extensions.limits.as_ref().and_then(|l| l.write_len), + MAX_WRITE_LENGTH, + ); + + let mut total: u64 = 0; + let mut offset = self.pos; + let mut in_flight = FuturesUnordered::new(); + let mut eof = false; + + loop { + // Top up the pipeline with new chunks until we hit the cap or EOF. + while !eof && in_flight.len() < max_inflight { + let mut buf = vec![0u8; chunk_size]; + let n = reader.read(&mut buf).await?; + if n == 0 { + eof = true; + break; + } + buf.truncate(n); + + let session = self.session.clone(); + let handle = self.handle.clone(); + let off = offset; + + in_flight.push(async move { + session.write(handle, off, buf).await?; + SftpResult::Ok(n as u64) + }); + + offset += n as u64; + total += n as u64; + } + + // Drain at least one in-flight write before reading more, otherwise + // we busy-loop the read path while writes never get a chance to make + // progress. + match in_flight.next().await { + Some(Ok(_)) => {} + Some(Err(e)) => return Err(e), + None => break, // pipeline drained and no more data -> done + } + } + + self.pos = offset; + Ok(total) + } + + /// Streams the remote file from the current position to `writer` using up to + /// `max_inflight` concurrent SFTP `READ` requests. Each request asks for up + /// to the negotiated `read_len`, capped at [`MAX_READ_LENGTH`]. + /// + /// Like [`Self::write_all_pipelined`], this hides per-request RTT. Chunks + /// are reassembled in offset order before being written to `writer`, so the + /// output is identical to a sequential read. For regular files, the current + /// file size is used to avoid speculative reads beyond EOF; if the size is + /// unavailable, the transfer stops on EOF or the first short read. + /// + /// Returns the number of bytes streamed. Updates `self.pos`. + pub async fn read_to_writer_pipelined( + &mut self, + writer: &mut W, + max_inflight: usize, + ) -> SftpResult + where + W: tokio::io::AsyncWrite + Unpin, + { + use futures::stream::{FuturesUnordered, StreamExt}; + use std::collections::BTreeMap; + use tokio::io::AsyncWriteExt; + + if max_inflight == 0 { + return Err(Error::UnexpectedBehavior( + "max_inflight must be at least 1".to_owned(), + )); + } + + let chunk_size = bounded_chunk_size( + self.extensions.limits.as_ref().and_then(|l| l.read_len), + MAX_READ_LENGTH, + ); + let file_end = self + .metadata() + .await + .ok() + .and_then(|m| m.size) + .filter(|&size| size >= self.pos); + + let mut total: u64 = 0; + let mut next_offset = self.pos; + let mut next_to_write = self.pos; + let mut pending: BTreeMap> = BTreeMap::new(); + let mut in_flight = FuturesUnordered::new(); + let mut eof = false; + + loop { + // Keep the total reorder buffer bounded. A slow early read can make + // later replies arrive first; counting both pending and in-flight + // chunks prevents unbounded memory growth in that case. + while !eof + && in_flight.len() + pending.len() < max_inflight + && file_end.is_none_or(|end| next_offset < end) + { + let session = self.session.clone(); + let handle = self.handle.clone(); + let off = next_offset; + let len = file_end.map_or(chunk_size as u64, |end| { + (end - next_offset).min(chunk_size as u64) + }) as u32; + + in_flight.push(async move { + match session.read(handle, off, len).await { + Ok(data) => SftpResult::Ok((off, len, Some(data.data))), + Err(Error::Status(s)) if s.status_code == StatusCode::Eof => { + SftpResult::Ok((off, len, None)) + } + Err(e) => Err(e), + } + }); + + next_offset += u64::from(len); + } + + match in_flight.next().await { + Some(Ok((off, len, Some(data)))) => { + if data.is_empty() { + eof = true; + } else { + if let Some(end) = file_end { + let got_end = off.saturating_add(data.len() as u64); + if data.len() != len as usize || got_end > end { + return Err(Error::UnexpectedBehavior(format!( + "short read before EOF at offset {off}: requested {len} bytes, received {} bytes", + data.len() + ))); + } + } else if data.len() < len as usize { + eof = true; + } + + pending.insert(off, data); + } + } + Some(Ok((off, _, None))) => { + if file_end.is_some_and(|end| off < end) { + return Err(Error::UnexpectedBehavior(format!( + "unexpected EOF before file size at offset {off}" + ))); + } + eof = true; + } + Some(Err(e)) => return Err(e), + None => break, + } + + // Flush in-order chunks to writer as they become available. + while let Some(chunk) = pending.remove(&next_to_write) { + let n = chunk.len() as u64; + writer.write_all(&chunk).await?; + next_to_write += n; + total += n; + } + } + + self.pos = next_to_write; + Ok(total) + } +} + +#[cfg(test)] +mod tests { + use std::{ + future::Future, + sync::{Arc, Mutex}, + }; + + use tokio::io::duplex; + + use super::*; + use crate::{ + client::SftpSession, + protocol::{Attrs, Data, FileAttributes, Handle, OpenFlags, Status, Version}, + server, + server::Handler, + }; + + struct MemoryHandler { + data: Arc>>, + } + + impl MemoryHandler { + fn ok_status(id: u32) -> Status { + Status { + id, + status_code: StatusCode::Ok, + error_message: String::new(), + language_tag: String::new(), + } + } + } + + impl Handler for MemoryHandler { + type Error = StatusCode; + + fn unimplemented(&self) -> Self::Error { + StatusCode::OpUnsupported + } + + fn init( + &mut self, + _version: u32, + _extensions: std::collections::HashMap, + ) -> impl Future> + Send { + async { Ok(Version::new()) } + } + + fn open( + &mut self, + id: u32, + _filename: String, + _pflags: OpenFlags, + _attrs: FileAttributes, + ) -> impl Future> + Send { + async move { + Ok(Handle { + id, + handle: "memory".to_owned(), + }) + } + } + + fn close( + &mut self, + id: u32, + _handle: String, + ) -> impl Future> + Send { + async move { Ok(Self::ok_status(id)) } + } + + fn fstat( + &mut self, + id: u32, + _handle: String, + ) -> impl Future> + Send { + let data = self.data.clone(); + + async move { + let mut attrs = FileAttributes::empty(); + attrs.size = Some(data.lock().expect("memory file lock poisoned").len() as u64); + Ok(Attrs { id, attrs }) + } + } + + fn read( + &mut self, + id: u32, + _handle: String, + offset: u64, + len: u32, + ) -> impl Future> + Send { + let data = self.data.clone(); + + async move { + let data = data.lock().expect("memory file lock poisoned"); + let offset = usize::try_from(offset).map_err(|_| StatusCode::Failure)?; + if offset >= data.len() { + return Err(StatusCode::Eof); + } + let end = offset.saturating_add(len as usize).min(data.len()); + + Ok(Data { + id, + data: data[offset..end].to_vec(), + }) + } + } + + fn write( + &mut self, + id: u32, + _handle: String, + offset: u64, + bytes: Vec, + ) -> impl Future> + Send { + let data = self.data.clone(); + + async move { + let mut data = data.lock().expect("memory file lock poisoned"); + let offset = usize::try_from(offset).map_err(|_| StatusCode::Failure)?; + let end = offset.checked_add(bytes.len()).ok_or(StatusCode::Failure)?; + if data.len() < end { + data.resize(end, 0); + } + data[offset..end].copy_from_slice(&bytes); + + Ok(Self::ok_status(id)) + } + } + } + + async fn memory_session(data: Arc>>) -> SftpSession { + let (client, server_stream) = duplex(64 * 1024); + server::run(server_stream, MemoryHandler { data }).await; + SftpSession::new(client).await.expect("memory SFTP init") + } + + #[test] + fn advertised_chunk_sizes_are_capped() { + assert_eq!( + bounded_chunk_size(None, MAX_READ_LENGTH), + MAX_READ_LENGTH as usize + ); + assert_eq!(bounded_chunk_size(Some(1024), MAX_READ_LENGTH), 1024); + assert_eq!( + bounded_chunk_size(Some(MAX_READ_LENGTH * 4), MAX_READ_LENGTH), + MAX_READ_LENGTH as usize + ); + } + + #[tokio::test] + async fn write_all_pipelined_streams_all_bytes() { + let remote_data = Arc::new(Mutex::new(Vec::new())); + let sftp = memory_session(remote_data.clone()).await; + let input: Vec = (0..(MAX_WRITE_LENGTH as usize * 2 + 123)) + .map(|n| (n % 251) as u8) + .collect(); + let mut reader = &input[..]; + let mut file = sftp + .open_with_flags( + "ignored", + OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE, + ) + .await + .expect("open memory file"); + + let written = file + .write_all_pipelined(&mut reader, 4) + .await + .expect("pipelined write"); + + assert_eq!(written as usize, input.len()); + assert_eq!( + *remote_data.lock().expect("memory file lock poisoned"), + input + ); + } + + #[tokio::test] + async fn read_to_writer_pipelined_streams_all_bytes() { + let input: Vec = (0..(MAX_READ_LENGTH as usize * 2 + 123)) + .map(|n| (n % 251) as u8) + .collect(); + let remote_data = Arc::new(Mutex::new(input.clone())); + let sftp = memory_session(remote_data).await; + let mut file = sftp.open("ignored").await.expect("open memory file"); + let mut output = Vec::new(); + + let read = file + .read_to_writer_pipelined(&mut output, 4) + .await + .expect("pipelined read"); + + assert_eq!(read as usize, input.len()); + assert_eq!(output, input); + } } impl Drop for File { diff --git a/docs/architecture/ssh-client.md b/docs/architecture/ssh-client.md index ab713d73..29df9d26 100644 --- a/docs/architecture/ssh-client.md +++ b/docs/architecture/ssh-client.md @@ -29,8 +29,7 @@ - Support for SSH agent, key-based, and password authentication - Configurable timeouts and retry logic - Full SFTP support for file transfers -- SFTP uploads/downloads stream file payloads in 255 KiB chunks, matching the - default russh-sftp read/write packet limit and avoiding whole-file buffering +- SFTP uploads/downloads use bounded pipelined streaming with up to 64 concurrent SFTP READ/WRITE requests, capped at the default russh-sftp packet size to avoid whole-file buffering and oversized server-advertised limits - SSH keepalive support via `SshConnectionConfig`: - `keepalive_interval`: Interval between keepalive packets (default: 60s, 0 to disable) - `keepalive_max`: Maximum unanswered keepalive packets before disconnect (default: 3) diff --git a/src/ssh/tokio_client/file_transfer.rs b/src/ssh/tokio_client/file_transfer.rs index 2f8b8b67..84836190 100644 --- a/src/ssh/tokio_client/file_transfer.rs +++ b/src/ssh/tokio_client/file_transfer.rs @@ -21,34 +21,16 @@ use russh_sftp::{client::SftpSession, protocol::OpenFlags}; use std::path::Path; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -/// Chunk size used for streaming SFTP uploads/downloads. -/// -/// Sized to match russh-sftp's default MAX_WRITE_LENGTH (255 KiB) so each -/// chunk maps to a single SFTP WRITE/READ packet without further fragmentation. -const STREAM_CHUNK_SIZE: usize = 255 * 1024; - -/// Stream `reader` to `writer` in fixed-size chunks so a single transfer never -/// holds more than `STREAM_CHUNK_SIZE` of file payload in memory at once. -async fn stream_copy(reader: &mut R, writer: &mut W) -> std::io::Result<()> -where - R: tokio::io::AsyncRead + Unpin, - W: tokio::io::AsyncWrite + Unpin, -{ - let mut buf = vec![0u8; STREAM_CHUNK_SIZE]; - loop { - let n = reader.read(&mut buf).await?; - if n == 0 { - break; - } - writer.write_all(&buf[..n]).await?; - } - Ok(()) -} +use tokio::io::AsyncWriteExt; use super::connection::Client; +/// Maximum number of concurrent SFTP `WRITE`/`READ` requests held in flight per +/// transfer. Mirrors OpenSSH `sftp(1)`'s default (`-R 64`) — large enough to +/// hide per-request RTT on intra-DC and intercontinental links, small enough to +/// keep peak buffer memory bounded (`MAX_INFLIGHT * MAX_WRITE_LENGTH ≈ 16 MiB`). +const MAX_INFLIGHT_REQUESTS: usize = 64; + impl Client { /// Upload a file with sftp to the remote server. /// @@ -69,7 +51,8 @@ impl Client { channel.request_subsystem(true, "sftp").await?; let sftp = SftpSession::new(channel.into_stream()).await?; - // Open local file for streaming reads (avoids loading whole file in memory). + // Stream local file with multiple SFTP WRITE requests in flight to + // hide per-request RTT and avoid loading the entire file in memory. let mut local_file = tokio::fs::File::open(src_file_path) .await .map_err(super::Error::IoError)?; @@ -80,9 +63,8 @@ impl Client { OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ, ) .await?; - stream_copy(&mut local_file, &mut file) - .await - .map_err(super::Error::IoError)?; + file.write_all_pipelined(&mut local_file, MAX_INFLIGHT_REQUESTS) + .await?; file.flush().await.map_err(super::Error::IoError)?; file.shutdown().await.map_err(super::Error::IoError)?; @@ -106,19 +88,18 @@ impl Client { channel.request_subsystem(true, "sftp").await?; let sftp = SftpSession::new(channel.into_stream()).await?; - // open remote file for reading + // Stream remote file with multiple SFTP READ requests in flight; chunks + // are reassembled in offset order before being written to disk. let mut remote_file = sftp .open_with_flags(remote_file_path, OpenFlags::READ) .await?; - // Stream remote file directly to local disk to avoid buffering the - // whole file in memory. let mut local_file = tokio::fs::File::create(local_file_path.as_ref()) .await .map_err(super::Error::IoError)?; - stream_copy(&mut remote_file, &mut local_file) - .await - .map_err(super::Error::IoError)?; + remote_file + .read_to_writer_pipelined(&mut local_file, MAX_INFLIGHT_REQUESTS) + .await?; remote_file .shutdown() .await @@ -193,8 +174,7 @@ impl Client { let _ = sftp.create_dir(&remote_path).await; // Ignore error if already exists self.upload_dir_recursive(sftp, &path, &remote_path).await?; } else if metadata.is_file() { - // Stream local file to remote in chunks instead of loading - // the entire file in memory before send. + // Stream local file with pipelined SFTP WRITEs. let mut local_file = tokio::fs::File::open(&path) .await .map_err(super::Error::IoError)?; @@ -206,9 +186,9 @@ impl Client { ) .await?; - stream_copy(&mut local_file, &mut remote_file) - .await - .map_err(super::Error::IoError)?; + remote_file + .write_all_pipelined(&mut local_file, MAX_INFLIGHT_REQUESTS) + .await?; remote_file.flush().await.map_err(super::Error::IoError)?; remote_file .shutdown() @@ -285,16 +265,16 @@ impl Client { self.download_dir_recursive(sftp, &remote_path, &local_path) .await?; } else if metadata.file_type().is_file() { - // Stream remote file directly to local disk in chunks. + // Stream remote file with pipelined SFTP READs. let mut remote_file = sftp.open_with_flags(&remote_path, OpenFlags::READ).await?; let mut local_file = tokio::fs::File::create(&local_path) .await .map_err(super::Error::IoError)?; - stream_copy(&mut remote_file, &mut local_file) - .await - .map_err(super::Error::IoError)?; + remote_file + .read_to_writer_pipelined(&mut local_file, MAX_INFLIGHT_REQUESTS) + .await?; remote_file .shutdown() .await @@ -307,55 +287,3 @@ impl Client { }) } } - -#[cfg(test)] -mod tests { - use super::*; - use std::{ - io::Cursor, - pin::Pin, - task::{Context, Poll}, - }; - use tokio::io::AsyncWrite; - - #[derive(Default)] - struct RecordingWriter { - bytes: Vec, - write_lengths: Vec, - } - - impl AsyncWrite for RecordingWriter { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.write_lengths.push(buf.len()); - self.bytes.extend_from_slice(buf); - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - } - - #[tokio::test] - async fn stream_copy_writes_sftp_sized_chunks() { - let input = vec![0xAB; STREAM_CHUNK_SIZE * 2 + 17]; - let mut reader = Cursor::new(input.clone()); - let mut writer = RecordingWriter::default(); - - stream_copy(&mut reader, &mut writer).await.unwrap(); - - assert_eq!(writer.bytes, input); - assert_eq!( - writer.write_lengths, - vec![STREAM_CHUNK_SIZE, STREAM_CHUNK_SIZE, 17] - ); - } -}