diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index cc098daf..d8924b6c 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -12,7 +12,7 @@ bssh (Backend.AI SSH / Broadcast SSH) is a high-performance parallel SSH command - SSH jump host support (-J) - SSH configuration file parsing (-F) - Interactive PTY sessions with single/multiplex modes -- SFTP file transfers (upload/download) +- SFTP file transfers (upload/download) with chunked streaming - Backend.AI cluster auto-detection - pdsh compatibility mode diff --git a/docs/architecture/ssh-client.md b/docs/architecture/ssh-client.md index e435ceaf..ab713d73 100644 --- a/docs/architecture/ssh-client.md +++ b/docs/architecture/ssh-client.md @@ -29,6 +29,8 @@ - Support for SSH agent, key-based, and password authentication - Configurable timeouts and retry logic - Full SFTP support for file transfers +- SFTP uploads/downloads stream file payloads in 255 KiB chunks, matching the + default russh-sftp read/write packet limit and avoiding whole-file buffering - SSH keepalive support via `SshConnectionConfig`: - `keepalive_interval`: Interval between keepalive packets (default: 60s, 0 to disable) - `keepalive_max`: Maximum unanswered keepalive packets before disconnect (default: 3) diff --git a/src/ssh/tokio_client/file_transfer.rs b/src/ssh/tokio_client/file_transfer.rs index 5fda7622..2f8b8b67 100644 --- a/src/ssh/tokio_client/file_transfer.rs +++ b/src/ssh/tokio_client/file_transfer.rs @@ -23,8 +23,31 @@ use russh_sftp::{client::SftpSession, protocol::OpenFlags}; use std::path::Path; use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// Chunk size used for streaming SFTP uploads/downloads. +/// +/// Sized to match russh-sftp's default MAX_WRITE_LENGTH (255 KiB) so each +/// chunk maps to a single SFTP WRITE/READ packet without further fragmentation. +const STREAM_CHUNK_SIZE: usize = 255 * 1024; + +/// Stream `reader` to `writer` in fixed-size chunks so a single transfer never +/// holds more than `STREAM_CHUNK_SIZE` of file payload in memory at once. +async fn stream_copy(reader: &mut R, writer: &mut W) -> std::io::Result<()> +where + R: tokio::io::AsyncRead + Unpin, + W: tokio::io::AsyncWrite + Unpin, +{ + let mut buf = vec![0u8; STREAM_CHUNK_SIZE]; + loop { + let n = reader.read(&mut buf).await?; + if n == 0 { + break; + } + writer.write_all(&buf[..n]).await?; + } + Ok(()) +} + use super::connection::Client; -use crate::utils::buffer_pool::global; impl Client { /// Upload a file with sftp to the remote server. @@ -46,19 +69,18 @@ impl Client { channel.request_subsystem(true, "sftp").await?; let sftp = SftpSession::new(channel.into_stream()).await?; - // read file contents locally - let file_contents = tokio::fs::read(src_file_path) + // Open local file for streaming reads (avoids loading whole file in memory). + let mut local_file = tokio::fs::File::open(src_file_path) .await .map_err(super::Error::IoError)?; - // interaction with i/o let mut file = sftp .open_with_flags( dest_file_path, OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ, ) .await?; - file.write_all(&file_contents) + stream_copy(&mut local_file, &mut file) .await .map_err(super::Error::IoError)?; file.flush().await.map_err(super::Error::IoError)?; @@ -89,18 +111,16 @@ impl Client { .open_with_flags(remote_file_path, OpenFlags::READ) .await?; - // Use pooled buffer for reading file contents to reduce allocations - let mut pooled_buffer = global::get_large_buffer(); - remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?; - let contents = pooled_buffer.as_vec().clone(); // Clone to owned Vec for writing - - // write contents to local file + // Stream remote file directly to local disk to avoid buffering the + // whole file in memory. let mut local_file = tokio::fs::File::create(local_file_path.as_ref()) .await .map_err(super::Error::IoError)?; - - local_file - .write_all(&contents) + stream_copy(&mut remote_file, &mut local_file) + .await + .map_err(super::Error::IoError)?; + remote_file + .shutdown() .await .map_err(super::Error::IoError)?; local_file.flush().await.map_err(super::Error::IoError)?; @@ -173,8 +193,9 @@ impl Client { let _ = sftp.create_dir(&remote_path).await; // Ignore error if already exists self.upload_dir_recursive(sftp, &path, &remote_path).await?; } else if metadata.is_file() { - // Upload file - let file_contents = tokio::fs::read(&path) + // Stream local file to remote in chunks instead of loading + // the entire file in memory before send. + let mut local_file = tokio::fs::File::open(&path) .await .map_err(super::Error::IoError)?; @@ -185,8 +206,7 @@ impl Client { ) .await?; - remote_file - .write_all(&file_contents) + stream_copy(&mut local_file, &mut remote_file) .await .map_err(super::Error::IoError)?; remote_file.flush().await.map_err(super::Error::IoError)?; @@ -265,17 +285,21 @@ impl Client { self.download_dir_recursive(sftp, &remote_path, &local_path) .await?; } else if metadata.file_type().is_file() { - // Download file using pooled buffer + // Stream remote file directly to local disk in chunks. let mut remote_file = sftp.open_with_flags(&remote_path, OpenFlags::READ).await?; - let mut pooled_buffer = global::get_large_buffer(); - remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?; - let contents = pooled_buffer.as_vec().clone(); - - tokio::fs::write(&local_path, contents) + let mut local_file = tokio::fs::File::create(&local_path) + .await + .map_err(super::Error::IoError)?; + stream_copy(&mut remote_file, &mut local_file) + .await + .map_err(super::Error::IoError)?; + remote_file + .shutdown() .await .map_err(super::Error::IoError)?; + local_file.flush().await.map_err(super::Error::IoError)?; } } @@ -283,3 +307,55 @@ impl Client { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + io::Cursor, + pin::Pin, + task::{Context, Poll}, + }; + use tokio::io::AsyncWrite; + + #[derive(Default)] + struct RecordingWriter { + bytes: Vec, + write_lengths: Vec, + } + + impl AsyncWrite for RecordingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.write_lengths.push(buf.len()); + self.bytes.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[tokio::test] + async fn stream_copy_writes_sftp_sized_chunks() { + let input = vec![0xAB; STREAM_CHUNK_SIZE * 2 + 17]; + let mut reader = Cursor::new(input.clone()); + let mut writer = RecordingWriter::default(); + + stream_copy(&mut reader, &mut writer).await.unwrap(); + + assert_eq!(writer.bytes, input); + assert_eq!( + writer.write_lengths, + vec![STREAM_CHUNK_SIZE, STREAM_CHUNK_SIZE, 17] + ); + } +}