diff --git a/.pi/todos/58e24853.md b/.pi/todos/58e24853.md new file mode 100644 index 00000000..d5107ae0 --- /dev/null +++ b/.pi/todos/58e24853.md @@ -0,0 +1,20 @@ +--- +id: "58e24853" +title: "Integrate turmoil with feature-gated type replacement" +tags: + - "turmoil" + - "feature-gate" + - "network-simulation" +status: "completed" +created_at: "2026-04-16T08:02:46.705Z" +--- +Add turmoil as a feature-gated dependency that replaces tokio::net types with turmoil::net types when enabled. + +## Tasks +- [x] Add turmoil to workspace dependencies +- [x] Add turmoil feature to msg-transport Cargo.toml +- [x] Create type alias module for feature-gated type resolution +- [x] Update TCP transport to use type aliases with channel-based accept for turmoil +- [x] QUIC uses real UDP sockets (documented limitation) +- [x] IPC uses Unix sockets which turmoil doesn't support (expected) +- [x] Test the integration with cargo check - all builds pass diff --git a/.pi/todos/settings.json b/.pi/todos/settings.json new file mode 100644 index 00000000..a355918b --- /dev/null +++ b/.pi/todos/settings.json @@ -0,0 +1,4 @@ +{ + "gc": true, + "gcDays": 7 +} diff --git a/Cargo.lock b/Cargo.lock index 50490d26..bd9af04c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -971,6 +971,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "turmoil", ] [[package]] @@ -991,6 +992,7 @@ dependencies = [ "tokio-openssl", "tracing", "tracing-subscriber", + "turmoil", ] [[package]] @@ -1139,6 +1141,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1457,6 +1460,16 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rayon" version = "1.11.0" @@ -1717,6 +1730,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -1810,6 +1829,16 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -2069,7 +2098,9 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.61.2", @@ -2184,6 +2215,21 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "turmoil" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5384da930ba6d7e467030c421a7332726755d548ba38058aed30c2c30d991d2" +dependencies = [ + "bytes", + "indexmap", + "rand", + "rand_distr", + "scoped-tls", + "tokio", + "tracing", +] + [[package]] name = "twox-hash" version = "2.1.2" diff --git a/Cargo.toml b/Cargo.toml index 0caed8d7..5fa94ef5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ tokio = { version = "1", features = [ "io-util", "macros", "fs", + "sync", ] } tokio-util = { version = "0.7", features = ["codec"] } tokio-stream = { version = "0.1", features = ["sync"] } @@ -77,6 +78,10 @@ quinn = "0.11.9" rcgen = "0.14" openssl = { version = "0.10" } +# turmoil simulation +# Note: version must match tokio version for compatibility +turmoil = { version = "0.7" } + # benchmarking & profiling criterion = { version = "0.5", features = ["async_tokio"] } pprof = { version = "0.15", features = ["flamegraph", "criterion"] } diff --git a/libmsg/Cargo.toml b/libmsg/Cargo.toml index 25cd22e4..904b002a 100644 --- a/libmsg/Cargo.toml +++ b/libmsg/Cargo.toml @@ -34,8 +34,12 @@ tikv-jemallocator = { version = "0.6.1", features = ["profiling"] } [features] default = [] -quic = ["msg-transport/quic"] -tcp-tls = ["msg-transport/tcp-tls"] +quic = ["msg-transport/quic", "msg-socket/quic"] +tcp-tls = ["msg-transport/tcp-tls", "msg-socket/tcp-tls"] +# Enables turmoil-based network simulation across the transport and socket +# layers. TCP and TCP-TLS traffic both flow through `turmoil::net`, so TLS +# integration tests can run inside a deterministic simulation. +turmoil = ["msg-transport/turmoil", "msg-socket/turmoil"] [[bench]] name = "reqrep" diff --git a/msg-socket/Cargo.toml b/msg-socket/Cargo.toml index cd75bd9c..397a0e97 100644 --- a/msg-socket/Cargo.toml +++ b/msg-socket/Cargo.toml @@ -26,12 +26,25 @@ tracing.workspace = true tokio-stream.workspace = true parking_lot.workspace = true arc-swap.workspace = true +turmoil = { workspace = true, optional = true } derive_more = { workspace = true, features = ["deref"] } [dev-dependencies] rand.workspace = true -msg-transport = { workspace = true, features = ["quic", "tcp-tls"] } +# Transport features are forwarded through this crate's own feature flags below, +# so tests pick them up via `--features` rather than being hardcoded here. +msg-transport.workspace = true openssl.workspace = true +turmoil.workspace = true tracing-subscriber = "0.3" + +[features] +default = [] +# Transport forwarding. Downstream users (and this crate's integration tests) +# opt into each transport via these flags without depending on `msg-transport` +# directly. `turmoil` composes with `tcp-tls` and `quic`. +tcp-tls = ["msg-transport/tcp-tls"] +quic = ["msg-transport/quic"] +turmoil = ["dep:turmoil", "msg-transport/turmoil"] diff --git a/msg-socket/src/lib.rs b/msg-socket/src/lib.rs index d47d2c1b..1f0caf00 100644 --- a/msg-socket/src/lib.rs +++ b/msg-socket/src/lib.rs @@ -35,6 +35,8 @@ pub use sub::*; mod connection; pub use connection::*; +mod resolve; + /// The default buffer size for a socket. pub const DEFAULT_BUFFER_SIZE: usize = 8192; diff --git a/msg-socket/src/pub/mod.rs b/msg-socket/src/pub/mod.rs index d8c66fd6..db15d8ca 100644 --- a/msg-socket/src/pub/mod.rs +++ b/msg-socket/src/pub/mod.rs @@ -227,12 +227,14 @@ impl Default for SocketState { } } -#[cfg(test)] +#[cfg(all(test, not(feature = "turmoil")))] mod tests { use std::time::Duration; use futures::StreamExt; - use msg_transport::{quic::Quic, tcp::Tcp}; + #[cfg(feature = "quic")] + use msg_transport::quic::Quic; + use msg_transport::tcp::Tcp; use msg_wire::compression::GzipCompressor; use tracing::info; @@ -291,6 +293,7 @@ mod tests { assert_eq!("WORLD", msg.payload()); } + #[cfg(feature = "quic")] #[tokio::test] async fn pubsub_auth_quic() { let _ = tracing_subscriber::fmt::try_init(); @@ -408,6 +411,7 @@ mod tests { assert_eq!("WORLD", msg.payload()); } + #[cfg(feature = "quic")] #[tokio::test] async fn pubsub_durable_quic() { let _ = tracing_subscriber::fmt::try_init(); diff --git a/msg-socket/src/pub/socket.rs b/msg-socket/src/pub/socket.rs index 55c178b7..d60df469 100644 --- a/msg-socket/src/pub/socket.rs +++ b/msg-socket/src/pub/socket.rs @@ -3,15 +3,14 @@ use std::{net::SocketAddr, path::PathBuf, sync::Arc}; use arc_swap::Guard; use bytes::Bytes; use futures::stream::FuturesUnordered; -use tokio::{ - net::{ToSocketAddrs, lookup_host}, - sync::broadcast, - task::JoinSet, -}; +use tokio::{sync::broadcast, task::JoinSet}; use tracing::{debug, trace, warn}; use super::{PubError, PubMessage, PubOptions, SocketState, driver::PubDriver, stats::PubStats}; -use crate::{ConnectionHook, ConnectionHookErased}; +use crate::{ + ConnectionHook, ConnectionHookErased, + resolve::{ToSocketAddrs, lookup_host}, +}; use msg_transport::{Address, Transport}; use msg_wire::compression::Compressor; diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index 0b22a9ed..f552c0db 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -188,7 +188,7 @@ impl Request { } } -#[cfg(test)] +#[cfg(all(test, not(feature = "turmoil")))] mod tests { use std::{net::SocketAddr, time::Duration}; diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index 5ef4d346..ef6f1f3e 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -8,7 +8,6 @@ use std::{ use futures::{Stream, stream::FuturesUnordered}; use tokio::{ - net::{ToSocketAddrs, lookup_host}, sync::mpsc, task::{JoinHandle, JoinSet}, }; @@ -18,6 +17,7 @@ use tracing::{debug, warn}; use crate::{ ConnectionHook, ConnectionHookErased, DEFAULT_QUEUE_SIZE, RepOptions, Request, rep::{RepError, SocketState, driver::RepDriver}, + resolve::{ToSocketAddrs, lookup_host}, }; use msg_transport::{Address, Transport}; diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 5ec9c494..8065fcf5 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -3,10 +3,7 @@ use std::{marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc}; use arc_swap::Guard; use bytes::Bytes; use rustc_hash::FxHashMap; -use tokio::{ - net::{ToSocketAddrs, lookup_host}, - sync::{mpsc, mpsc::error::TrySendError, oneshot}, -}; +use tokio::sync::{mpsc, mpsc::error::TrySendError, oneshot}; use tokio_util::codec::Framed; use msg_common::span::WithSpan; @@ -23,6 +20,7 @@ use crate::{ driver::ReqDriver, stats::ReqStats, }, + resolve::{ToSocketAddrs, lookup_host}, stats::SocketStats, }; use std::sync::atomic::Ordering; diff --git a/msg-socket/src/resolve.rs b/msg-socket/src/resolve.rs new file mode 100644 index 00000000..e0abd929 --- /dev/null +++ b/msg-socket/src/resolve.rs @@ -0,0 +1,26 @@ +use std::{io, net::SocketAddr}; + +#[cfg(not(feature = "turmoil"))] +pub(crate) use tokio::net::ToSocketAddrs; +#[cfg(feature = "turmoil")] +pub(crate) use turmoil::ToSocketAddrs; + +#[cfg(not(feature = "turmoil"))] +pub(crate) async fn lookup_host( + addr: impl ToSocketAddrs, +) -> io::Result> { + tokio::net::lookup_host(addr).await +} + +#[cfg(feature = "turmoil")] +pub(crate) async fn lookup_host( + addr: impl ToSocketAddrs, +) -> io::Result> { + if !turmoil::in_simulation() { + return Err(io::Error::other( + "hostname resolution under the `turmoil` feature requires a running turmoil simulation", + )); + } + + turmoil::net::lookup_host(addr).await +} diff --git a/msg-socket/src/sub/mod.rs b/msg-socket/src/sub/mod.rs index eb82351e..a6ecb20b 100644 --- a/msg-socket/src/sub/mod.rs +++ b/msg-socket/src/sub/mod.rs @@ -170,7 +170,7 @@ impl Default for SocketState { } } -#[cfg(test)] +#[cfg(all(test, not(feature = "turmoil")))] mod tests { use std::net::SocketAddr; diff --git a/msg-socket/src/sub/socket.rs b/msg-socket/src/sub/socket.rs index cb87e93b..743b2d94 100644 --- a/msg-socket/src/sub/socket.rs +++ b/msg-socket/src/sub/socket.rs @@ -9,17 +9,14 @@ use std::{ use futures::Stream; use rustc_hash::FxHashMap; -use tokio::{ - net::{ToSocketAddrs, lookup_host}, - sync::mpsc, - task::JoinSet, -}; +use tokio::{sync::mpsc, task::JoinSet}; use msg_common::{IpAddrExt, JoinMap}; use msg_transport::{Address, Transport}; use crate::{ ConnectionHook, ConnectionHookErased, + resolve::{ToSocketAddrs, lookup_host}, sub::{ Command, DEFAULT_BUFFER_SIZE, PubMessage, SocketState, SubDriver, SubError, SubOptions, stats::SubStats, diff --git a/msg-socket/tests/it/main.rs b/msg-socket/tests/it/main.rs index 1a0fb94a..65423d0c 100644 --- a/msg-socket/tests/it/main.rs +++ b/msg-socket/tests/it/main.rs @@ -1,4 +1,8 @@ +#[cfg(not(feature = "turmoil"))] mod pubsub; +#[cfg(not(feature = "turmoil"))] mod reqrep; +#[cfg(feature = "turmoil")] +mod turmoil; fn main() {} diff --git a/msg-socket/tests/it/pubsub.rs b/msg-socket/tests/it/pubsub.rs index c24a9fb5..1cdbe02c 100644 --- a/msg-socket/tests/it/pubsub.rs +++ b/msg-socket/tests/it/pubsub.rs @@ -7,7 +7,9 @@ use tokio_stream::StreamExt; use tracing::info; use msg_socket::{PubSocket, SubSocket}; -use msg_transport::{Address, Transport, quic::Quic, tcp::Tcp}; +#[cfg(feature = "quic")] +use msg_transport::quic::Quic; +use msg_transport::{Address, Transport, tcp::Tcp}; const TOPIC: &str = "test"; @@ -20,9 +22,12 @@ async fn pubsub_channel() { assert!(result.is_ok()); - let result = pubsub_channel_transport(build_quic, "127.0.0.1:9879".parse().unwrap()).await; + #[cfg(feature = "quic")] + { + let result = pubsub_channel_transport(build_quic, "127.0.0.1:9879".parse().unwrap()).await; - assert!(result.is_ok()); + assert!(result.is_ok()); + } } async fn pubsub_channel_transport( @@ -70,9 +75,13 @@ async fn pubsub_fan_out() { assert!(result.is_ok()); - let result = pubsub_fan_out_transport(build_quic, 10, "127.0.0.1:9880".parse().unwrap()).await; + #[cfg(feature = "quic")] + { + let result = + pubsub_fan_out_transport(build_quic, 10, "127.0.0.1:9880".parse().unwrap()).await; - assert!(result.is_ok()); + assert!(result.is_ok()); + } } async fn pubsub_fan_out_transport< @@ -135,9 +144,13 @@ async fn pubsub_fan_in() { assert!(result.is_ok()); - let result = pubsub_fan_in_transport(build_quic, 20, "127.0.0.1:9881".parse().unwrap()).await; + #[cfg(feature = "quic")] + { + let result = + pubsub_fan_in_transport(build_quic, 20, "127.0.0.1:9881".parse().unwrap()).await; - assert!(result.is_ok()); + assert!(result.is_ok()); + } } async fn pubsub_fan_in_transport< @@ -216,6 +229,7 @@ fn build_tcp() -> Tcp { Tcp::default() } +#[cfg(feature = "quic")] fn build_quic() -> Quic { Quic::default() } diff --git a/msg-socket/tests/it/reqrep.rs b/msg-socket/tests/it/reqrep.rs index 697b7306..cf54885f 100644 --- a/msg-socket/tests/it/reqrep.rs +++ b/msg-socket/tests/it/reqrep.rs @@ -2,14 +2,15 @@ use std::time::Duration; use bytes::Bytes; use msg_socket::{DEFAULT_QUEUE_SIZE, RepSocket, ReqOptions, ReqSocket}; -use msg_transport::{ - tcp::Tcp, - tcp_tls::{self, TcpTls}, -}; +use msg_transport::tcp::Tcp; +#[cfg(feature = "tcp-tls")] +use msg_transport::tcp_tls::{self, TcpTls}; +#[cfg(feature = "tcp-tls")] use openssl::ssl::{SslAcceptor, SslMethod}; use tokio_stream::StreamExt; /// Helper functions. +#[cfg(feature = "tcp-tls")] mod helpers { use std::{path::PathBuf, str::FromStr as _}; @@ -83,6 +84,7 @@ async fn reqrep_works() { assert_eq!(hello, response, "expected {hello:?}, got {response:?}"); } +#[cfg(feature = "tcp-tls")] #[tokio::test] async fn reqrep_tls_works() { let _ = tracing_subscriber::fmt::try_init(); @@ -116,6 +118,7 @@ async fn reqrep_tls_works() { /// Test that changing the [`SslAcceptor`] at runtime works and results in not accepting the /// connection after modification. +#[cfg(feature = "tcp-tls")] #[tokio::test] async fn reqrep_tls_control_works() { let _ = tracing_subscriber::fmt::try_init(); @@ -176,6 +179,7 @@ async fn reqrep_tls_control_works() { tokio::time::timeout(Duration::from_secs(1), req.request(hello.clone())).await.unwrap_err(); } +#[cfg(feature = "tcp-tls")] #[tokio::test] async fn reqrep_mutual_tls_works() { let _ = tracing_subscriber::fmt::try_init(); diff --git a/msg-socket/tests/it/turmoil.rs b/msg-socket/tests/it/turmoil.rs new file mode 100644 index 00000000..f8aca918 --- /dev/null +++ b/msg-socket/tests/it/turmoil.rs @@ -0,0 +1,210 @@ +use std::net::{Ipv4Addr, SocketAddr}; +#[cfg(feature = "tcp-tls")] +use std::path::PathBuf; + +use bytes::Bytes; +use msg_socket::{PubSocket, RepSocket, ReqSocket, SubSocket}; +use msg_transport::tcp::Tcp; +#[cfg(feature = "tcp-tls")] +use msg_transport::tcp_tls::{self, TcpTls}; +#[cfg(feature = "tcp-tls")] +use openssl::ssl::{ + SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, SslFiletype, SslMethod, +}; +use tokio::time::{Duration, sleep}; +use tokio_stream::StreamExt; +use turmoil::{Builder, IpVersion, Result}; + +const SERVER_HOST: &str = "server"; +const TCP_PORT: u16 = 17_301; +#[cfg(feature = "tcp-tls")] +const TLS_PORT: u16 = 17_302; +const PUBSUB_PORT: u16 = 17_303; +#[cfg(feature = "tcp-tls")] +const PUBSUB_TLS_PORT: u16 = 17_304; +const TOPIC: &str = "test"; + +fn build_sim() -> turmoil::Sim<'static> { + let mut builder = Builder::new(); + builder.ip_version(IpVersion::V4); + builder.build() +} + +fn bind_addr(port: u16) -> SocketAddr { + SocketAddr::from((Ipv4Addr::UNSPECIFIED, port)) +} + +#[cfg(feature = "tcp-tls")] +fn certificate_dir() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../testdata/certificates") +} + +#[cfg(feature = "tcp-tls")] +fn default_acceptor_builder() -> SslAcceptorBuilder { + let base = certificate_dir(); + let certificate_path = base.join("server-cert.pem"); + let private_key_path = base.join("server-key.pem"); + let ca_certificate_path = base.join("ca-cert.pem"); + + let mut acceptor_builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + acceptor_builder.set_certificate_file(certificate_path, SslFiletype::PEM).unwrap(); + acceptor_builder.set_private_key_file(private_key_path, SslFiletype::PEM).unwrap(); + acceptor_builder.set_ca_file(ca_certificate_path).unwrap(); + acceptor_builder +} + +#[cfg(feature = "tcp-tls")] +fn default_connector_builder() -> SslConnectorBuilder { + let base = certificate_dir(); + let certificate_path = base.join("client-cert.pem"); + let private_key_path = base.join("client-key.pem"); + let ca_certificate_path = base.join("ca-cert.pem"); + + let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap(); + connector_builder.set_certificate_file(certificate_path, SslFiletype::PEM).unwrap(); + connector_builder.set_private_key_file(private_key_path, SslFiletype::PEM).unwrap(); + connector_builder.set_ca_file(ca_certificate_path).unwrap(); + connector_builder +} + +#[test] +fn reqrep_tcp_works_in_turmoil() -> Result { + let _ = tracing_subscriber::fmt::try_init(); + + let mut sim = build_sim(); + + sim.host(SERVER_HOST, || async { + let mut rep = RepSocket::new(Tcp::default()); + rep.bind(bind_addr(TCP_PORT)).await.unwrap(); + + let request = rep.next().await.unwrap(); + let msg = request.msg().clone(); + request.respond(msg).unwrap(); + + Ok(()) + }); + + sim.client("client", async { + let mut req = ReqSocket::new(Tcp::default()); + req.connect(format!("{SERVER_HOST}:{TCP_PORT}")).await.unwrap(); + + let hello = Bytes::from_static(b"hello over turmoil"); + let response = req.request(hello.clone()).await.unwrap(); + assert_eq!(hello, response); + + Ok(()) + }); + + sim.run() +} + +#[test] +fn pubsub_tcp_works_in_turmoil() -> Result { + let _ = tracing_subscriber::fmt::try_init(); + + let mut sim = build_sim(); + + sim.host(SERVER_HOST, || async { + let mut publisher = PubSocket::new(Tcp::default()); + publisher.bind(bind_addr(PUBSUB_PORT)).await.unwrap(); + + for _ in 0..20 { + sleep(Duration::from_millis(50)).await; + publisher.publish(TOPIC, Bytes::from_static(b"hello pubsub")).await.unwrap(); + } + + Ok(()) + }); + + sim.client("client", async { + let mut subscriber = SubSocket::new(Tcp::default()); + subscriber.connect(format!("{SERVER_HOST}:{PUBSUB_PORT}")).await.unwrap(); + subscriber.subscribe(TOPIC).await.unwrap(); + + let msg = subscriber.next().await.unwrap(); + assert_eq!(TOPIC, msg.topic()); + assert_eq!(b"hello pubsub".as_slice(), msg.payload()); + + Ok(()) + }); + + sim.run() +} + +#[cfg(feature = "tcp-tls")] +#[test] +fn pubsub_tcp_tls_works_in_turmoil() -> Result { + let _ = tracing_subscriber::fmt::try_init(); + + let mut sim = build_sim(); + + sim.host(SERVER_HOST, || async { + let server_config = tcp_tls::config::Server::new(default_acceptor_builder().build().into()); + let mut publisher = PubSocket::new(TcpTls::new_server(server_config)); + publisher.bind(bind_addr(PUBSUB_TLS_PORT)).await.unwrap(); + + for _ in 0..20 { + sleep(Duration::from_millis(50)).await; + publisher.publish(TOPIC, Bytes::from_static(b"hello pubsub tls")).await.unwrap(); + } + + Ok(()) + }); + + sim.client("client", async { + let domain = "localhost".to_string(); + let ssl_connector = default_connector_builder().build(); + let tcp_tls_client = TcpTls::new_client( + tcp_tls::config::Client::new(domain).with_ssl_connector(ssl_connector), + ); + let mut subscriber = SubSocket::new(tcp_tls_client); + subscriber.connect(format!("{SERVER_HOST}:{PUBSUB_TLS_PORT}")).await.unwrap(); + subscriber.subscribe(TOPIC).await.unwrap(); + + let msg = subscriber.next().await.unwrap(); + assert_eq!(TOPIC, msg.topic()); + assert_eq!(b"hello pubsub tls".as_slice(), msg.payload()); + + Ok(()) + }); + + sim.run() +} + +#[cfg(feature = "tcp-tls")] +#[test] +fn reqrep_tcp_tls_works_in_turmoil() -> Result { + let _ = tracing_subscriber::fmt::try_init(); + + let mut sim = build_sim(); + + sim.host(SERVER_HOST, || async { + let server_config = tcp_tls::config::Server::new(default_acceptor_builder().build().into()); + let mut rep = RepSocket::new(TcpTls::new_server(server_config)); + rep.bind(bind_addr(TLS_PORT)).await.unwrap(); + + let request = rep.next().await.unwrap(); + let msg = request.msg().clone(); + request.respond(msg).unwrap(); + + Ok(()) + }); + + sim.client("client", async { + let domain = "localhost".to_string(); + let ssl_connector = default_connector_builder().build(); + let tcp_tls_client = TcpTls::new_client( + tcp_tls::config::Client::new(domain).with_ssl_connector(ssl_connector), + ); + let mut req = ReqSocket::new(tcp_tls_client); + req.connect(format!("{SERVER_HOST}:{TLS_PORT}")).await.unwrap(); + + let hello = Bytes::from_static(b"hello over turmoil tls"); + let response = req.request(hello.clone()).await.unwrap(); + assert_eq!(hello, response); + + Ok(()) + }); + + sim.run() +} diff --git a/msg-transport/Cargo.toml b/msg-transport/Cargo.toml index 1776a5c3..2a8553bd 100644 --- a/msg-transport/Cargo.toml +++ b/msg-transport/Cargo.toml @@ -33,6 +33,9 @@ derive_more = { workspace = true, features = [ quinn = { workspace = true, optional = true } rcgen = { workspace = true, optional = true } +# Turmoil simulation +turmoil = { workspace = true, optional = true } + # TLS openssl = { workspace = true, optional = true } tokio-openssl = { workspace = true, optional = true } @@ -49,3 +52,7 @@ tcp-tls = [ "dep:thiserror", "dep:tokio-openssl", ] +# Routes the TCP and TCP-TLS transports through `turmoil::net` for deterministic +# network simulation. Composes with `tcp-tls`: under this feature, TLS traffic +# also runs through the simulator via `SslStream`. +turmoil = ["dep:turmoil"] diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index e5b6e4f8..4e7e5257 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -2,6 +2,21 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(not(test), warn(unused_crate_dependencies))] +// Suppress unused crate warning for libc - it's used conditionally in tcp/stats.rs +#[cfg(all(not(feature = "turmoil"), any(target_os = "macos", target_os = "linux")))] +extern crate libc; + +/// `Send + Sync` sibling of [`futures::future::BoxFuture`]. [`Transport`] requires +/// `Self: Sync`, so any future stored on a transport field has to be `Sync` too, +/// which [`BoxFuture`](futures::future::BoxFuture) (bounded only by `Send`) is not. +/// +/// Used by the TCP and TCP-TLS transports under the `turmoil` feature to hold an +/// in-progress `accept()` future, since `turmoil::net::TcpListener` only exposes +/// an `async fn accept(&self)` (no `poll_accept`). +#[cfg(feature = "turmoil")] +pub(crate) type SyncBoxFuture<'a, T> = + std::pin::Pin + Send + Sync + 'a>>; + use std::{ fmt::Debug, hash::Hash, @@ -27,6 +42,9 @@ pub mod tcp; #[cfg(feature = "tcp-tls")] pub mod tcp_tls; +/// Network type aliases for feature-gated turmoil integration. +pub mod net; + /// A trait for address types that can be used by any transport. pub trait Address: Clone + Debug + Send + Sync + Unpin + Hash + Eq + 'static {} diff --git a/msg-transport/src/net.rs b/msg-transport/src/net.rs new file mode 100644 index 00000000..48a6d095 --- /dev/null +++ b/msg-transport/src/net.rs @@ -0,0 +1,24 @@ +//! Network type aliases for feature-gated turmoil integration. +//! +//! When the `turmoil` feature is enabled, this module exports types from +//! `turmoil::net` that mirror tokio's networking types. This allows +//! applications to use simulated networking without any code changes. +//! +//! # Usage +//! Instead of using `tokio::net::TcpListener` or `tokio::net::TcpStream` directly, +//! use the types exported from this module: +//! +//! ```rust +//! use msg_transport::net::{TcpListener, TcpStream}; +//! ``` +//! +//! # Note on QUIC +//! The QUIC transport uses the [`quinn`] crate which requires real UDP sockets. +//! When using turmoil, only TCP transport participates in the simulation. +//! QUIC connections will use real networking even with the turmoil feature enabled. + +#[cfg(feature = "turmoil")] +pub use turmoil::net::{TcpListener, TcpStream}; + +#[cfg(not(feature = "turmoil"))] +pub use tokio::net::{TcpListener, TcpStream}; diff --git a/msg-transport/src/tcp/mod.rs b/msg-transport/src/tcp/mod.rs index 2b5b61fc..443714dc 100644 --- a/msg-transport/src/tcp/mod.rs +++ b/msg-transport/src/tcp/mod.rs @@ -1,15 +1,22 @@ use futures::future::BoxFuture; +#[cfg(feature = "turmoil")] +use std::sync::Arc; use std::{ io, net::SocketAddr, + pin::Pin, task::{Context, Poll}, }; -use tokio::net::{TcpListener, TcpStream}; use tracing::debug; use msg_common::async_error; -use crate::{Acceptor, PeerAddress, Transport, TransportExt}; +#[cfg(feature = "turmoil")] +use crate::SyncBoxFuture; +use crate::{ + Acceptor, PeerAddress, Transport, TransportExt, + net::{TcpListener, TcpStream}, +}; mod stats; pub use stats::TcpStats; @@ -17,16 +24,55 @@ pub use stats::TcpStats; #[derive(Debug, Default)] pub struct Config; -#[derive(Debug, Default)] +/// TCP transport implementation. +/// +/// When the `turmoil` feature is enabled, this transport uses turmoil's simulated +/// networking types instead of real TCP sockets. This allows for deterministic testing +/// of distributed systems. +#[derive(Default)] pub struct Tcp { #[allow(unused)] config: Config, - listener: Option, + /// The bound listener. + /// + /// Under turmoil the listener is wrapped in an [`Arc`] so that the in-progress + /// accept future created in [`Tcp::poll_accept`] can hold a strong reference for + /// its `'static` lifetime. Both references are owned by this struct, so dropping + /// `Tcp` drops the listener synchronously, which in turmoil unbinds the port + /// from the simulated host immediately. + #[cfg(not(feature = "turmoil"))] + listener: Option, + #[cfg(feature = "turmoil")] + listener: Option>, + /// For turmoil: the in-progress accept future, if any. + /// + /// `turmoil::net::TcpListener` only exposes an `async fn accept(&self)`, so we + /// drive it by holding a pinned future here and polling it from `poll_accept`. + /// Back-pressure is therefore provided by turmoil's listener queue directly, + /// the same way `tokio::net::TcpListener::poll_accept` works in the default + /// build. There's no intermediate mpsc buffer that can fill up and stall the + /// listener. + #[cfg(feature = "turmoil")] + accept_fut: Option>>, +} + +impl std::fmt::Debug for Tcp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tcp") + .field("config", &self.config) + .field("listener", &self.listener.as_ref().map(|_| "")) + .finish() + } } impl Tcp { pub fn new(config: Config) -> Self { - Self { config, listener: None } + Self { + config, + listener: None, + #[cfg(feature = "turmoil")] + accept_fut: None, + } } } @@ -53,9 +99,27 @@ impl Transport for Tcp { } async fn bind(&mut self, addr: SocketAddr) -> Result<(), Self::Error> { + // Bind first, then commit. A failed bind must leave the transport in its + // previous state, matching the non-turmoil path where the old listener is + // never disturbed unless a replacement is ready to take over. let listener = TcpListener::bind(addr).await?; - self.listener = Some(listener); + #[cfg(feature = "turmoil")] + { + // Drop the in-progress accept future before replacing the listener so + // that the old listener's last `Arc` is released in order, triggering + // its synchronous `Drop` which unbinds the port from turmoil's host. + // Because we always bind to a fresh address (the simulated host cannot + // hold two listeners on the same port), the new listener is already + // installed on its own port and will not conflict with the release. + self.accept_fut = None; + self.listener = Some(Arc::new(listener)); + } + + #[cfg(not(feature = "turmoil"))] + { + self.listener = Some(listener); + } Ok(()) } @@ -63,30 +127,60 @@ impl Transport for Tcp { fn connect(&mut self, addr: SocketAddr) -> Self::Connect { Box::pin(async move { let stream = TcpStream::connect(addr).await?; + #[cfg(not(feature = "turmoil"))] stream.set_nodelay(true)?; Ok(stream) }) } - fn poll_accept(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - let Some(ref listener) = this.listener else { - return Poll::Ready(async_error(io::ErrorKind::NotConnected.into())); - }; - - match listener.poll_accept(cx) { - Poll::Ready(Ok((io, addr))) => { - debug!(%addr, "accepted connection"); + #[cfg(not(feature = "turmoil"))] + { + let Some(ref listener) = this.listener else { + return Poll::Ready(async_error(io::ErrorKind::NotConnected.into())); + }; + + match listener.poll_accept(cx) { + Poll::Ready(Ok((io, addr))) => { + debug!(%addr, "accepted connection"); + + Poll::Ready(Box::pin(async move { + io.set_nodelay(true)?; + Ok(io) + })) + } + Poll::Ready(Err(e)) => Poll::Ready(async_error(e)), + Poll::Pending => Poll::Pending, + } + } - Poll::Ready(Box::pin(async move { - io.set_nodelay(true)?; - Ok(io) - })) + #[cfg(feature = "turmoil")] + { + let Some(listener) = this.listener.as_ref().cloned() else { + return Poll::Ready(async_error(io::ErrorKind::NotConnected.into())); + }; + + // Lazily build a `'static` accept future that owns its own strong + // reference to the listener, and keep it alive across pending polls. + let fut = this + .accept_fut + .get_or_insert_with(|| Box::pin(async move { listener.accept().await })); + + match fut.as_mut().poll(cx) { + Poll::Ready(Ok((io, addr))) => { + this.accept_fut = None; + debug!(%addr, "accepted connection"); + Poll::Ready(Box::pin(async move { Ok(io) })) + } + Poll::Ready(Err(e)) => { + this.accept_fut = None; + Poll::Ready(async_error(e)) + } + Poll::Pending => Poll::Pending, } - Poll::Ready(Err(e)) => Poll::Ready(async_error(e)), - Poll::Pending => Poll::Pending, } } } diff --git a/msg-transport/src/tcp/stats.rs b/msg-transport/src/tcp/stats.rs index 28dda56c..9b1af8c4 100644 --- a/msg-transport/src/tcp/stats.rs +++ b/msg-transport/src/tcp/stats.rs @@ -1,7 +1,9 @@ -use std::{os::fd::AsRawFd, time::Duration}; - -use tokio::net::TcpStream; +use std::time::Duration; +/// TCP connection statistics. +/// +/// When using simulated networking (e.g., with the `turmoil` feature), these stats +/// may be unavailable or return default values since there's no real OS-level TCP socket. #[derive(Debug, Default)] pub struct TcpStats { /// The congestion window in bytes. @@ -32,114 +34,124 @@ pub struct TcpStats { pub retransmission_timeout: Duration, } -#[cfg(target_os = "macos")] -impl TryFrom<&TcpStream> for TcpStats { - type Error = std::io::Error; - - /// Gathers stats from the given TCP socket file descriptor, sourced from the OS with - /// [`libc::getsockopt`]. - fn try_from(stream: &TcpStream) -> Result { - let info = getsockopt::(stream, libc::TCP_CONNECTION_INFO)?; +/// Real TCP stats implementation using OS socket options. +/// This is always compiled for tokio's TcpStream (used by tcp-tls and normal tcp). +#[cfg(any(not(feature = "turmoil"), feature = "tcp-tls"))] +mod os_stats { + use super::*; + use std::os::fd::AsRawFd; + + /// Helper function to get a socket option from a TCP stream. + fn getsockopt(stream: &impl AsRawFd, option: libc::c_int) -> std::io::Result { + let mut info = unsafe { std::mem::zeroed::() }; + let mut len = std::mem::size_of::() as libc::socklen_t; + let dst = &mut info as *mut _ as *mut _; + + let result = unsafe { + libc::getsockopt(stream.as_raw_fd(), libc::IPPROTO_TCP, option, dst, &mut len) + }; + + if result != 0 { + return Err(std::io::Error::last_os_error()); + } - Ok(info.into()) + Ok(info) } -} -#[cfg(target_os = "macos")] -impl From for TcpStats { - /// Converts a [`libc::tcp_connection_info`] into [`TcpStats`]. - fn from(info: libc::tcp_connection_info) -> Self { - // Window sizes - let congestion_window = info.tcpi_snd_cwnd; - let receive_window = info.tcpi_rcv_wnd; - let send_window = info.tcpi_snd_wnd; - - // RTT - let last_rtt = Duration::from_millis(info.tcpi_rttcur as u64); - let smoothed_rtt = Duration::from_millis(info.tcpi_srtt as u64); - let rtt_variance = Duration::from_millis(info.tcpi_rttvar as u64); - - // Volumes - let tx_bytes = info.tcpi_txbytes; - let rx_bytes = info.tcpi_rxbytes; - - // Retransmissions - let retransmitted_bytes = info.tcpi_txretransmitbytes; - let retransmitted_packets = info.tcpi_rxretransmitpackets; - let retransmission_timeout = Duration::from_millis(info.tcpi_rto as u64); - - Self { - congestion_window, - receive_window, - send_window, - last_rtt, - smoothed_rtt, - rtt_variance, - tx_bytes, - rx_bytes, - retransmitted_bytes, - retransmitted_packets, - retransmission_timeout, + /// Implement stats for tokio's TcpStream. + #[cfg(target_os = "macos")] + impl TryFrom<&tokio::net::TcpStream> for TcpStats { + type Error = std::io::Error; + + fn try_from(stream: &tokio::net::TcpStream) -> Result { + let info = getsockopt::(stream, libc::TCP_CONNECTION_INFO)?; + Ok(info.into()) } } -} -#[cfg(target_os = "linux")] -impl TryFrom<&TcpStream> for TcpStats { - type Error = std::io::Error; + #[cfg(target_os = "macos")] + impl From for TcpStats { + fn from(info: libc::tcp_connection_info) -> Self { + let congestion_window = info.tcpi_snd_cwnd; + let receive_window = info.tcpi_rcv_wnd; + let send_window = info.tcpi_snd_wnd; + let last_rtt = Duration::from_millis(info.tcpi_rttcur as u64); + let smoothed_rtt = Duration::from_millis(info.tcpi_srtt as u64); + let rtt_variance = Duration::from_millis(info.tcpi_rttvar as u64); + let tx_bytes = info.tcpi_txbytes; + let rx_bytes = info.tcpi_rxbytes; + let retransmitted_bytes = info.tcpi_txretransmitbytes; + let retransmitted_packets = info.tcpi_rxretransmitpackets; + let retransmission_timeout = Duration::from_millis(info.tcpi_rto as u64); + + Self { + congestion_window, + receive_window, + send_window, + last_rtt, + smoothed_rtt, + rtt_variance, + tx_bytes, + rx_bytes, + retransmitted_bytes, + retransmitted_packets, + retransmission_timeout, + } + } + } - /// Gathers stats from the given TCP socket file descriptor, sourced from the OS with - /// [`libc::getsockopt`]. - fn try_from(stream: &TcpStream) -> Result { - let info = getsockopt::(stream, libc::TCP_INFO)?; + #[cfg(target_os = "linux")] + impl TryFrom<&tokio::net::TcpStream> for TcpStats { + type Error = std::io::Error; - Ok(info.into()) + fn try_from(stream: &tokio::net::TcpStream) -> Result { + let info = getsockopt::(stream, libc::TCP_INFO)?; + Ok(info.into()) + } } -} -#[cfg(target_os = "linux")] -impl From for TcpStats { - /// Converts a [`libc::tcp_info`] into [`TcpStats`]. - fn from(info: libc::tcp_info) -> Self { - // On Linux, tcpi_snd_cwnd is in segments; convert to bytes using snd_mss. - let congestion_window = info.tcpi_snd_cwnd.saturating_mul(info.tcpi_snd_mss); - // Local advertised receive window (bytes). - let receive_window = info.tcpi_rcv_space; - - // RTT fields are reported in microseconds. - let smoothed_rtt = Duration::from_micros(info.tcpi_rtt as u64); - let rtt_variance = Duration::from_micros(info.tcpi_rttvar as u64); - - // Retransmissions - let retransmitted_packets = info.tcpi_total_retrans as u64; - let retransmitted_bytes = retransmitted_packets.saturating_mul(info.tcpi_snd_mss as u64); - // RTO is in microseconds. - let retransmission_timeout = Duration::from_micros(info.tcpi_rto as u64); - - Self { - congestion_window, - receive_window, - smoothed_rtt, - rtt_variance, - retransmitted_bytes, - retransmitted_packets, - retransmission_timeout, + #[cfg(target_os = "linux")] + impl From for TcpStats { + fn from(info: libc::tcp_info) -> Self { + let congestion_window = info.tcpi_snd_cwnd.saturating_mul(info.tcpi_snd_mss); + let receive_window = info.tcpi_rcv_space; + let smoothed_rtt = Duration::from_micros(info.tcpi_rtt as u64); + let rtt_variance = Duration::from_micros(info.tcpi_rttvar as u64); + let retransmitted_packets = info.tcpi_total_retrans as u64; + let retransmitted_bytes = + retransmitted_packets.saturating_mul(info.tcpi_snd_mss as u64); + let retransmission_timeout = Duration::from_micros(info.tcpi_rto as u64); + + Self { + congestion_window, + receive_window, + smoothed_rtt, + rtt_variance, + retransmitted_bytes, + retransmitted_packets, + retransmission_timeout, + } } } } -/// Helper function to get a socket option from a TCP stream. -fn getsockopt(stream: &TcpStream, option: libc::c_int) -> std::io::Result { - let mut info = unsafe { std::mem::zeroed::() }; - let mut len = std::mem::size_of::() as libc::socklen_t; - let dst = &mut info as *mut _ as *mut _; - - let result = - unsafe { libc::getsockopt(stream.as_raw_fd(), libc::IPPROTO_TCP, option, dst, &mut len) }; - - if result != 0 { - return Err(std::io::Error::last_os_error()); +// Turmoil simulated stats: OS-level TCP counters don't exist for simulated sockets, +// so we return the default (zeroed) stats. Surfacing an error here would cause +// `MeteredIo::maybe_refresh` to log at error level on every refresh of every active +// connection, which is expected behavior in simulation rather than a failure. +// +// Without the turmoil feature, `crate::net::TcpStream` is `tokio::net::TcpStream`, +// and the impl in `os_stats` already satisfies `Transport::Stats`. +#[cfg(feature = "turmoil")] +mod turmoil_stats { + use super::*; + use crate::net::TcpStream; + + // `From` yields `TryFrom` through the standard library's blanket impl (with + // `Error = Infallible`), which is enough to satisfy `Transport::Stats`. + impl From<&TcpStream> for TcpStats { + fn from(_stream: &TcpStream) -> Self { + TcpStats::default() + } } - - Ok(info) } diff --git a/msg-transport/src/tcp_tls/mod.rs b/msg-transport/src/tcp_tls/mod.rs index ded495c2..e42203b1 100644 --- a/msg-transport/src/tcp_tls/mod.rs +++ b/msg-transport/src/tcp_tls/mod.rs @@ -10,13 +10,18 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tokio::net::{TcpListener, TcpStream}; use tokio_openssl::SslStream; use tracing::debug; use msg_common::async_error; -use crate::{Acceptor, PeerAddress, Transport, TransportExt, tcp::TcpStats}; +#[cfg(feature = "turmoil")] +use crate::SyncBoxFuture; +use crate::{ + Acceptor, PeerAddress, Transport, TransportExt, + net::{TcpListener, TcpStream}, + tcp::TcpStats, +}; pub mod config; @@ -61,14 +66,36 @@ impl Client { /// A TCP-TLS server. pub struct Server { /// The underlying TCP listener. + /// + /// Under turmoil the listener is wrapped in an [`Arc`] so that the in-progress + /// accept future stored in [`Server::accept_fut`] can hold a strong reference + /// for its `'static` lifetime. Both references are owned by this struct, so + /// dropping the server drops the listener synchronously, which under turmoil + /// unbinds the port from the simulated host immediately. + #[cfg(not(feature = "turmoil"))] listener: Option, + #[cfg(feature = "turmoil")] + listener: Option>, + /// For turmoil: the in-progress raw-TCP accept future, if any. + /// + /// `turmoil::net::TcpListener` only exposes `async fn accept(&self)`, so we + /// drive it by holding a pinned future here and polling it from `poll_accept`. + /// This mirrors the tokio `poll_accept` back-pressure behavior exactly; the + /// TLS handshake itself runs in the transport accept future returned afterwards. + #[cfg(feature = "turmoil")] + accept_fut: Option>>, /// The OpenSSL acceptor for TLS handshake requests. acceptor: ArcSwap, } impl Server { pub fn new(acceptor: Arc) -> Self { - Self { listener: None, acceptor: ArcSwap::new(acceptor) } + Self { + listener: None, + #[cfg(feature = "turmoil")] + accept_fut: None, + acceptor: ArcSwap::new(acceptor), + } } pub fn swap_acceptor(&mut self, acceptor: Arc) { @@ -79,7 +106,7 @@ impl Server { impl fmt::Debug for Server { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Server") - .field("listener", &self.listener) + .field("listener", &self.listener.as_ref().map(|_| "")) .field("acceptor", &"SslAcceptor") .finish() } @@ -185,6 +212,18 @@ impl PeerAddress for TcpTlsStream { } } +// Under turmoil there are no OS-level TCP counters to gather, so defer to the +// infallible `From<&TcpStream>` impl that returns default stats. Keeping the +// impls mutually exclusive avoids overlapping the blanket `TryFrom` that +// `From` yields with a second concrete `TryFrom` impl. +#[cfg(feature = "turmoil")] +impl From<&TcpTlsStream> for TcpStats { + fn from(stream: &TcpTlsStream) -> Self { + TcpStats::from(stream.get_ref()) + } +} + +#[cfg(not(feature = "turmoil"))] impl TryFrom<&TcpTlsStream> for TcpStats { type Error = std::io::Error; @@ -222,8 +261,22 @@ impl Transport for TcpTls { return Err(InvalidOperation::BindAsClient.into()); }; + // Bind first, then commit, so that a failed bind leaves any previously + // bound listener intact. let listener = TcpListener::bind(addr).await?; - server.listener = Some(listener); + + #[cfg(feature = "turmoil")] + { + // Drop the in-progress accept future before replacing the listener so + // the old listener's last `Arc` is released in order, triggering its + // synchronous `Drop` which unbinds the port from turmoil's host. + server.accept_fut = None; + server.listener = Some(Arc::new(listener)); + } + #[cfg(not(feature = "turmoil"))] + { + server.listener = Some(listener); + } Ok(()) } @@ -245,8 +298,11 @@ impl Transport for TcpTls { }; let tls_session_state = connector.configure()?.into_ssl(&config.domain)?; - // 2. Establish the TCP connection + // 2. Establish the TCP connection. `set_nodelay` is skipped under turmoil since its + // `TcpStream` stub returns Ok without effect, and tokio's real TCP call is + // unnecessary in the simulator (Nagle doesn't apply). let stream = TcpStream::connect(addr).await?; + #[cfg(not(feature = "turmoil"))] stream.set_nodelay(true)?; // 3. Perform the TLS handshake @@ -263,27 +319,64 @@ impl Transport for TcpTls { return Poll::Ready(async_error(InvalidOperation::AcceptAsClient.into())); }; - let Some(ref listener) = server.listener else { - return Poll::Ready(async_error(Error::IoKind(io::ErrorKind::NotConnected))); + let tls_acceptor = server.acceptor.load_full(); + + // Shared TLS-handshake tail: wrap a raw accepted `TcpStream` in an + // `SslStream` and run `accept()`. Returning an `async move` keeps the + // `Self::Accept` future identical between the tokio and turmoil paths. + let handshake = move |io: TcpStream| -> Self::Accept { + Box::pin(async move { + #[cfg(not(feature = "turmoil"))] + io.set_nodelay(true)?; + + let tls_session_state = Ssl::new(tls_acceptor.context())?; + let mut stream = SslStream::new(tls_session_state, io)?; + Pin::new(&mut stream).accept().await?; + + Ok(stream.into()) + }) }; - let tls_acceptor = server.acceptor.load(); - match listener.poll_accept(cx) { - Poll::Ready(Ok((io, addr))) => { - debug!(%addr, "accepted connection"); + #[cfg(not(feature = "turmoil"))] + { + let Some(ref listener) = server.listener else { + return Poll::Ready(async_error(Error::IoKind(io::ErrorKind::NotConnected))); + }; - Poll::Ready(Box::pin(async move { - io.set_nodelay(true)?; + match listener.poll_accept(cx) { + Poll::Ready(Ok((io, addr))) => { + debug!(%addr, "accepted connection"); + Poll::Ready(handshake(io)) + } + Poll::Ready(Err(e)) => Poll::Ready(async_error(e.into())), + Poll::Pending => Poll::Pending, + } + } - let tls_session_state = Ssl::new(tls_acceptor.context())?; - let mut stream = SslStream::new(tls_session_state, io)?; - Pin::new(&mut stream).accept().await?; + #[cfg(feature = "turmoil")] + { + let Some(listener) = server.listener.as_ref().cloned() else { + return Poll::Ready(async_error(Error::IoKind(io::ErrorKind::NotConnected))); + }; - Ok(stream.into()) - })) + // Lazily build a `'static` accept future that owns its own strong + // reference to the listener, and keep it alive across pending polls. + let fut = server + .accept_fut + .get_or_insert_with(|| Box::pin(async move { listener.accept().await })); + + match fut.as_mut().poll(cx) { + Poll::Ready(Ok((io, addr))) => { + server.accept_fut = None; + debug!(%addr, "accepted connection"); + Poll::Ready(handshake(io)) + } + Poll::Ready(Err(e)) => { + server.accept_fut = None; + Poll::Ready(async_error(e.into())) + } + Poll::Pending => Poll::Pending, } - Poll::Ready(Err(e)) => Poll::Ready(async_error(e.into())), - Poll::Pending => Poll::Pending, } } @@ -309,7 +402,9 @@ impl TransportExt for TcpTls { } } -#[cfg(test)] +// Reaches out to the public internet, which is not routable under turmoil's +// simulated topology. Only build this test in the real-network configuration. +#[cfg(all(test, not(feature = "turmoil")))] mod tests { use tokio::net::lookup_host;