Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions pingora-core/src/protocols/l4/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,17 @@ fn get_opt<T>(
}
}

// strict: strict validation enforces exact size checks on returned objects kernels.
// passing 'false' for strict relaxes validation to allow the kernel to send smaller objects than the size of T.
// allowing smaller size values means the tail stays zeroed and those fields may be “absent,” not real zeros.
#[cfg(target_os = "linux")]
fn get_opt_sized<T>(sock: c_int, opt: c_int, val: c_int) -> io::Result<T> {
fn get_opt_sized<T>(sock: c_int, opt: c_int, val: c_int, strict: bool) -> io::Result<T> {
let mut payload = mem::MaybeUninit::zeroed();
let expected_size = mem::size_of::<T>() as socklen_t;
let mut size = expected_size;
get_opt(sock, opt, val, &mut payload, &mut size)?;

if size != expected_size {
if size > expected_size || (strict && size != expected_size) {
return Err(std::io::Error::other("get_opt size mismatch"));
}
// Assume getsockopt() will set the value properly
Expand Down Expand Up @@ -272,7 +275,7 @@ fn set_keepalive(_sock: RawSocket, _ka: &TcpKeepalive) -> io::Result<()> {
/// Get the kernel TCP_INFO for the given FD.
#[cfg(target_os = "linux")]
pub fn get_tcp_info(fd: RawFd) -> io::Result<TCP_INFO> {
get_opt_sized(fd, libc::IPPROTO_TCP, libc::TCP_INFO)
get_opt_sized(fd, libc::IPPROTO_TCP, libc::TCP_INFO, false)
}

#[cfg(all(unix, not(target_os = "linux")))]
Expand Down Expand Up @@ -304,7 +307,7 @@ pub fn set_recv_buf(_sock: RawSocket, _: usize) -> Result<()> {

#[cfg(target_os = "linux")]
pub fn get_recv_buf(fd: RawFd) -> io::Result<usize> {
get_opt_sized::<c_int>(fd, libc::SOL_SOCKET, libc::SO_RCVBUF).map(|v| v as usize)
get_opt_sized::<c_int>(fd, libc::SOL_SOCKET, libc::SO_RCVBUF, true).map(|v| v as usize)
}

#[cfg(all(unix, not(target_os = "linux")))]
Expand All @@ -319,7 +322,7 @@ pub fn get_recv_buf(_sock: RawSocket) -> io::Result<usize> {

#[cfg(target_os = "linux")]
pub fn get_snd_buf(fd: RawFd) -> io::Result<usize> {
get_opt_sized::<c_int>(fd, libc::SOL_SOCKET, libc::SO_SNDBUF).map(|v| v as usize)
get_opt_sized::<c_int>(fd, libc::SOL_SOCKET, libc::SO_SNDBUF, true).map(|v| v as usize)
}

#[cfg(all(unix, not(target_os = "linux")))]
Expand Down Expand Up @@ -403,7 +406,7 @@ pub fn set_dscp(_sock: RawSocket, _value: u8) -> Result<()> {

#[cfg(target_os = "linux")]
pub fn get_socket_cookie(fd: RawFd) -> io::Result<u64> {
get_opt_sized::<c_ulonglong>(fd, libc::SOL_SOCKET, libc::SO_COOKIE)
get_opt_sized::<c_ulonglong>(fd, libc::SOL_SOCKET, libc::SO_COOKIE, true)
}

#[cfg(all(unix, not(target_os = "linux")))]
Expand All @@ -424,23 +427,24 @@ pub fn get_original_dest(fd: RawFd) -> Result<Option<SocketAddr>> {
.or_err(SocketError, "failed get original dest, invalid IP socket")?;

let dest = if addr.is_ipv4() {
get_opt_sized::<libc::sockaddr_in>(fd, libc::SOL_IP, libc::SO_ORIGINAL_DST).map(|addr| {
SocketAddr::V4(SocketAddrV4::new(
u32::from_be(addr.sin_addr.s_addr).into(),
u16::from_be(addr.sin_port),
))
})
} else {
get_opt_sized::<libc::sockaddr_in6>(fd, libc::SOL_IPV6, libc::IP6T_SO_ORIGINAL_DST).map(
get_opt_sized::<libc::sockaddr_in>(fd, libc::SOL_IP, libc::SO_ORIGINAL_DST, true).map(
|addr| {
SocketAddr::V4(SocketAddrV4::new(
u32::from_be(addr.sin_addr.s_addr).into(),
u16::from_be(addr.sin_port),
))
},
)
} else {
get_opt_sized::<libc::sockaddr_in6>(fd, libc::SOL_IPV6, libc::IP6T_SO_ORIGINAL_DST, true)
.map(|addr| {
SocketAddr::V6(SocketAddrV6::new(
addr.sin6_addr.s6_addr.into(),
u16::from_be(addr.sin6_port),
addr.sin6_flowinfo,
addr.sin6_scope_id,
))
},
)
})
};
dest.or_err(SocketError, "failed to get original dest")
.map(Some)
Expand Down