Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ SANDBOX
vendor
perf.data*
flamegraph*.svg
Cargo.lock
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ categories = [
bitflags = "1.3.2"
tokio = { version = "1.21.2", features = ["net", "macros", "rt-multi-thread", "time", "io-util", "sync"] }
futures = "0.3.25"
derivative = "2.2.0"
tracing = "0.1.37"
bytes = "1.2.1"
log = "0.4.17"
Expand Down
4 changes: 2 additions & 2 deletions examples/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ async fn main() {
}

async fn run(total: usize, num_streams: usize) -> io::Result<()> {
let socka = UdxSocket::bind("127.0.0.1:0").await?;
let sockb = UdxSocket::bind("127.0.0.1:0").await?;
let socka = UdxSocket::bind("127.0.0.1:0")?;
let sockb = UdxSocket::bind("127.0.0.1:0")?;
let addra = socka.local_addr()?;
let addrb = sockb.local_addr()?;

Expand Down
4 changes: 2 additions & 2 deletions examples/multi_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ async fn main() {
);

let host = "127.0.0.1";
let socka = UdxSocket::bind(format!("{host}:0")).await.unwrap();
let sockb = UdxSocket::bind(format!("{host}:0")).await.unwrap();
let socka = UdxSocket::bind(format!("{host}:0")).unwrap();
let sockb = UdxSocket::bind(format!("{host}:0")).unwrap();
let addra = socka.local_addr().unwrap();
let addrb = sockb.local_addr().unwrap();
eprintln!("addra {}", addra);
Expand Down
2 changes: 1 addition & 1 deletion examples/rw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async fn main() -> io::Result<()> {
.next()
.expect("invalid connect addr");
eprintln!("{} -> {}", listen_addr, connect_addr);
let sock = UdxSocket::bind(listen_addr).await?;
let sock = UdxSocket::bind(listen_addr)?;
let stream = sock.connect(connect_addr, 1, 1)?;
let max_len = UDX_DATA_MTU * 64;
let read = spawn("read", read_loop(stream.clone(), max_len));
Expand Down
4 changes: 2 additions & 2 deletions examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ async fn main() -> io::Result<()> {
tracing_subscriber::fmt::init();

// Bind two sockets
let socka = UdxSocket::bind("127.0.0.1:20004").await?;
let socka = UdxSocket::bind("127.0.0.1:20004")?;
let addra = socka.local_addr()?;
eprintln!("Socket A bound to {addra}");
let sockb = UdxSocket::bind("127.0.0.1:20005").await?;
let sockb = UdxSocket::bind("127.0.0.1:20005")?;
let addrb = sockb.local_addr()?;
eprintln!("Socket B bound to {addrb}");

Expand Down
26 changes: 20 additions & 6 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::fmt::Debug;
use std::io;
use std::io::IoSliceMut;
use std::mem::MaybeUninit;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
Expand Down Expand Up @@ -59,8 +61,20 @@ impl std::ops::Deref for UdxSocket {
}

impl UdxSocket {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let inner = UdxSocketInner::bind(addr).await?;
pub fn bind_rnd() -> io::Result<Self> {
Self::bind_port(0)
}
pub fn bind_port(port: u16) -> io::Result<Self> {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
Self::bind(addr)
}
// TODO FIXME this is not async but requires tokio running. Which will cause a runtime failure.
// rm this depndence
/// Create a socket on the given `addr`. Note `addr` is a *local* address normally it would
/// look like `127.0.0.1:8080` which creates a socket on port `8080`. To connect to any random
/// port pass `:0` as the port.
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let inner = UdxSocketInner::bind(addr)?;
let socket = Self(Arc::new(Mutex::new(inner)));
let driver = SocketDriver(socket.clone());
tokio::spawn(async {
Expand Down Expand Up @@ -236,7 +250,7 @@ impl SocketStats {
}

impl UdxSocketInner {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let socket = std::net::UdpSocket::bind(addr)?;
let socket = UdpSocket::from_std(socket)?;
let (send_tx, send_rx) = mpsc::unbounded_channel();
Expand Down Expand Up @@ -457,15 +471,15 @@ pub enum SocketEvent {
// Create an array of IO vectors from a buffer.
// Safety: buf has to be longer than N. You may only read from slices that have been written to.
// Taken from: quinn/src/endpoint.rs
unsafe fn iovectors_from_buf<'a, const N: usize>(buf: &'a mut [u8]) -> [IoSliceMut; N] {
let mut iovs = MaybeUninit::<[IoSliceMut<'a>; N]>::uninit();
unsafe fn iovectors_from_buf<const N: usize>(buf: &mut [u8]) -> [IoSliceMut; N] {
let mut iovs = MaybeUninit::<[IoSliceMut; N]>::uninit();
buf.chunks_mut(buf.len() / N)
.enumerate()
.for_each(|(i, buf)| {
iovs.as_mut_ptr()
.cast::<IoSliceMut>()
.add(i)
.write(IoSliceMut::<'a>::new(buf));
.write(IoSliceMut::new(buf));
});
iovs.assume_init()
}
4 changes: 2 additions & 2 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::io;

#[tokio::test]
async fn socket_dgrams() -> io::Result<()> {
let socka = UdxSocket::bind("127.0.0.1:0").await?;
let sockb = UdxSocket::bind("127.0.0.1:0").await?;
let socka = UdxSocket::bind("127.0.0.1:0")?;
let sockb = UdxSocket::bind("127.0.0.1:0")?;
let addra = socka.local_addr()?;
let addrb = sockb.local_addr()?;

Expand Down
4 changes: 2 additions & 2 deletions tests/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ async fn stream_close() -> io::Result<()> {
}

async fn create_pair() -> io::Result<((UdxSocket, UdxSocket), (UdxStream, UdxStream))> {
let socka = UdxSocket::bind("127.0.0.1:0").await?;
let sockb = UdxSocket::bind("127.0.0.1:0").await?;
let socka = UdxSocket::bind("127.0.0.1:0")?;
let sockb = UdxSocket::bind("127.0.0.1:0")?;
let addra = socka.local_addr()?;
let addrb = sockb.local_addr()?;
let streama = socka.connect(addrb, 1, 2)?;
Expand Down
23 changes: 18 additions & 5 deletions udx-udp/src/unix.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::{
io,
io::IoSliceMut,
io::{self, IoSliceMut},
mem::{self, MaybeUninit},
net::{IpAddr, SocketAddr},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
os::unix::io::AsRawFd,
ptr,
sync::atomic::AtomicUsize,
Expand Down Expand Up @@ -550,8 +549,22 @@ fn decode_recv(
}

let addr = match libc::c_int::from(name.ss_family) {
libc::AF_INET => unsafe { SocketAddr::V4(ptr::read(&name as *const _ as _)) },
libc::AF_INET6 => unsafe { SocketAddr::V6(ptr::read(&name as *const _ as _)) },
libc::AF_INET => unsafe {
// Cast to sockaddr_in first to get correct memory layout
let addr: &libc::sockaddr_in = &*(&name as *const _ as *const libc::sockaddr_in);
SocketAddr::new(
IpAddr::V4(Ipv4Addr::from(u32::from_be(addr.sin_addr.s_addr))),
u16::from_be(addr.sin_port),
)
},
libc::AF_INET6 => unsafe {
// Cast to sockaddr_in6 first to get correct memory layout
let addr: &libc::sockaddr_in6 = &*(&name as *const _ as *const libc::sockaddr_in6);
SocketAddr::new(
IpAddr::V6(Ipv6Addr::from(addr.sin6_addr.s6_addr)),
u16::from_be(addr.sin6_port),
)
},
_ => unreachable!(),
};

Expand Down