Skip to content

Commit ee849ef

Browse files
committed
feat(socket): make AF_INET non-zero
1 parent 6131789 commit ee849ef

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

src/fd/socket/tcp.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ pub struct Socket {
3737
port: u16,
3838
is_nonblocking: bool,
3939
is_listen: bool,
40+
// TODO: remove once the ecosystem has migrated away from `AF_INET_OLD`.
41+
domain: i32,
4042
}
4143

4244
impl Socket {
43-
pub fn new(h: Handle) -> Self {
45+
pub fn new(h: Handle, domain: i32) -> Self {
4446
let mut handle = BTreeSet::new();
4547
handle.insert(h);
4648

@@ -49,6 +51,7 @@ impl Socket {
4951
port: 0,
5052
is_nonblocking: false,
5153
is_listen: false,
54+
domain,
5255
}
5356
}
5457

@@ -349,6 +352,7 @@ impl Socket {
349352
port: self.port,
350353
is_nonblocking: self.is_nonblocking,
351354
is_listen: false,
355+
domain: self.domain,
352356
};
353357

354358
Ok((socket, endpoint))
@@ -520,4 +524,9 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
520524
async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
521525
self.write().await.set_status_flags(status_flags).await
522526
}
527+
528+
async fn inet_domain(&self) -> io::Result<i32> {
529+
let domain = self.read().await.domain;
530+
Ok(domain)
531+
}
523532
}

src/fd/socket/udp.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@ pub struct Socket {
1818
handle: Handle,
1919
nonblocking: bool,
2020
endpoint: Option<IpEndpoint>,
21+
// TODO: remove once the ecosystem has migrated away from `AF_INET_OLD`.
22+
domain: i32,
2123
}
2224

2325
impl Socket {
24-
pub fn new(handle: Handle) -> Self {
26+
pub fn new(handle: Handle, domain: i32) -> Self {
2527
Self {
2628
handle,
2729
nonblocking: false,
2830
endpoint: None,
31+
domain,
2932
}
3033
}
3134

@@ -274,4 +277,9 @@ impl ObjectInterface for async_lock::RwLock<Socket> {
274277
async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
275278
self.write().await.set_status_flags(status_flags).await
276279
}
280+
281+
async fn inet_domain(&self) -> io::Result<i32> {
282+
let domain = self.read().await.domain;
283+
Ok(domain)
284+
}
277285
}

src/syscalls/socket.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ use crate::fd::{
2525
use crate::syscalls::block_on;
2626

2727
pub const AF_UNSPEC: i32 = 0;
28-
pub const AF_INET: i32 = 0;
28+
pub const AF_INET_OLD: i32 = 0;
29+
pub const AF_INET: i32 = 3;
2930
pub const AF_INET6: i32 = 1;
3031
pub const AF_VSOCK: i32 = 2;
3132
pub const IPPROTO_IP: i32 = 0;
@@ -437,7 +438,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32
437438
}
438439

439440
#[cfg(any(feature = "tcp", feature = "udp"))]
440-
if (domain == AF_INET || domain == AF_INET6)
441+
if (domain == AF_INET_OLD || domain == AF_INET || domain == AF_INET6)
441442
&& type_.intersects(SockType::SOCK_STREAM | SockType::SOCK_DGRAM)
442443
{
443444
let mut guard = NIC.lock();
@@ -447,7 +448,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32
447448
if type_.contains(SockType::SOCK_DGRAM) {
448449
let handle = nic.create_udp_handle().unwrap();
449450
drop(guard);
450-
let socket = Arc::new(async_lock::RwLock::new(udp::Socket::new(handle)));
451+
let socket = Arc::new(async_lock::RwLock::new(udp::Socket::new(handle, domain)));
451452

452453
if type_.contains(SockType::SOCK_NONBLOCK) {
453454
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();
@@ -462,7 +463,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32
462463
if type_.contains(SockType::SOCK_STREAM) {
463464
let handle = nic.create_tcp_handle().unwrap();
464465
drop(guard);
465-
let socket = Arc::new(async_lock::RwLock::new(tcp::Socket::new(handle)));
466+
let socket = Arc::new(async_lock::RwLock::new(tcp::Socket::new(handle, domain)));
466467

467468
if type_.contains(SockType::SOCK_NONBLOCK) {
468469
block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap();
@@ -500,6 +501,7 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut
500501
if *addrlen >= size_of::<sockaddr_in>().try_into().unwrap() {
501502
let addr = unsafe { &mut *addr.cast() };
502503
*addr = sockaddr_in::from(endpoint);
504+
addr.sin_family = block_on(v.inet_domain(), None).unwrap().try_into().unwrap();
503505
*addrlen = size_of::<sockaddr_in>().try_into().unwrap();
504506
}
505507
}
@@ -564,7 +566,7 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl
564566
|e| -num::ToPrimitive::to_i32(&e).unwrap(),
565567
|v| match family {
566568
#[cfg(any(feature = "tcp", feature = "udp"))]
567-
AF_INET => {
569+
AF_INET_OLD | AF_INET => {
568570
if namelen < size_of::<sockaddr_in>().try_into().unwrap() {
569571
return -crate::errno::EINVAL;
570572
}
@@ -606,7 +608,7 @@ pub unsafe extern "C" fn sys_connect(fd: i32, name: *const sockaddr, namelen: so
606608

607609
let endpoint = match sa_family {
608610
#[cfg(any(feature = "tcp", feature = "udp"))]
609-
AF_INET => {
611+
AF_INET_OLD | AF_INET => {
610612
if namelen < size_of::<sockaddr_in>().try_into().unwrap() {
611613
return -crate::errno::EINVAL;
612614
}
@@ -663,6 +665,7 @@ pub unsafe extern "C" fn sys_getsockname(
663665
if *addrlen >= size_of::<sockaddr_in>().try_into().unwrap() {
664666
let addr = unsafe { &mut *addr.cast() };
665667
*addr = sockaddr_in::from(endpoint);
668+
addr.sin_family = block_on(v.inet_domain(), None).unwrap().try_into().unwrap();
666669
*addrlen = size_of::<sockaddr_in>().try_into().unwrap();
667670
} else {
668671
return -crate::errno::EINVAL;
@@ -803,6 +806,7 @@ pub unsafe extern "C" fn sys_getpeername(
803806
if *addrlen >= size_of::<sockaddr_in>().try_into().unwrap() {
804807
let addr = unsafe { &mut *addr.cast() };
805808
*addr = sockaddr_in::from(endpoint);
809+
addr.sin_family = block_on(v.inet_domain(), None).unwrap().try_into().unwrap();
806810
*addrlen = size_of::<sockaddr_in>().try_into().unwrap();
807811
} else {
808812
return -crate::errno::EINVAL;
@@ -915,7 +919,7 @@ pub unsafe extern "C" fn sys_sendto(
915919
if #[cfg(any(feature = "tcp", feature = "udp"))] {
916920
let sa_family = unsafe { i32::from((*addr).sa_family) };
917921

918-
if sa_family == AF_INET {
922+
if sa_family == AF_INET_OLD || sa_family == AF_INET {
919923
if addr_len < size_of::<sockaddr_in>().try_into().unwrap() {
920924
return (-crate::errno::EINVAL).try_into().unwrap();
921925
}
@@ -982,6 +986,7 @@ pub unsafe extern "C" fn sys_recvfrom(
982986
if *addrlen >= size_of::<sockaddr_in>().try_into().unwrap() {
983987
let addr = unsafe { &mut *addr.cast() };
984988
*addr = sockaddr_in::from(endpoint);
989+
addr.sin_family = block_on(v.inet_domain(), None).unwrap().try_into().unwrap();
985990
*addrlen = size_of::<sockaddr_in>().try_into().unwrap();
986991
} else {
987992
return (-crate::errno::EINVAL).try_into().unwrap();

0 commit comments

Comments
 (0)