Skip to content

Commit a096799

Browse files
sfacklerseanmonstar
authored andcommitted
feat(body): add Sender::abort
This allows a client or server to indicate that the body should be cut off in an abnormal fashion so the server doesn't simply get a "valid" but truncated body.
1 parent 1e3bc6b commit a096799

File tree

3 files changed

+71
-14
lines changed

3 files changed

+71
-14
lines changed

src/body/body.rs

+22-12
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use common::Never;
1111
use super::{Chunk, Payload};
1212
use super::internal::{FullDataArg, FullDataRet};
1313

14-
1514
type BodySender = mpsc::Sender<Result<Chunk, ::Error>>;
1615

1716
/// A stream of `Chunk`s, used when receiving bodies.
@@ -36,7 +35,7 @@ pub struct Body {
3635
enum Kind {
3736
Once(Option<Chunk>),
3837
Chan {
39-
_close_tx: oneshot::Sender<()>,
38+
abort_rx: oneshot::Receiver<()>,
4039
rx: mpsc::Receiver<Result<Chunk, ::Error>>,
4140
},
4241
H2(h2::RecvStream),
@@ -61,7 +60,7 @@ enum DelayEof {
6160
#[must_use = "Sender does nothing unless sent on"]
6261
#[derive(Debug)]
6362
pub struct Sender {
64-
close_rx: oneshot::Receiver<()>,
63+
abort_tx: oneshot::Sender<()>,
6564
tx: BodySender,
6665
}
6766

@@ -87,14 +86,14 @@ impl Body {
8786
#[inline]
8887
pub fn channel() -> (Sender, Body) {
8988
let (tx, rx) = mpsc::channel(0);
90-
let (close_tx, close_rx) = oneshot::channel();
89+
let (abort_tx, abort_rx) = oneshot::channel();
9190

9291
let tx = Sender {
93-
close_rx: close_rx,
92+
abort_tx: abort_tx,
9493
tx: tx,
9594
};
9695
let rx = Body::new(Kind::Chan {
97-
_close_tx: close_tx,
96+
abort_rx: abort_rx,
9897
rx: rx,
9998
});
10099

@@ -189,11 +188,17 @@ impl Body {
189188
fn poll_inner(&mut self) -> Poll<Option<Chunk>, ::Error> {
190189
match self.kind {
191190
Kind::Once(ref mut val) => Ok(Async::Ready(val.take())),
192-
Kind::Chan { ref mut rx, .. } => match rx.poll().expect("mpsc cannot error") {
193-
Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))),
194-
Async::Ready(Some(Err(err))) => Err(err),
195-
Async::Ready(None) => Ok(Async::Ready(None)),
196-
Async::NotReady => Ok(Async::NotReady),
191+
Kind::Chan { ref mut rx, ref mut abort_rx } => {
192+
if let Ok(Async::Ready(())) = abort_rx.poll() {
193+
return Err(::Error::new_body_write("body write aborted"));
194+
}
195+
196+
match rx.poll().expect("mpsc cannot error") {
197+
Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))),
198+
Async::Ready(Some(Err(err))) => Err(err),
199+
Async::Ready(None) => Ok(Async::Ready(None)),
200+
Async::NotReady => Ok(Async::NotReady),
201+
}
197202
},
198203
Kind::H2(ref mut h2) => {
199204
h2.poll()
@@ -283,7 +288,7 @@ impl fmt::Debug for Body {
283288
impl Sender {
284289
/// Check to see if this `Sender` can send more data.
285290
pub fn poll_ready(&mut self) -> Poll<(), ::Error> {
286-
match self.close_rx.poll() {
291+
match self.abort_tx.poll_cancel() {
287292
Ok(Async::Ready(())) | Err(_) => return Err(::Error::new_closed()),
288293
Ok(Async::NotReady) => (),
289294
}
@@ -303,6 +308,11 @@ impl Sender {
303308
.map_err(|err| err.into_inner().expect("just sent Ok"))
304309
}
305310

311+
/// Aborts the body in an abnormal fashion.
312+
pub fn abort(self) {
313+
let _ = self.abort_tx.send(());
314+
}
315+
306316
pub(crate) fn send_error(&mut self, err: ::Error) {
307317
let _ = self.tx.try_send(Err(err));
308318
}

src/error.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ impl StdError for Error {
265265
Kind::NewService => "calling user's new_service failed",
266266
Kind::Service => "error from user's server service",
267267
Kind::Body => "error reading a body from connection",
268-
Kind::BodyWrite => "error write a body to connection",
268+
Kind::BodyWrite => "error writing a body to connection",
269269
Kind::BodyUser => "error from user's Payload stream",
270270
Kind::Shutdown => "error shutting down connection",
271271
Kind::Http2 => "http2 general error",

tests/client.rs

+48-1
Original file line numberDiff line numberDiff line change
@@ -1373,7 +1373,7 @@ mod conn {
13731373
use tokio::net::TcpStream;
13741374
use tokio_io::{AsyncRead, AsyncWrite};
13751375

1376-
use hyper::{self, Request};
1376+
use hyper::{self, Request, Body, Method};
13771377
use hyper::client::conn;
13781378

13791379
use super::{s, tcp_connect, FutureHyperExt};
@@ -1424,6 +1424,53 @@ mod conn {
14241424
res.join(rx).map(|r| r.0).wait().unwrap();
14251425
}
14261426

1427+
#[test]
1428+
fn aborted_body_isnt_completed() {
1429+
let _ = ::pretty_env_logger::try_init();
1430+
let server = TcpListener::bind("127.0.0.1:0").unwrap();
1431+
let addr = server.local_addr().unwrap();
1432+
let mut runtime = Runtime::new().unwrap();
1433+
1434+
let (tx, rx) = oneshot::channel();
1435+
let server = thread::spawn(move || {
1436+
let mut sock = server.accept().unwrap().0;
1437+
sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
1438+
sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
1439+
let expected = "POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n5\r\nhello\r\n";
1440+
let mut buf = vec![0; expected.len()];
1441+
sock.read_exact(&mut buf).expect("read 1");
1442+
assert_eq!(s(&buf), expected);
1443+
1444+
let _ = tx.send(());
1445+
1446+
assert_eq!(sock.read(&mut buf).expect("read 2"), 0);
1447+
});
1448+
1449+
let tcp = tcp_connect(&addr).wait().unwrap();
1450+
1451+
let (mut client, conn) = conn::handshake(tcp).wait().unwrap();
1452+
1453+
runtime.spawn(conn.map(|_| ()).map_err(|e| panic!("conn error: {}", e)));
1454+
1455+
let (mut sender, body) = Body::channel();
1456+
let sender = thread::spawn(move || {
1457+
sender.send_data("hello".into()).ok().unwrap();
1458+
rx.wait().unwrap();
1459+
sender.abort();
1460+
});
1461+
1462+
let req = Request::builder()
1463+
.method(Method::POST)
1464+
.uri("/")
1465+
.body(body)
1466+
.unwrap();
1467+
let res = client.send_request(req);
1468+
res.wait().unwrap_err();
1469+
1470+
server.join().expect("server thread panicked");
1471+
sender.join().expect("sender thread panicked");
1472+
}
1473+
14271474
#[test]
14281475
fn uri_absolute_form() {
14291476
let server = TcpListener::bind("127.0.0.1:0").unwrap();

0 commit comments

Comments
 (0)