Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ bssh (Backend.AI SSH / Broadcast SSH) is a high-performance parallel SSH command
- SSH jump host support (-J)
- SSH configuration file parsing (-F)
- Interactive PTY sessions with single/multiplex modes
- SFTP file transfers (upload/download)
- SFTP file transfers (upload/download) with chunked streaming
- Backend.AI cluster auto-detection
- pdsh compatibility mode

Expand Down
2 changes: 2 additions & 0 deletions docs/architecture/ssh-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
- Support for SSH agent, key-based, and password authentication
- Configurable timeouts and retry logic
- Full SFTP support for file transfers
- SFTP uploads/downloads stream file payloads in 255 KiB chunks, matching the
default russh-sftp read/write packet limit and avoiding whole-file buffering
- SSH keepalive support via `SshConnectionConfig`:
- `keepalive_interval`: Interval between keepalive packets (default: 60s, 0 to disable)
- `keepalive_max`: Maximum unanswered keepalive packets before disconnect (default: 3)
Expand Down
124 changes: 100 additions & 24 deletions src/ssh/tokio_client/file_transfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,31 @@ use russh_sftp::{client::SftpSession, protocol::OpenFlags};
use std::path::Path;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

/// Chunk size used for streaming SFTP uploads/downloads.
///
/// Sized to match russh-sftp's default MAX_WRITE_LENGTH (255 KiB) so each
/// chunk maps to a single SFTP WRITE/READ packet without further fragmentation.
const STREAM_CHUNK_SIZE: usize = 255 * 1024;

/// Stream `reader` to `writer` in fixed-size chunks so a single transfer never
/// holds more than `STREAM_CHUNK_SIZE` of file payload in memory at once.
async fn stream_copy<R, W>(reader: &mut R, writer: &mut W) -> std::io::Result<()>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut buf = vec![0u8; STREAM_CHUNK_SIZE];
loop {
let n = reader.read(&mut buf).await?;
if n == 0 {
break;
}
writer.write_all(&buf[..n]).await?;
}
Ok(())
}

use super::connection::Client;
use crate::utils::buffer_pool::global;

impl Client {
/// Upload a file with sftp to the remote server.
Expand All @@ -46,19 +69,18 @@ impl Client {
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;

// read file contents locally
let file_contents = tokio::fs::read(src_file_path)
// Open local file for streaming reads (avoids loading whole file in memory).
let mut local_file = tokio::fs::File::open(src_file_path)
.await
.map_err(super::Error::IoError)?;

// interaction with i/o
let mut file = sftp
.open_with_flags(
dest_file_path,
OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
)
.await?;
file.write_all(&file_contents)
stream_copy(&mut local_file, &mut file)
.await
.map_err(super::Error::IoError)?;
file.flush().await.map_err(super::Error::IoError)?;
Expand Down Expand Up @@ -89,18 +111,16 @@ impl Client {
.open_with_flags(remote_file_path, OpenFlags::READ)
.await?;

// Use pooled buffer for reading file contents to reduce allocations
let mut pooled_buffer = global::get_large_buffer();
remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?;
let contents = pooled_buffer.as_vec().clone(); // Clone to owned Vec for writing

// write contents to local file
// Stream remote file directly to local disk to avoid buffering the
// whole file in memory.
let mut local_file = tokio::fs::File::create(local_file_path.as_ref())
.await
.map_err(super::Error::IoError)?;

local_file
.write_all(&contents)
stream_copy(&mut remote_file, &mut local_file)
.await
.map_err(super::Error::IoError)?;
remote_file
.shutdown()
.await
.map_err(super::Error::IoError)?;
local_file.flush().await.map_err(super::Error::IoError)?;
Expand Down Expand Up @@ -173,8 +193,9 @@ impl Client {
let _ = sftp.create_dir(&remote_path).await; // Ignore error if already exists
self.upload_dir_recursive(sftp, &path, &remote_path).await?;
} else if metadata.is_file() {
// Upload file
let file_contents = tokio::fs::read(&path)
// Stream local file to remote in chunks instead of loading
// the entire file in memory before send.
let mut local_file = tokio::fs::File::open(&path)
.await
.map_err(super::Error::IoError)?;

Expand All @@ -185,8 +206,7 @@ impl Client {
)
.await?;

remote_file
.write_all(&file_contents)
stream_copy(&mut local_file, &mut remote_file)
.await
.map_err(super::Error::IoError)?;
remote_file.flush().await.map_err(super::Error::IoError)?;
Expand Down Expand Up @@ -265,21 +285,77 @@ impl Client {
self.download_dir_recursive(sftp, &remote_path, &local_path)
.await?;
} else if metadata.file_type().is_file() {
// Download file using pooled buffer
// Stream remote file directly to local disk in chunks.
let mut remote_file =
sftp.open_with_flags(&remote_path, OpenFlags::READ).await?;

let mut pooled_buffer = global::get_large_buffer();
remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?;
let contents = pooled_buffer.as_vec().clone();

tokio::fs::write(&local_path, contents)
let mut local_file = tokio::fs::File::create(&local_path)
.await
.map_err(super::Error::IoError)?;
stream_copy(&mut remote_file, &mut local_file)
.await
.map_err(super::Error::IoError)?;
remote_file
.shutdown()
.await
.map_err(super::Error::IoError)?;
local_file.flush().await.map_err(super::Error::IoError)?;
}
}

Ok(())
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::{
io::Cursor,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::AsyncWrite;

#[derive(Default)]
struct RecordingWriter {
bytes: Vec<u8>,
write_lengths: Vec<usize>,
}

impl AsyncWrite for RecordingWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.write_lengths.push(buf.len());
self.bytes.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}

#[tokio::test]
async fn stream_copy_writes_sftp_sized_chunks() {
let input = vec![0xAB; STREAM_CHUNK_SIZE * 2 + 17];
let mut reader = Cursor::new(input.clone());
let mut writer = RecordingWriter::default();

stream_copy(&mut reader, &mut writer).await.unwrap();

assert_eq!(writer.bytes, input);
assert_eq!(
writer.write_lengths,
vec![STREAM_CHUNK_SIZE, STREAM_CHUNK_SIZE, 17]
);
}
}