Skip to content

Commit

Permalink
Merge pull request #1442 from microsoft/enhancement-runtime-move-ephe…
Browse files Browse the repository at this point in the history
…meral-port

[inetstack] Enhancement: Move ephemeral port allocator
  • Loading branch information
iyzhang authored Oct 18, 2024
2 parents c107940 + 8ff1f84 commit aa45d62
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 165 deletions.
7 changes: 2 additions & 5 deletions src/rust/demikernel/libos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ impl LibOS {
LibOSName::Catnap => Self::NetworkLibOS(NetworkLibOSWrapper::Catnap(SharedNetworkLibOS::<
SharedCatnapTransport,
>::new(
config.local_ipv4_addr()?,
runtime.clone(),
SharedCatnapTransport::new(&config, &mut runtime)?,
))),
Expand All @@ -98,7 +97,7 @@ impl LibOS {
let inetstack: SharedInetStack =
SharedInetStack::new(&config, runtime.clone(), layer1_endpoint).unwrap();
Self::NetworkLibOS(NetworkLibOSWrapper::Catpowder(
SharedNetworkLibOS::<SharedInetStack>::new(config.local_ipv4_addr()?, runtime, inetstack),
SharedNetworkLibOS::<SharedInetStack>::new(runtime, inetstack),
))
},
#[cfg(feature = "catnip-libos")]
Expand All @@ -109,9 +108,7 @@ impl LibOS {
SharedInetStack::new(&config, runtime.clone(), layer1_endpoint).unwrap();

Self::NetworkLibOS(NetworkLibOSWrapper::Catnip(SharedNetworkLibOS::<SharedInetStack>::new(
config.local_ipv4_addr()?,
runtime,
inetstack,
runtime, inetstack,
)))
},
_ => panic!("unsupported libos"),
Expand Down
68 changes: 15 additions & 53 deletions src/rust/demikernel/libos/network/libos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ use ::std::{
/// Catnap libOS. All state is kept in the [runtime] and [qtable].
/// TODO: Move [qtable] into [runtime] so all state is contained in the PosixRuntime.
pub struct NetworkLibOS<T: NetworkTransport> {
local_ipv4_addr: Ipv4Addr,
runtime: SharedDemiRuntime,
transport: T,
}
Expand All @@ -55,9 +54,8 @@ pub struct SharedNetworkLibOS<T: NetworkTransport>(SharedObject<NetworkLibOS<T>>
//======================================================================================================================

impl<T: NetworkTransport> SharedNetworkLibOS<T> {
pub fn new(local_ipv4_addr: Ipv4Addr, runtime: SharedDemiRuntime, transport: T) -> Self {
pub fn new(runtime: SharedDemiRuntime, transport: T) -> Self {
Self(SharedObject::new(NetworkLibOS::<T> {
local_ipv4_addr,
runtime: runtime.clone(),
transport,
}))
Expand Down Expand Up @@ -99,66 +97,39 @@ impl<T: NetworkTransport> SharedNetworkLibOS<T> {
}

/// This function contains the LibOS-level functionality needed to bind a SharedNetworkQueue to a local address.
pub fn bind(&mut self, qd: QDesc, mut socket_addr: SocketAddr) -> Result<(), Fail> {
pub fn bind(&mut self, qd: QDesc, socket_addr: SocketAddr) -> Result<(), Fail> {
trace!("bind() qd={:?}, local={:?}", qd, socket_addr);

// We only support IPv4 addresses.
let socket_addrv4: SocketAddrV4 = unwrap_socketaddr(socket_addr)?;

// We only support the wildcard address for UDP sockets.
// FIXME: https://github.com/demikernel/demikernel/issues/189
match *socket_addrv4.ip() {
Ipv4Addr::UNSPECIFIED if self.get_shared_queue(&qd)?.get_qtype() == QType::UdpSocket => (),
Ipv4Addr::UNSPECIFIED => {
let cause: String = format!("cannot bind to wildcard address (qd={:?})", qd);
error!("bind(): {}", cause);
return Err(Fail::new(libc::ENOTSUP, &cause));
},
addrv4 if addrv4 != self.local_ipv4_addr => {
let cause: String = format!("cannot bind to non-local address: {:?}", addrv4);
error!("bind(): {}", &cause);
return Err(Fail::new(libc::EADDRNOTAVAIL, &cause));
},
_ => (),
}

if SharedDemiRuntime::is_private_ephemeral_port(socket_addr.port()) {
self.runtime.reserve_ephemeral_port(socket_addr.port())?
if *socket_addrv4.ip() == Ipv4Addr::UNSPECIFIED && self.get_shared_queue(&qd)?.get_qtype() != QType::UdpSocket {
let cause: String = format!("cannot bind to wildcard address (qd={:?})", qd);
error!("bind(): {}", cause);
return Err(Fail::new(libc::ENOTSUP, &cause));
}

// We only support the wildcard address for UDP sockets.
// FIXME: https://github.com/demikernel/demikernel/issues/582
if socket_addr.port() == 0 {
if self.get_shared_queue(&qd)?.get_qtype() != QType::UdpSocket {
let cause: String = format!("cannot bind to port 0 (qd={:?})", qd);
error!("bind(): {}", cause);
return Err(Fail::new(libc::ENOTSUP, &cause));
} else {
// Allocate an ephemeral port.
let new_port: u16 = self.runtime.alloc_ephemeral_port()?;
socket_addr.set_port(new_port);
}
if socket_addr.port() == 0 && self.get_shared_queue(&qd)?.get_qtype() != QType::UdpSocket {
let cause: String = format!("cannot bind to port 0 (qd={:?})", qd);
error!("bind(): {}", cause);
return Err(Fail::new(libc::ENOTSUP, &cause));
}

if self.runtime.is_addr_in_use(socket_addrv4) {
let cause: String = format!("address is already bound to a socket (qd={:?}", qd);
error!("bind(): {}", &cause);
return Err(Fail::new(libc::EADDRINUSE, &cause));
}
self.get_shared_queue(&qd)?.bind(socket_addr)?;
// Insert into address to queue descriptor table.
self.runtime
.insert_socket_id_to_qd(SocketId::Passive(socket_addrv4.clone()), qd);

if let Err(e) = self.get_shared_queue(&qd)?.bind(socket_addr) {
if SharedDemiRuntime::is_private_ephemeral_port(socket_addr.port()) {
if self.runtime.free_ephemeral_port(socket_addr.port()).is_err() {
warn!("bind(): leaking ephemeral port (port={})", socket_addr.port());
}
}
Err(e)
} else {
// Insert into address to queue descriptor table.
self.runtime
.insert_socket_id_to_qd(SocketId::Passive(socket_addrv4.clone()), qd);
Ok(())
}
Ok(())
}

/// Sets a SharedNetworkQueue and its underlying socket as a passive one. This function contains the LibOS-level
Expand Down Expand Up @@ -306,15 +277,6 @@ impl<T: NetworkTransport> SharedNetworkLibOS<T> {
unwrap_socketaddr(local),
"we only support IPv4"
)));

// Check if this is an ephemeral port.
if SharedDemiRuntime::is_private_ephemeral_port(local.port()) {
// Allocate ephemeral port from the pool, to leave ephemeral port allocator in a consistent state.
if let Err(e) = self.runtime.free_ephemeral_port(local.port()) {
let cause: String = format!("close(): Could not free ephemeral port");
warn!("{}: {:?}", cause, e);
}
}
}
// Remove the queue from the queue table. Expect is safe here because we looked up the queue to
// schedule this coroutine and no other close coroutine should be able to run due to state machine
Expand Down
1 change: 0 additions & 1 deletion src/rust/inetstack/protocols/layer3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ impl SharedLayer3Endpoint {
self.layer2_endpoint.transmit_ipv4_packet(remote_link_addr, pkt)
}

#[cfg(test)]
pub fn get_local_addr(&self) -> Ipv4Addr {
self.local_ipv4_addr
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ impl Default for EphemeralPorts {

#[cfg(test)]
mod test {
use crate::runtime::network::ephemeral::{EphemeralPorts, FIRST_PRIVATE_PORT_NUMBER, LAST_PRIVATE_PORT_NUMBER};
use crate::inetstack::protocols::layer4::ephemeral::{
EphemeralPorts, FIRST_PRIVATE_PORT_NUMBER, LAST_PRIVATE_PORT_NUMBER,
};
use ::anyhow::Result;

#[test]
Expand Down
91 changes: 77 additions & 14 deletions src/rust/inetstack/protocols/layer4/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// Exports
//======================================================================================================================

pub mod ephemeral;
pub mod tcp;
pub mod udp;

Expand All @@ -21,6 +22,7 @@ use crate::{
inetstack::protocols::{
layer3::{ip::IpProtocol, SharedLayer3Endpoint},
layer4::{
ephemeral::EphemeralPorts,
tcp::{SharedTcpPeer, SharedTcpSocket},
udp::{SharedUdpPeer, SharedUdpSocket},
},
Expand Down Expand Up @@ -48,6 +50,7 @@ pub struct Peer {
tcp: SharedTcpPeer,
udp: SharedUdpPeer,
layer3_endpoint: SharedLayer3Endpoint,
ephemeral_ports: EphemeralPorts,
}

/// Socket Representation.
Expand Down Expand Up @@ -75,6 +78,7 @@ impl Peer {
tcp,
udp,
layer3_endpoint,
ephemeral_ports: EphemeralPorts::default(),
})
}

Expand Down Expand Up @@ -161,14 +165,28 @@ impl Peer {
/// Upon successful completion, `Ok(())` is returned. Upon failure, `Fail` is
/// returned instead.
///
pub fn bind(&mut self, sd: &mut Socket, local: SocketAddr) -> Result<(), Fail> {
pub fn bind(&mut self, sd: &mut Socket, socket_addr: SocketAddr) -> Result<(), Fail> {
// FIXME: add IPv6 support; https://github.com/microsoft/demikernel/issues/935
let local: SocketAddrV4 = unwrap_socketaddr(local)?;
let socket_addr_v4: SocketAddrV4 = unwrap_socketaddr(socket_addr)?;
// Check if we are allowed to bind to this address.
if *socket_addr_v4.ip() != self.layer3_endpoint.get_local_addr()
&& *socket_addr_v4.ip() != Ipv4Addr::UNSPECIFIED
{
let cause: String = format!("cannot bind to non-local address: {:?}", socket_addr_v4);
error!("bind(): {}", &cause);
return Err(Fail::new(libc::EADDRNOTAVAIL, &cause));
}

match sd {
Socket::Tcp(socket) => self.tcp.bind(socket, local),
Socket::Udp(socket) => self.udp.bind(socket, local),
Socket::Tcp(socket) => self.tcp.bind(socket, socket_addr_v4),
Socket::Udp(socket) => self.udp.bind(socket, socket_addr_v4),
}?;

if EphemeralPorts::is_private(socket_addr_v4.port()) {
self.ephemeral_ports.reserve(socket_addr_v4.port())?;
}

Ok(())
}

///
Expand Down Expand Up @@ -251,11 +269,18 @@ impl Peer {
pub async fn connect(&mut self, sd: &mut Socket, remote: SocketAddr) -> Result<(), Fail> {
trace!("connect(): remote={:?}", remote);

// FIXME: add IPv6 support; https://github.com/microsoft/demikernel/issues/935
let remote: SocketAddrV4 = unwrap_socketaddr(remote)?;

match sd {
Socket::Tcp(socket) => self.tcp.connect(socket, remote).await,
Socket::Tcp(socket) => {
// FIXME: add IPv6 support; https://github.com/microsoft/demikernel/issues/935
let remote: SocketAddrV4 = unwrap_socketaddr(remote)?;
// If not bound, allocate an ephemeral port.
let local: SocketAddrV4 = match socket.local() {
Some(local) => local,
None => SocketAddrV4::new(self.layer3_endpoint.get_local_addr(), self.ephemeral_ports.alloc()?),
};

self.tcp.connect(socket, local, remote).await
},
_ => Err(Fail::new(libc::EINVAL, "invalid queue type")),
}
}
Expand All @@ -271,17 +296,55 @@ impl Peer {
/// completes shutting down the connection. Upon failure, `Fail` is returned instead.
///
pub async fn close(&mut self, sd: &mut Socket) -> Result<(), Fail> {
match sd {
Socket::Tcp(socket) => self.tcp.close(socket).await,
Socket::Udp(socket) => self.udp.close(socket).await,
let local_port: Option<u16> = match sd {
Socket::Tcp(socket) => {
let local_port: Option<u16> = match socket.local() {
Some(socket_addr_v4) => Some(socket_addr_v4.port()),
None => None,
};

self.tcp.close(socket).await?;
local_port
},
Socket::Udp(socket) => {
let local_port: Option<u16> = match socket.local() {
Some(socket_addr_v4) => Some(socket_addr_v4.port()),
None => None,
};
self.udp.close(socket).await?;
local_port
},
};
match local_port {
Some(port) if EphemeralPorts::is_private(port) => self.ephemeral_ports.free(port),
_ => Ok(()),
}
}

/// Forcibly close a socket. This should only be used on clean up.
pub fn hard_close(&mut self, sd: &mut Socket) -> Result<(), Fail> {
match sd {
Socket::Tcp(socket) => self.tcp.hard_close(socket),
Socket::Udp(socket) => self.udp.hard_close(socket),
let local_port: Option<u16> = match sd {
Socket::Tcp(socket) => {
let local_port: Option<u16> = match socket.local() {
Some(socket_addr_v4) => Some(socket_addr_v4.port()),
None => None,
};

self.tcp.hard_close(socket)?;
local_port
},
Socket::Udp(socket) => {
let local_port: Option<u16> = match socket.local() {
Some(socket_addr_v4) => Some(socket_addr_v4.port()),
None => None,
};
self.udp.hard_close(socket)?;
local_port
},
};
match local_port {
Some(port) if EphemeralPorts::is_private(port) => self.ephemeral_ports.free(port),
_ => Ok(()),
}
}

Expand Down
40 changes: 8 additions & 32 deletions src/rust/inetstack/protocols/layer4/tcp/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,14 @@ impl SharedTcpPeer {
}

/// Runs until the connect to remote is made or times out.
pub async fn connect(&mut self, socket: &mut SharedTcpSocket, remote: SocketAddrV4) -> Result<(), Fail> {
// Check whether we need to allocate an ephemeral port.
let local: SocketAddrV4 = match socket.local() {
Some(addr) => {
// If socket is already bound to a local address, use it but remove the old binding.
self.addresses.remove(&SocketId::Passive(addr));
addr
},
None => {
let local_port: u16 = self.runtime.alloc_ephemeral_port()?;
SocketAddrV4::new(self.local_ipv4_addr, local_port)
},
};
pub async fn connect(
&mut self,
socket: &mut SharedTcpSocket,
local: SocketAddrV4,
remote: SocketAddrV4,
) -> Result<(), Fail> {
// If socket is already bound to a local address, use it but remove the old binding.
self.addresses.remove(&SocketId::Passive(local));
// Insert the connection to receive incoming packets for this address pair.
// Should we remove the passive entry for the local address if the socket was previously bound?
if self
Expand Down Expand Up @@ -204,38 +199,19 @@ impl SharedTcpPeer {
Ok((None, incoming))
}

/// Frees an ephemeral port (if any) allocated to a given socket.
fn free_ephemeral_port(&mut self, socket_id: &SocketId) {
let local: &SocketAddrV4 = match socket_id {
SocketId::Active(local, _) => local,
SocketId::Passive(local) => local,
};
// Rollback ephemeral port allocation.
if SharedDemiRuntime::is_private_ephemeral_port(local.port()) {
if self.runtime.free_ephemeral_port(local.port()).is_err() {
// We fail if and only if we attempted to free a port that was not allocated.
// This is unexpected, but if it happens, issue a warning and keep going,
// otherwise we would leave the queue in a dangling state.
warn!("bind(): leaking ephemeral port (port={})", local.port());
}
}
}

/// Closes a TCP socket.
pub async fn close(&mut self, socket: &mut SharedTcpSocket) -> Result<(), Fail> {
// Wait for close to complete.
// Handle result: If unsuccessful, free the new queue descriptor.
if let Some(socket_id) = socket.close().await? {
self.addresses.remove(&socket_id);
self.free_ephemeral_port(&socket_id);
}
Ok(())
}

pub fn hard_close(&mut self, socket: &mut SharedTcpSocket) -> Result<(), Fail> {
if let Some(socket_id) = socket.hard_close()? {
self.addresses.remove(&socket_id);
self.free_ephemeral_port(&socket_id);
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/rust/inetstack/test_helpers/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl SharedEngine {
let transport: SharedInetStack = SharedInetStack::new_test(&config, runtime.clone(), layer1_endpoint.clone())?;

Ok(Self(SharedObject::new(Engine {
libos: SharedNetworkLibOS::<SharedInetStack>::new(config.local_ipv4_addr()?, runtime, transport),
libos: SharedNetworkLibOS::<SharedInetStack>::new(runtime, transport),
layer1_endpoint,
})))
}
Expand Down
Loading

0 comments on commit aa45d62

Please sign in to comment.