From 5e0971edf2edc36884f3800ea9f28fef17a31837 Mon Sep 17 00:00:00 2001 From: WhySoBad <49595640+WhySoBad@users.noreply.github.com> Date: Sun, 22 Mar 2026 23:33:28 +0100 Subject: [PATCH] Add `shutdown` shim for network sockets --- src/shims/io_error.rs | 1 + src/shims/unix/foreign_items.rs | 10 +++ src/shims/unix/socket.rs | 53 +++++++++++++- tests/pass-dep/libc/libc-socket.rs | 107 +++++++++++++++++++++++++++++ tests/pass/shims/socket.rs | 43 +++++++++++- 5 files changed, 211 insertions(+), 3 deletions(-) diff --git a/src/shims/io_error.rs b/src/shims/io_error.rs index bb761980e6..cc9459318d 100644 --- a/src/shims/io_error.rs +++ b/src/shims/io_error.rs @@ -251,6 +251,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { fn io_error_to_errnum(&self, err: std::io::Error) -> InterpResult<'tcx, Scalar> { let this = self.eval_context_ref(); let target = &this.tcx.sess.target; + if target.families.iter().any(|f| f == "unix") { for &(name, kind) in UNIX_IO_ERROR_TABLE { if err.kind() == kind { diff --git a/src/shims/unix/foreign_items.rs b/src/shims/unix/foreign_items.rs index c384b24abb..8d1cb7e4d8 100644 --- a/src/shims/unix/foreign_items.rs +++ b/src/shims/unix/foreign_items.rs @@ -676,6 +676,16 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { )?; this.getpeername(socket, address, address_len, dest)?; } + "shutdown" => { + let [sockfd, how] = this.check_shim_sig( + shim_sig!(extern "C" fn(i32, i32) -> i32), + link_name, + abi, + args, + )?; + let result = this.shutdown(sockfd, how)?; + this.write_scalar(result, dest)?; + } // Time "gettimeofday" => { diff --git a/src/shims/unix/socket.rs b/src/shims/unix/socket.rs index c553cd1f70..1465780a14 100644 --- a/src/shims/unix/socket.rs +++ b/src/shims/unix/socket.rs @@ -1,6 +1,6 @@ use std::cell::{Cell, RefCell}; use std::io::Read; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; use std::{io, iter}; @@ -1053,6 +1053,48 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { ), ) } + + fn shutdown(&mut self, socket: &OpTy<'tcx>, how: &OpTy<'tcx>) -> InterpResult<'tcx, Scalar> { + let this = self.eval_context_mut(); + + let socket = this.read_scalar(socket)?.to_i32()?; + let how = this.read_scalar(how)?.to_i32()?; + + // Get the file handle + let Some(fd) = this.machine.fds.get(socket) else { + return this.set_last_error_and_return_i32(LibcError("EBADF")); + }; + + let Some(socket) = fd.downcast::() else { + // Man page specifies to return ENOTSOCK if `fd` is not a socket. + return this.set_last_error_and_return_i32(LibcError("ENOTSOCK")); + }; + + assert!(this.machine.communicate(), "cannot have `Socket` with isolation enabled!"); + + let state = socket.state.borrow(); + + let (SocketState::Connecting(stream) | SocketState::Connected(stream)) = &*state else { + return this.set_last_error_and_return_i32(LibcError("ENOTCONN")); + }; + + let shut_rd = this.eval_libc_i32("SHUT_RD"); + let shut_wr = this.eval_libc_i32("SHUT_WR"); + let shut_rdwr = this.eval_libc_i32("SHUT_RDWR"); + + let how = match () { + _ if how == shut_rd => Shutdown::Read, + _ if how == shut_wr => Shutdown::Write, + _ if how == shut_rdwr => Shutdown::Both, + // An invalid value was passed to `how`. + _ => return this.set_last_error_and_return_i32(LibcError("EINVAL")), + }; + + match stream.shutdown(how) { + Ok(_) => interp_ok(Scalar::from_i32(0)), + Err(e) => this.set_last_error_and_return_i32(e), + } + } } impl<'tcx> EvalContextPrivExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -1487,6 +1529,15 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // would be returned on UNIX-like systems. We thus remap this error to an EWOULDBLOCK. interp_ok(Err(IoError::HostError(io::ErrorKind::WouldBlock.into()))) } + Err(IoError::HostError(e)) + if cfg!(windows) + && matches!(e.raw_os_error(), Some(/* WSAESHUTDOWN error code */ 10058)) => + { + // FIXME: This is a temporary workaround for handling WSAESHUTDOWN errors + // on Windows. A discussion on how those errors should be handled can be found here: + // + interp_ok(Err(IoError::HostError(io::ErrorKind::BrokenPipe.into()))) + } result => interp_ok(result), } } diff --git a/tests/pass-dep/libc/libc-socket.rs b/tests/pass-dep/libc/libc-socket.rs index c87677a756..1430260588 100644 --- a/tests/pass-dep/libc/libc-socket.rs +++ b/tests/pass-dep/libc/libc-socket.rs @@ -44,6 +44,10 @@ fn main() { test_getpeername_ipv4(); test_getpeername_ipv6(); + + test_shutdown(); + test_shutdown_readable_after_write_close(); + test_shutdown_writable_after_read_close(); } /// Test creating a socket and then closing it afterwards. @@ -488,3 +492,106 @@ fn test_getpeername_ipv6() { server_thread.join().unwrap(); } + +/// Test shutting down TCP streams. +fn test_shutdown() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + // Spawn the server thread. + let server_thread = thread::spawn(move || net::accept_ipv4(server_sockfd).unwrap()); + + let mut byte = [0u8]; + + net::connect_ipv4(client_sockfd, addr).unwrap(); + let client_dup_sockfd = unsafe { libc::dup(client_sockfd) }; + + // Closing should prevent reads/writes. + unsafe { + libc::shutdown(client_sockfd, libc::SHUT_RDWR); + let err = errno_result(libc::write(client_sockfd, [0u8].as_ptr().cast(), 1)).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::BrokenPipe); + let bytes_read = + errno_result(libc::read(client_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap(); + assert_eq!(bytes_read, 0); + } + + // TODO: Once epoll is available for TCP sockets, ensure that the rdhup and hup readiness + // are set. + + // Closing should affect previously duplicated handles. + unsafe { + let err = + errno_result(libc::write(client_dup_sockfd, [0u8].as_ptr().cast(), 1)).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::BrokenPipe); + let bytes_read = + errno_result(libc::read(client_dup_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap(); + assert_eq!(bytes_read, 0); + } + + // Closing should affect newly duplicated handles. + unsafe { + let client_dup2_sockfd = libc::dup(client_sockfd); + let err = + errno_result(libc::write(client_dup2_sockfd, [0u8].as_ptr().cast(), 1)).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::BrokenPipe); + let bytes_read = + errno_result(libc::read(client_dup2_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap(); + assert_eq!(bytes_read, 0); + } + + server_thread.join().unwrap(); +} + +/// Test that a socket is still readable after the write end has +/// been closed. +fn test_shutdown_readable_after_write_close() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + let (peerfd, _) = net::accept_ipv4(server_sockfd).unwrap(); + // Write a single byte which should be read later on. + unsafe { errno_result(libc::write(peerfd, [1u8].as_ptr().cast(), 1)).unwrap() }; + }); + + net::connect_ipv4(client_sockfd, addr).unwrap(); + + unsafe { + // Close the write end. + libc::shutdown(client_sockfd, libc::SHUT_WR); + + // Ensure that we're still readable. + let mut byte = [0u8]; + errno_result(libc::read(client_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap(); + assert_eq!(&byte, &[1u8]); + } + + server_thread.join().unwrap(); +} + +/// Test that a socket is still writable after the read end has +/// been closed. +fn test_shutdown_writable_after_read_close() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + // Spawn the server thread. + let server_thread = thread::spawn(move || net::accept_ipv4(server_sockfd).unwrap()); + + net::connect_ipv4(client_sockfd, addr).unwrap(); + + unsafe { + // Close the read end. + libc::shutdown(client_sockfd, libc::SHUT_RD); + + // Ensure that we're still writable. + errno_result(libc::write(client_sockfd, [1u8].as_ptr().cast(), 1)).unwrap(); + } + + server_thread.join().unwrap(); +} diff --git a/tests/pass/shims/socket.rs b/tests/pass/shims/socket.rs index ec2cc76700..584ce8f60c 100644 --- a/tests/pass/shims/socket.rs +++ b/tests/pass/shims/socket.rs @@ -1,8 +1,8 @@ //@ignore-target: windows # No libc socket on Windows //@compile-flags: -Zmiri-disable-isolation -use std::io::{Read, Write}; -use std::net::{TcpListener, TcpStream}; +use std::io::{ErrorKind, Read, Write}; +use std::net::{Shutdown, TcpListener, TcpStream}; use std::thread; const TEST_BYTES: &[u8] = b"these are some test bytes!"; @@ -14,6 +14,7 @@ fn main() { test_read_write(); test_peek(); test_peer_addr(); + test_shutdown(); } fn test_create_ipv4_listener() { @@ -113,3 +114,41 @@ fn test_peer_addr() { handle.join().unwrap(); } + +/// Test shutting down TCP streams. +fn test_shutdown() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + // Get local address with randomized port to know where + // we need to connect to. + let address = listener.local_addr().unwrap(); + + // Start server thread. + let handle = thread::spawn(move || { + let (stream, _addr) = listener.accept().unwrap(); + // Return stream from thread such that it doesn't get dropped too early. + stream + }); + + let mut byte = [0u8]; + let mut stream = TcpStream::connect(address).unwrap(); + let mut stream_clone = stream.try_clone().unwrap(); + + // Closing should prevent reads/writes. + stream.shutdown(Shutdown::Write).unwrap(); + stream.write(&[0]).unwrap_err(); + stream.shutdown(Shutdown::Read).unwrap(); + assert_eq!(stream.read(&mut byte).unwrap(), 0); + + // Closing should affect previously cloned handles. + let err = stream_clone.write(&[0]).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::BrokenPipe); + assert_eq!(stream_clone.read(&mut byte).unwrap(), 0); + + // Closing should affect newly cloned handles. + let mut stream_other_clone = stream.try_clone().unwrap(); + let err = stream_other_clone.write(&[0]).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::BrokenPipe); + assert_eq!(stream_other_clone.read(&mut byte).unwrap(), 0); + + let _stream = handle.join().unwrap(); +}