Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/shims/io_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn io_error_to_errnum(&self, err: std::io::Error) -> InterpResult<'tcx, Scalar> {
let this = self.eval_context_ref();
let target = &this.tcx.sess.target;

if target.families.iter().any(|f| f == "unix") {
for &(name, kind) in UNIX_IO_ERROR_TABLE {
if err.kind() == kind {
Expand Down
10 changes: 10 additions & 0 deletions src/shims/unix/foreign_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,16 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
)?;
this.getpeername(socket, address, address_len, dest)?;
}
"shutdown" => {
let [sockfd, how] = this.check_shim_sig(
shim_sig!(extern "C" fn(i32, i32) -> i32),
link_name,
abi,
args,
)?;
let result = this.shutdown(sockfd, how)?;
this.write_scalar(result, dest)?;
}

// Time
"gettimeofday" => {
Expand Down
53 changes: 52 additions & 1 deletion src/shims/unix/socket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::cell::{Cell, RefCell};
use std::io::Read;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration;
use std::{io, iter};

Expand Down Expand Up @@ -1053,6 +1053,48 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
),
)
}

fn shutdown(&mut self, socket: &OpTy<'tcx>, how: &OpTy<'tcx>) -> InterpResult<'tcx, Scalar> {
let this = self.eval_context_mut();

let socket = this.read_scalar(socket)?.to_i32()?;
let how = this.read_scalar(how)?.to_i32()?;

// Get the file handle
let Some(fd) = this.machine.fds.get(socket) else {
return this.set_last_error_and_return_i32(LibcError("EBADF"));
};

let Some(socket) = fd.downcast::<Socket>() else {
// Man page specifies to return ENOTSOCK if `fd` is not a socket.
return this.set_last_error_and_return_i32(LibcError("ENOTSOCK"));
};

assert!(this.machine.communicate(), "cannot have `Socket` with isolation enabled!");

let state = socket.state.borrow();

let (SocketState::Connecting(stream) | SocketState::Connected(stream)) = &*state else {
return this.set_last_error_and_return_i32(LibcError("ENOTCONN"));
};

let shut_rd = this.eval_libc_i32("SHUT_RD");
let shut_wr = this.eval_libc_i32("SHUT_WR");
let shut_rdwr = this.eval_libc_i32("SHUT_RDWR");

let how = match () {
_ if how == shut_rd => Shutdown::Read,
_ if how == shut_wr => Shutdown::Write,
_ if how == shut_rdwr => Shutdown::Both,
// An invalid value was passed to `how`.
_ => return this.set_last_error_and_return_i32(LibcError("EINVAL")),
};

match stream.shutdown(how) {
Ok(_) => interp_ok(Scalar::from_i32(0)),
Err(e) => this.set_last_error_and_return_i32(e),
}
}
}

impl<'tcx> EvalContextPrivExt<'tcx> for crate::MiriInterpCx<'tcx> {}
Expand Down Expand Up @@ -1487,6 +1529,15 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// would be returned on UNIX-like systems. We thus remap this error to an EWOULDBLOCK.
interp_ok(Err(IoError::HostError(io::ErrorKind::WouldBlock.into())))
}
Err(IoError::HostError(e))
if cfg!(windows)
&& matches!(e.raw_os_error(), Some(/* WSAESHUTDOWN error code */ 10058)) =>
{
// FIXME: This is a temporary workaround for handling WSAESHUTDOWN errors
// on Windows. A discussion on how those errors should be handled can be found here:
// <https://rust-lang.zulipchat.com/#narrow/channel/219381-t-libs/topic/WSAESHUTDOWN.20error.20on.20Windows/near/591883531>
interp_ok(Err(IoError::HostError(io::ErrorKind::BrokenPipe.into())))
}
result => interp_ok(result),
}
}
Expand Down
107 changes: 107 additions & 0 deletions tests/pass-dep/libc/libc-socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ fn main() {

test_getpeername_ipv4();
test_getpeername_ipv6();

test_shutdown();
test_shutdown_readable_after_write_close();
test_shutdown_writable_after_read_close();
}

/// Test creating a socket and then closing it afterwards.
Expand Down Expand Up @@ -488,3 +492,106 @@ fn test_getpeername_ipv6() {

server_thread.join().unwrap();
}

/// Test shutting down TCP streams.
fn test_shutdown() {
let (server_sockfd, addr) = net::make_listener_ipv4().unwrap();
let client_sockfd =
unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() };

// Spawn the server thread.
let server_thread = thread::spawn(move || net::accept_ipv4(server_sockfd).unwrap());

let mut byte = [0u8];

net::connect_ipv4(client_sockfd, addr).unwrap();
let client_dup_sockfd = unsafe { libc::dup(client_sockfd) };

// Closing should prevent reads/writes.
unsafe {
libc::shutdown(client_sockfd, libc::SHUT_RDWR);
let err = errno_result(libc::write(client_sockfd, [0u8].as_ptr().cast(), 1)).unwrap_err();
assert_eq!(err.kind(), ErrorKind::BrokenPipe);
let bytes_read =
errno_result(libc::read(client_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap();
assert_eq!(bytes_read, 0);
}

// TODO: Once epoll is available for TCP sockets, ensure that the rdhup and hup readiness
// are set.

// Closing should affect previously duplicated handles.
unsafe {
let err =
errno_result(libc::write(client_dup_sockfd, [0u8].as_ptr().cast(), 1)).unwrap_err();
assert_eq!(err.kind(), ErrorKind::BrokenPipe);
let bytes_read =
errno_result(libc::read(client_dup_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap();
assert_eq!(bytes_read, 0);
}

// Closing should affect newly duplicated handles.
unsafe {
let client_dup2_sockfd = libc::dup(client_sockfd);
let err =
errno_result(libc::write(client_dup2_sockfd, [0u8].as_ptr().cast(), 1)).unwrap_err();
assert_eq!(err.kind(), ErrorKind::BrokenPipe);
let bytes_read =
errno_result(libc::read(client_dup2_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap();
assert_eq!(bytes_read, 0);
}

server_thread.join().unwrap();
}

/// Test that a socket is still readable after the write end has
/// been closed.
fn test_shutdown_readable_after_write_close() {
let (server_sockfd, addr) = net::make_listener_ipv4().unwrap();
let client_sockfd =
unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() };

// Spawn the server thread.
let server_thread = thread::spawn(move || {
let (peerfd, _) = net::accept_ipv4(server_sockfd).unwrap();
// Write a single byte which should be read later on.
unsafe { errno_result(libc::write(peerfd, [1u8].as_ptr().cast(), 1)).unwrap() };
});

net::connect_ipv4(client_sockfd, addr).unwrap();

unsafe {
// Close the write end.
libc::shutdown(client_sockfd, libc::SHUT_WR);

// Ensure that we're still readable.
let mut byte = [0u8];
errno_result(libc::read(client_sockfd, byte.as_mut_ptr().cast(), 1)).unwrap();
assert_eq!(&byte, &[1u8]);
}

server_thread.join().unwrap();
}

/// Test that a socket is still writable after the read end has
/// been closed.
fn test_shutdown_writable_after_read_close() {
let (server_sockfd, addr) = net::make_listener_ipv4().unwrap();
let client_sockfd =
unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() };

// Spawn the server thread.
let server_thread = thread::spawn(move || net::accept_ipv4(server_sockfd).unwrap());

net::connect_ipv4(client_sockfd, addr).unwrap();

unsafe {
// Close the read end.
libc::shutdown(client_sockfd, libc::SHUT_RD);

// Ensure that we're still writable.
errno_result(libc::write(client_sockfd, [1u8].as_ptr().cast(), 1)).unwrap();
}

server_thread.join().unwrap();
}
43 changes: 41 additions & 2 deletions tests/pass/shims/socket.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//@ignore-target: windows # No libc socket on Windows
//@compile-flags: -Zmiri-disable-isolation

use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, TcpListener, TcpStream};
use std::thread;

const TEST_BYTES: &[u8] = b"these are some test bytes!";
Expand All @@ -14,6 +14,7 @@ fn main() {
test_read_write();
test_peek();
test_peer_addr();
test_shutdown();
}

fn test_create_ipv4_listener() {
Expand Down Expand Up @@ -113,3 +114,41 @@ fn test_peer_addr() {

handle.join().unwrap();
}

/// Test shutting down TCP streams.
fn test_shutdown() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
// Get local address with randomized port to know where
// we need to connect to.
let address = listener.local_addr().unwrap();

// Start server thread.
let handle = thread::spawn(move || {
let (stream, _addr) = listener.accept().unwrap();
// Return stream from thread such that it doesn't get dropped too early.
stream
});

let mut byte = [0u8];
let mut stream = TcpStream::connect(address).unwrap();
let mut stream_clone = stream.try_clone().unwrap();

// Closing should prevent reads/writes.
stream.shutdown(Shutdown::Write).unwrap();
stream.write(&[0]).unwrap_err();
stream.shutdown(Shutdown::Read).unwrap();
assert_eq!(stream.read(&mut byte).unwrap(), 0);

// Closing should affect previously cloned handles.
let err = stream_clone.write(&[0]).unwrap_err();
assert_eq!(err.kind(), ErrorKind::BrokenPipe);
assert_eq!(stream_clone.read(&mut byte).unwrap(), 0);

// Closing should affect newly cloned handles.
let mut stream_other_clone = stream.try_clone().unwrap();
let err = stream_other_clone.write(&[0]).unwrap_err();
assert_eq!(err.kind(), ErrorKind::BrokenPipe);
assert_eq!(stream_other_clone.read(&mut byte).unwrap(), 0);

let _stream = handle.join().unwrap();
}