Skip to content

Commit

Permalink
udp: UdpSocket split support (#1226)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yin Guanhao authored and carllerche committed Jul 8, 2019
1 parent 8b49a1e commit 88e775d
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 13 deletions.
1 change: 1 addition & 0 deletions tokio-udp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod recv_from;
mod send;
mod send_to;
mod socket;
pub mod split;

// pub use self::frame::UdpFramed;
pub use self::recv::Recv;
Expand Down
6 changes: 3 additions & 3 deletions tokio-udp/src/recv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use std::task::{Context, Poll};
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Recv<'a, 'b> {
socket: &'a mut UdpSocket,
socket: &'a UdpSocket,
buf: &'b mut [u8],
}

impl<'a, 'b> Recv<'a, 'b> {
pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b mut [u8]) -> Self {
pub(super) fn new(socket: &'a UdpSocket, buf: &'b mut [u8]) -> Self {
Self { socket, buf }
}
}
Expand All @@ -25,6 +25,6 @@ impl<'a, 'b> Future for Recv<'a, 'b> {

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Recv { socket, buf } = self.get_mut();
Pin::new(&mut **socket).poll_recv(cx, buf)
socket.poll_recv_priv(cx, buf)
}
}
6 changes: 3 additions & 3 deletions tokio-udp/src/recv_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ use std::task::{Context, Poll};
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct RecvFrom<'a, 'b> {
socket: &'a mut UdpSocket,
socket: &'a UdpSocket,
buf: &'b mut [u8],
}

impl<'a, 'b> RecvFrom<'a, 'b> {
pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b mut [u8]) -> Self {
pub(super) fn new(socket: &'a UdpSocket, buf: &'b mut [u8]) -> Self {
Self { socket, buf }
}
}
Expand All @@ -26,6 +26,6 @@ impl<'a, 'b> Future for RecvFrom<'a, 'b> {

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let RecvFrom { socket, buf } = self.get_mut();
Pin::new(&mut **socket).poll_recv_from(cx, buf)
socket.poll_recv_from_priv(cx, buf)
}
}
6 changes: 3 additions & 3 deletions tokio-udp/src/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use std::task::{Context, Poll};
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Send<'a, 'b> {
socket: &'a mut UdpSocket,
socket: &'a UdpSocket,
buf: &'b [u8],
}

impl<'a, 'b> Send<'a, 'b> {
pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b [u8]) -> Self {
pub(super) fn new(socket: &'a UdpSocket, buf: &'b [u8]) -> Self {
Self { socket, buf }
}
}
Expand All @@ -25,6 +25,6 @@ impl<'a, 'b> Future for Send<'a, 'b> {

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Send { socket, buf } = self.get_mut();
Pin::new(&mut **socket).poll_send(cx, buf)
socket.poll_send_priv(cx, buf)
}
}
6 changes: 3 additions & 3 deletions tokio-udp/src/send_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use std::task::{Context, Poll};
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct SendTo<'a, 'b> {
socket: &'a mut UdpSocket,
socket: &'a UdpSocket,
buf: &'b [u8],
target: &'b SocketAddr,
}

impl<'a, 'b> SendTo<'a, 'b> {
pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b [u8], target: &'b SocketAddr) -> Self {
pub(super) fn new(socket: &'a UdpSocket, buf: &'b [u8], target: &'b SocketAddr) -> Self {
Self {
socket,
buf,
Expand All @@ -35,6 +35,6 @@ impl<'a, 'b> Future for SendTo<'a, 'b> {
buf,
target,
} = self.get_mut();
Pin::new(&mut **socket).poll_send_to(cx, buf, target)
socket.poll_send_to_priv(cx, buf, target)
}
}
54 changes: 54 additions & 0 deletions tokio-udp/src/socket.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::split::{split, UdpSocketRecvHalf, UdpSocketSendHalf};
use super::{Recv, RecvFrom, Send, SendTo};
use mio;
use std::convert::TryFrom;
Expand Down Expand Up @@ -42,6 +43,16 @@ impl UdpSocket {
Ok(UdpSocket { io })
}

/// Split the `UdpSocket` into a receive half and a send half. The two parts
/// can be used to receive and send datagrams concurrently, even from two
/// different tasks.
///
/// See the module level documenation of [`split`](super::split) for more
/// details.
pub fn split(self) -> (UdpSocketRecvHalf, UdpSocketSendHalf) {
split(self)
}

/// Returns the local address that this socket is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io.get_ref().local_addr()
Expand Down Expand Up @@ -83,6 +94,24 @@ impl UdpSocket {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_send_priv(cx, buf)
}

// Poll IO functions that takes `&self` are provided for the split API.
//
// They are not public because (taken from the doc of `PollEvented`):
//
// While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
// caller must ensure that there are at most two tasks that use a
// `PollEvented` instance concurrently. One for reading and one for writing.
// While violating this requirement is "safe" from a Rust memory model point
// of view, it will result in unexpected behavior in the form of lost
// notifications and tasks hanging.
pub(crate) fn poll_send_priv(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
ready!(self.io.poll_write_ready(cx))?;

Expand Down Expand Up @@ -134,6 +163,14 @@ impl UdpSocket {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.poll_recv_priv(cx, buf)
}

pub(crate) fn poll_recv_priv(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;

Expand Down Expand Up @@ -173,6 +210,15 @@ impl UdpSocket {
cx: &mut Context<'_>,
buf: &[u8],
target: &SocketAddr,
) -> Poll<io::Result<usize>> {
self.poll_send_to_priv(cx, buf, target)
}

pub(crate) fn poll_send_to_priv(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: &SocketAddr,
) -> Poll<io::Result<usize>> {
ready!(self.io.poll_write_ready(cx))?;

Expand Down Expand Up @@ -201,6 +247,14 @@ impl UdpSocket {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, SocketAddr), io::Error>> {
self.poll_recv_from_priv(cx, buf)
}

pub(crate) fn poll_recv_from_priv(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<(usize, SocketAddr), io::Error>> {
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;

Expand Down
145 changes: 145 additions & 0 deletions tokio-udp/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//! [`UdpSocket`](../struct.UdpSocket.html) split support.
//!
//! The [`split`](../struct.UdpSocket.html#method.split) method splits a
//! `UdpSocket` into a receive half and a send half, which can be used to
//! receive and send datagrams concurrently, even from two different tasks.
//!
//! The halves provide access to the underlying socket, implementing
//! `AsRef<UdpSocket>`. This allows you to call `UdpSocket` methods that takes
//! `&self`, e.g., to get local address, to get and set socket options, to join
//! or leave multicast groups, etc.
//!
//! The halves can be reunited to the original socket with their `reunite`
//! methods.
use super::{Recv, RecvFrom, Send, SendTo, UdpSocket};
use std::error::Error;
use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;

/// The send half after [`split`](super::UdpSocket::split).
///
/// Use [`send_to`](#method.send_to) or [`send`](#method.send) to send
/// datagrams.
#[derive(Debug)]
pub struct UdpSocketSendHalf(Arc<UdpSocket>);

/// The recv half after [`split`](super::UdpSocket::split).
///
/// Use [`recv_from`](#method.recv_from) or [`recv`](#method.recv) to receive
/// datagrams.
#[derive(Debug)]
pub struct UdpSocketRecvHalf(Arc<UdpSocket>);

pub(crate) fn split(socket: UdpSocket) -> (UdpSocketRecvHalf, UdpSocketSendHalf) {
let shared = Arc::new(socket);
let send = shared.clone();
let recv = shared;
(UdpSocketRecvHalf(recv), UdpSocketSendHalf(send))
}

/// Error indicating two halves were not from the same socket, and thus could
/// not be `reunite`d.
#[derive(Debug)]
pub struct ReuniteError(pub UdpSocketSendHalf, pub UdpSocketRecvHalf);

impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}

impl Error for ReuniteError {}

fn reunite(s: UdpSocketSendHalf, r: UdpSocketRecvHalf) -> Result<UdpSocket, ReuniteError> {
if Arc::ptr_eq(&s.0, &r.0) {
drop(r);
// Only two instances of the `Arc` are ever created, one for the
// receiver and one for the sender, and those `Arc`s are never exposed
// externally. And so when we drop one here, the other one must be the
// only remaining one.
Ok(Arc::try_unwrap(s.0).expect("tokio_udp: try_unwrap failed in reunite"))
} else {
Err(ReuniteError(s, r))
}
}

impl UdpSocketRecvHalf {
/// Attempts to put the two "halves" of a `UdpSocket` back together and
/// recover the original socket. Succeeds only if the two "halves"
/// originated from the same call to `UdpSocket::split`.
pub fn reunite(self, other: UdpSocketSendHalf) -> Result<UdpSocket, ReuniteError> {
reunite(other, self)
}

/// Returns a future that receives a single datagram on the socket. On success,
/// the future resolves to the number of bytes read and the origin.
///
/// The function must be called with valid byte array `buf` of sufficient size
/// to hold the message bytes. If a message is too long to fit in the supplied
/// buffer, excess bytes may be discarded.
pub fn recv_from<'a, 'b>(&'a mut self, buf: &'b mut [u8]) -> RecvFrom<'a, 'b> {
RecvFrom::new(&self.0, buf)
}

/// Returns a future that receives a single datagram message on the socket from
/// the remote address to which it is connected. On success, the future will resolve
/// to the number of bytes read.
///
/// The function must be called with valid byte array `buf` of sufficient size to
/// hold the message bytes. If a message is too long to fit in the supplied buffer,
/// excess bytes may be discarded.
///
/// The [`connect`] method will connect this socket to a remote address. The future
/// will fail if the socket is not connected.
///
/// [`connect`]: super::UdpSocket::connect
pub fn recv<'a, 'b>(&'a mut self, buf: &'b mut [u8]) -> Recv<'a, 'b> {
Recv::new(&self.0, buf)
}
}

impl UdpSocketSendHalf {
/// Attempts to put the two "halves" of a `UdpSocket` back together and
/// recover the original socket. Succeeds only if the two "halves"
/// originated from the same call to `UdpSocket::split`.
pub fn reunite(self, other: UdpSocketRecvHalf) -> Result<UdpSocket, ReuniteError> {
reunite(self, other)
}

/// Returns a future that sends data on the socket to the given address.
/// On success, the future will resolve to the number of bytes written.
///
/// The future will resolve to an error if the IP version of the socket does
/// not match that of `target`.
pub fn send_to<'a, 'b>(&'a mut self, buf: &'b [u8], target: &'b SocketAddr) -> SendTo<'a, 'b> {
SendTo::new(&self.0, buf, target)
}

/// Returns a future that sends data on the socket to the remote address to which it is connected.
/// On success, the future will resolve to the number of bytes written.
///
/// The [`connect`] method will connect this socket to a remote address. The future
/// will resolve to an error if the socket is not connected.
///
/// [`connect`]: super::UdpSocket::connect
pub fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> Send<'a, 'b> {
Send::new(&self.0, buf)
}
}

impl AsRef<UdpSocket> for UdpSocketSendHalf {
fn as_ref(&self) -> &UdpSocket {
&self.0
}
}

impl AsRef<UdpSocket> for UdpSocketRecvHalf {
fn as_ref(&self) -> &UdpSocket {
&self.0
}
}
34 changes: 34 additions & 0 deletions tokio-udp/tests/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,40 @@ async fn send_to_recv_from() -> std::io::Result<()> {
Ok(())
}

#[tokio::test]
async fn split() -> std::io::Result<()> {
let socket = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?;
let (mut r, mut s) = socket.split();

let msg = b"hello";
let addr = s.as_ref().local_addr()?;
tokio::spawn(async move {
s.send_to(msg, &addr).await.unwrap();
});
let mut recv_buf = [0u8; 32];
let (len, _) = r.recv_from(&mut recv_buf[..]).await?;
assert_eq!(&recv_buf[..len], msg);
Ok(())
}

#[tokio::test]
async fn reunite() -> std::io::Result<()> {
let socket = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?;
let (s, r) = socket.split();
assert!(s.reunite(r).is_ok());
Ok(())
}

#[tokio::test]
async fn reunite_error() -> std::io::Result<()> {
let socket = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?;
let socket1 = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?;
let (s, _) = socket.split();
let (_, r1) = socket1.split();
assert!(s.reunite(r1).is_err());
Ok(())
}

// pub struct ByteCodec;

// impl Decoder for ByteCodec {
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub mod udp {
//! [`Send`]: struct.Send.html
//! [`RecvFrom`]: struct.RecvFrom.html
//! [`SendTo`]: struct.SendTo.html
pub use tokio_udp::{Recv, RecvFrom, Send, SendTo, UdpSocket};
pub use tokio_udp::{split, Recv, RecvFrom, Send, SendTo, UdpSocket};
}
#[cfg(feature = "udp")]
pub use self::udp::UdpSocket;
Expand Down

0 comments on commit 88e775d

Please sign in to comment.