Skip to content

Commit

Permalink
sys::socket listen's Backlog wrapper type addition.
Browse files Browse the repository at this point in the history
changing the sys::socket::listen backlog type from `usize` to
a `i32` wrapper, offering known sane values, from -1, SOMAXCONN to
 511.

close nix-rustgh-2264
  • Loading branch information
devnexen committed Dec 30, 2023
1 parent 4e2d917 commit 792b568
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ targets = [
]

[dependencies]
libc = { version = "0.2.151", features = ["extra_traits"] }
libc = { git = "https://github.com/rust-lang/libc", rev = "6a203e955b60cca48562f020f0e4e003079f3199", features = ["extra_traits"] }
bitflags = "2.3.1"
cfg-if = "1.0"
pin-utils = { version = "0.1.0", optional = true }
Expand Down
1 change: 1 addition & 0 deletions changelog/2276.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added the `Backlog` wrapper type for the `listen` call.
137 changes: 135 additions & 2 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2009,12 +2009,145 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
unsafe { Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1]))) }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Backlog(i32);

impl Backlog {
/// Sets the listen queue size to system `SOMAXCONN` value
pub const MAXCONN: Self = Self(libc::SOMAXCONN);
/// Sets the listen queue size to -1 for system supporting it
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
pub const MAXALLOWABLE: Self = Self(-1);

// Create a `Backlog`, an `EINVAL` will be returned if `val` is invalid.
pub fn new<I: Into<i32> + PartialOrd<I> + From<i32>>(val: I) -> Result<Self> {
cfg_if! {
if #[cfg(any(target_os = "linux", target_os = "freebsd"))] {
const MIN: i32 = -1;
} else {
const MIN: i32 = 0;
}
}

if val < MIN.into() || val > Self::MAXCONN.0.into() {
return Err(Errno::EINVAL);
}

Ok(Self(val.into()))
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BacklogTryFromError {
TooNegative,
TooPositive,
}

impl std::fmt::Display for BacklogTryFromError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
Self::TooNegative => write!(f, "Passed a positive backlog less than -1"),
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
Self::TooNegative => write!(f, "Passed a positive backlog less than 0"),
Self::TooPositive => write!(f, "Passed a positive backlog greater than `{:?}`", Backlog::MAXCONN)
}
}
}

impl std::error::Error for BacklogTryFromError {}

impl From<u16> for Backlog {
fn from(backlog: u16) -> Self {
Self(i32::from(backlog))
}
}

impl From<u8> for Backlog {
fn from(backlog: u8) -> Self {
Self(i32::from(backlog))
}
}

impl From<Backlog> for i32 {
fn from(backlog: Backlog) -> Self {
backlog.0
}
}

impl TryFrom<i64> for Backlog {
type Error = BacklogTryFromError;
fn try_from(backlog: i64) -> std::result::Result<Self, Self::Error> {
match backlog {
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
..=-2 => Err(BacklogTryFromError::TooNegative),
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
..=-1 => Err(BacklogTryFromError::TooNegative),
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
val if (-1..=i64::from(Backlog::MAXCONN.0)).contains(&val) => Ok(Self(i32::try_from(backlog).map_err(|_| BacklogTryFromError::TooPositive)?)),
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
val if (0..=i64::from(Backlog::MAXCONN.0)).contains(&val) => Ok(Self(i32::try_from(backlog).map_err(|_| BacklogTryFromError::TooPositive)?)),
_ => Err(BacklogTryFromError::TooPositive),
}
}
}

impl TryFrom<i32> for Backlog {
type Error = BacklogTryFromError;
fn try_from(backlog: i32) -> std::result::Result<Self, Self::Error> {
match backlog {
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
..=-2 => Err(BacklogTryFromError::TooNegative),
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
..=-1 => Err(BacklogTryFromError::TooNegative),
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
val if (-1..=Backlog::MAXCONN.0).contains(&val) => Ok(Self(backlog)),
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
val if (0..=Backlog::MAXCONN.0).contains(&val) => Ok(Self(backlog)),
_ => Err(BacklogTryFromError::TooPositive),
}
}
}

impl TryFrom<i16> for Backlog {
type Error = BacklogTryFromError;
fn try_from(backlog: i16) -> std::result::Result<Self, Self::Error> {
match backlog {
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
..=-2 => Err(BacklogTryFromError::TooNegative),
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
..=-1 => Err(BacklogTryFromError::TooNegative),
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
val if (-1..=i16::try_from(Backlog::MAXCONN.0).unwrap()).contains(&val) => Ok(Self(i32::from(backlog))),
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
val if (0..=i16::try_from(Backlog::MAXCONN.0).unwrap()).contains(&val) => Ok(Self(i32::from(backlog))),
_ => Err(BacklogTryFromError::TooPositive),
}
}
}

impl TryFrom<i8> for Backlog {
type Error = BacklogTryFromError;
fn try_from(backlog: i8) -> std::result::Result<Self, Self::Error> {
match backlog {
..=-2 => Err(BacklogTryFromError::TooNegative),
_ => Err(BacklogTryFromError::TooPositive),
}
}
}

impl<T: Into<Backlog>> From<Option<T>> for Backlog {
fn from(backlog: Option<T>) -> Self {
backlog.map_or(Self::MAXCONN, |b| b.into())
}
}

/// Listen for connections on a socket
///
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html)
pub fn listen<F: AsFd>(sock: &F, backlog: usize) -> Result<()> {
pub fn listen<F: AsFd, B: Into<Backlog>>(sock: &F, backlog: B) -> Result<()> {
let fd = sock.as_fd().as_raw_fd();
let res = unsafe { libc::listen(fd, backlog as c_int) };
let res = unsafe { libc::listen(fd, i32::from(backlog.into())) };

Errno::result(res).map(drop)
}
Expand Down
15 changes: 13 additions & 2 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
// Test creating and using named unix domain sockets
#[test]
pub fn test_named_unixdomain() {
use nix::sys::socket::{accept, bind, connect, listen, socket, UnixAddr};
use nix::sys::socket::{
accept, bind, connect, listen, socket, Backlog, UnixAddr,
};
use nix::sys::socket::{SockFlag, SockType};
use nix::unistd::{read, write};
use std::thread;
Expand All @@ -1627,7 +1629,7 @@ pub fn test_named_unixdomain() {
.expect("socket failed");
let sockaddr = UnixAddr::new(&sockname).unwrap();
bind(s1.as_raw_fd(), &sockaddr).expect("bind failed");
listen(&s1, 10).expect("listen failed");
listen(&s1, Backlog::new(10).unwrap()).expect("listen failed");

let thr = thread::spawn(move || {
let s2 = socket(
Expand All @@ -1650,6 +1652,15 @@ pub fn test_named_unixdomain() {
assert_eq!(&buf[..], b"hello");
}

#[test]
pub fn test_listen_wrongbacklog() {
use nix::sys::socket::Backlog;

assert!(Backlog::new(5012).is_err());
assert!(Backlog::new(65535).is_err());
assert!(Backlog::new(-2).is_err());
}

// Test using unnamed unix domain addresses
#[cfg(linux_android)]
#[test]
Expand Down
15 changes: 9 additions & 6 deletions test/sys/test_sockopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn test_so_buf() {
#[cfg(target_os = "freebsd")]
#[test]
fn test_so_listen_q_limit() {
use nix::sys::socket::{bind, listen, SockaddrIn};
use nix::sys::socket::{bind, listen, Backlog, SockaddrIn};
use std::net::SocketAddrV4;
use std::str::FromStr;

Expand All @@ -123,14 +123,16 @@ fn test_so_listen_q_limit() {
bind(rsock.as_raw_fd(), &sock_addr).unwrap();
let pre_limit = getsockopt(&rsock, sockopt::ListenQLimit).unwrap();
assert_eq!(pre_limit, 0);
listen(&rsock, 42).unwrap();
listen(&rsock, Backlog::new(42).unwrap()).unwrap();
let post_limit = getsockopt(&rsock, sockopt::ListenQLimit).unwrap();
assert_eq!(post_limit, 42);
}

#[test]
fn test_so_tcp_maxseg() {
use nix::sys::socket::{accept, bind, connect, listen, SockaddrIn};
use nix::sys::socket::{
accept, bind, connect, listen, Backlog, SockaddrIn,
};
use nix::unistd::write;
use std::net::SocketAddrV4;
use std::str::FromStr;
Expand All @@ -146,7 +148,7 @@ fn test_so_tcp_maxseg() {
)
.unwrap();
bind(rsock.as_raw_fd(), &sock_addr).unwrap();
listen(&rsock, 10).unwrap();
listen(&rsock, Backlog::from(10u16)).unwrap();
let initial = getsockopt(&rsock, sockopt::TcpMaxSeg).unwrap();
// Initial MSS is expected to be 536 (https://tools.ietf.org/html/rfc879#section-1) but some
// platforms keep it even lower. This might fail if you've tuned your initial MSS to be larger
Expand Down Expand Up @@ -716,7 +718,8 @@ fn is_socket_type_dgram() {
#[test]
fn can_get_listen_on_tcp_socket() {
use nix::sys::socket::{
getsockopt, listen, socket, sockopt, AddressFamily, SockFlag, SockType,
getsockopt, listen, socket, sockopt, AddressFamily, Backlog, SockFlag,
SockType,
};

let s = socket(
Expand All @@ -728,7 +731,7 @@ fn can_get_listen_on_tcp_socket() {
.unwrap();
let s_listening = getsockopt(&s, sockopt::AcceptConn).unwrap();
assert!(!s_listening);
listen(&s, 10).unwrap();
listen(&s, Backlog::new(10).unwrap()).unwrap();
let s_listening2 = getsockopt(&s, sockopt::AcceptConn).unwrap();
assert!(s_listening2);
}

0 comments on commit 792b568

Please sign in to comment.