Skip to content

Commit 0b71694

Browse files
committed
feat(server): add Expect 100-continue support
Adds a new method to `Handler`, with a default implementation of always responding with a `100 Continue` when sent an expectation. Closes #369
1 parent fe8c6d9 commit 0b71694

File tree

3 files changed

+151
-36
lines changed

3 files changed

+151
-36
lines changed

src/header/common/expect.rs

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use std::fmt;
2+
3+
use header::{Header, HeaderFormat};
4+
5+
/// The `Expect` header.
6+
///
7+
/// > The "Expect" header field in a request indicates a certain set of
8+
/// > behaviors (expectations) that need to be supported by the server in
9+
/// > order to properly handle this request. The only such expectation
10+
/// > defined by this specification is 100-continue.
11+
/// >
12+
/// > Expect = "100-continue"
13+
#[derive(Copy, Clone, PartialEq, Debug)]
14+
pub enum Expect {
15+
/// The value `100-continue`.
16+
Continue
17+
}
18+
19+
impl Header for Expect {
20+
fn header_name() -> &'static str {
21+
"Expect"
22+
}
23+
24+
fn parse_header(raw: &[Vec<u8>]) -> Option<Expect> {
25+
if &[b"100-continue"] == raw {
26+
Some(Expect::Continue)
27+
} else {
28+
None
29+
}
30+
}
31+
}
32+
33+
impl HeaderFormat for Expect {
34+
fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result {
35+
f.write_str("100-continue")
36+
}
37+
}

src/header/common/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub use self::content_type::ContentType;
2020
pub use self::cookie::Cookie;
2121
pub use self::date::Date;
2222
pub use self::etag::Etag;
23+
pub use self::expect::Expect;
2324
pub use self::expires::Expires;
2425
pub use self::host::Host;
2526
pub use self::if_match::IfMatch;
@@ -160,6 +161,7 @@ mod content_length;
160161
mod content_type;
161162
mod date;
162163
mod etag;
164+
mod expect;
163165
mod expires;
164166
mod host;
165167
mod if_match;

src/server/mod.rs

+112-36
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! HTTP Server
2-
use std::io::{BufReader, BufWriter};
2+
use std::io::{BufReader, BufWriter, Write};
33
use std::marker::PhantomData;
44
use std::net::{IpAddr, SocketAddr};
55
use std::path::Path;
@@ -14,9 +14,12 @@ pub use net::{Fresh, Streaming};
1414

1515
use HttpError::HttpIoError;
1616
use {HttpResult};
17-
use header::Connection;
17+
use header::{Headers, Connection, Expect};
1818
use header::ConnectionOption::{Close, KeepAlive};
19+
use method::Method;
1920
use net::{NetworkListener, NetworkStream, HttpListener};
21+
use status::StatusCode;
22+
use uri::RequestUri;
2023
use version::HttpVersion::{Http10, Http11};
2124

2225
use self::listener::ListenerPool;
@@ -99,7 +102,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> {
99102

100103
debug!("threads = {:?}", threads);
101104
let pool = ListenerPool::new(listener.clone());
102-
let work = move |stream| keep_alive_loop(stream, &handler);
105+
let work = move |mut stream| handle_connection(&mut stream, &handler);
103106

104107
let guard = thread::scoped(move || pool.accept(work, threads));
105108

@@ -111,7 +114,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> {
111114
}
112115

113116

114-
fn keep_alive_loop<'h, S, H>(mut stream: S, handler: &'h H)
117+
fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H)
115118
where S: NetworkStream + Clone, H: Handler {
116119
debug!("Incoming stream");
117120
let addr = match stream.peer_addr() {
@@ -128,39 +131,45 @@ where S: NetworkStream + Clone, H: Handler {
128131

129132
let mut keep_alive = true;
130133
while keep_alive {
131-
keep_alive = handle_connection(addr, &mut rdr, &mut wrt, handler);
132-
debug!("keep_alive = {:?}", keep_alive);
133-
}
134-
}
134+
let req = match Request::new(&mut rdr, addr) {
135+
Ok(req) => req,
136+
Err(e@HttpIoError(_)) => {
137+
debug!("ioerror in keepalive loop = {:?}", e);
138+
break;
139+
}
140+
Err(e) => {
141+
//TODO: send a 400 response
142+
error!("request error = {:?}", e);
143+
break;
144+
}
145+
};
135146

136-
fn handle_connection<'a, 'aa, 'h, S, H>(
137-
addr: SocketAddr,
138-
rdr: &'a mut BufReader<&'aa mut NetworkStream>,
139-
wrt: &mut BufWriter<S>,
140-
handler: &'h H
141-
) -> bool where 'aa: 'a, S: NetworkStream, H: Handler {
142-
let mut res = Response::new(wrt);
143-
let req = match Request::<'a, 'aa>::new(rdr, addr) {
144-
Ok(req) => req,
145-
Err(e@HttpIoError(_)) => {
146-
debug!("ioerror in keepalive loop = {:?}", e);
147-
return false;
148-
}
149-
Err(e) => {
150-
//TODO: send a 400 response
151-
error!("request error = {:?}", e);
152-
return false;
147+
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
148+
let status = handler.check_continue((&req.method, &req.uri, &req.headers));
149+
match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) {
150+
Ok(..) => (),
151+
Err(e) => {
152+
error!("error writing 100-continue: {:?}", e);
153+
break;
154+
}
155+
}
156+
157+
if status != StatusCode::Continue {
158+
debug!("non-100 status ({}) for Expect 100 request", status);
159+
break;
160+
}
153161
}
154-
};
155162

156-
let keep_alive = match (req.version, req.headers.get::<Connection>()) {
157-
(Http10, Some(conn)) if !conn.contains(&KeepAlive) => false,
158-
(Http11, Some(conn)) if conn.contains(&Close) => false,
159-
_ => true
160-
};
161-
res.version = req.version;
162-
handler.handle(req, res);
163-
keep_alive
163+
keep_alive = match (req.version, req.headers.get::<Connection>()) {
164+
(Http10, Some(conn)) if !conn.contains(&KeepAlive) => false,
165+
(Http11, Some(conn)) if conn.contains(&Close) => false,
166+
_ => true
167+
};
168+
let mut res = Response::new(&mut wrt);
169+
res.version = req.version;
170+
handler.handle(req, res);
171+
debug!("keep_alive = {:?}", keep_alive);
172+
}
164173
}
165174

166175
/// A listening server, which can later be closed.
@@ -184,11 +193,78 @@ pub trait Handler: Sync + Send {
184193
/// Receives a `Request`/`Response` pair, and should perform some action on them.
185194
///
186195
/// This could reading from the request, and writing to the response.
187-
fn handle<'a, 'aa, 'b, 's>(&'s self, Request<'aa, 'a>, Response<'b, Fresh>);
196+
fn handle<'a, 'k>(&'a self, Request<'a, 'k>, Response<'a, Fresh>);
197+
198+
/// Called when a Request includes a `Expect: 100-continue` header.
199+
///
200+
/// By default, this will always immediately response with a `StatusCode::Continue`,
201+
/// but can be overridden with custom behavior.
202+
fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
203+
StatusCode::Continue
204+
}
188205
}
189206

190207
impl<F> Handler for F where F: Fn(Request, Response<Fresh>), F: Sync + Send {
191-
fn handle<'a, 'aa, 'b, 's>(&'s self, req: Request<'a, 'aa>, res: Response<'b, Fresh>) {
208+
fn handle<'a, 'k>(&'a self, req: Request<'a, 'k>, res: Response<'a, Fresh>) {
192209
self(req, res)
193210
}
194211
}
212+
213+
#[cfg(test)]
214+
mod tests {
215+
use header::Headers;
216+
use method::Method;
217+
use mock::MockStream;
218+
use status::StatusCode;
219+
use uri::RequestUri;
220+
221+
use super::{Request, Response, Fresh, Handler, handle_connection};
222+
223+
#[test]
224+
fn test_check_continue_default() {
225+
let mut mock = MockStream::with_input(b"\
226+
POST /upload HTTP/1.1\r\n\
227+
Host: example.domain\r\n\
228+
Expect: 100-continue\r\n\
229+
Content-Length: 10\r\n\
230+
\r\n\
231+
1234567890\
232+
");
233+
234+
fn handle(_: Request, res: Response<Fresh>) {
235+
res.start().unwrap().end().unwrap();
236+
}
237+
238+
handle_connection(&mut mock, &handle);
239+
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
240+
assert_eq!(&mock.write[..cont.len()], cont);
241+
let res = b"HTTP/1.1 200 OK\r\n";
242+
assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res);
243+
}
244+
245+
#[test]
246+
fn test_check_continue_reject() {
247+
struct Reject;
248+
impl Handler for Reject {
249+
fn handle<'a, 'k>(&'a self, _: Request<'a, 'k>, res: Response<'a, Fresh>) {
250+
res.start().unwrap().end().unwrap();
251+
}
252+
253+
fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
254+
StatusCode::ExpectationFailed
255+
}
256+
}
257+
258+
let mut mock = MockStream::with_input(b"\
259+
POST /upload HTTP/1.1\r\n\
260+
Host: example.domain\r\n\
261+
Expect: 100-continue\r\n\
262+
Content-Length: 10\r\n\
263+
\r\n\
264+
1234567890\
265+
");
266+
267+
handle_connection(&mut mock, &Reject);
268+
assert_eq!(mock.write, b"HTTP/1.1 417 Expectation Failed\r\n\r\n");
269+
}
270+
}

0 commit comments

Comments
 (0)