Skip to content

Commit de1f24e

Browse files
authored
Backport additional wasi-http changes to the 15.x release branch (#7540)
* wasi-http: Implement http-error-code, and centralize error conversions (#7534) * Filter out forbidden headers on incoming request and response resources (#7538)
1 parent bb8eec8 commit de1f24e

File tree

7 files changed

+156
-69
lines changed

7 files changed

+156
-69
lines changed

crates/test-programs/src/bin/api_proxy.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ impl bindings::exports::wasi::http::incoming_handler::Guest for T {
2020
let req_hdrs = request.headers();
2121

2222
assert!(
23-
!req_hdrs.get(&header).is_empty(),
24-
"missing `custom-forbidden-header` from request"
23+
req_hdrs.get(&header).is_empty(),
24+
"forbidden `custom-forbidden-header` found in request"
2525
);
2626

2727
assert!(req_hdrs.delete(&header).is_err());
28+
assert!(req_hdrs.append(&header, &b"no".to_vec()).is_err());
2829

2930
assert!(
30-
!req_hdrs.get(&header).is_empty(),
31-
"delete of forbidden header succeeded"
31+
req_hdrs.get(&header).is_empty(),
32+
"append of forbidden header succeeded"
3233
);
3334

3435
let hdrs = bindings::wasi::http::types::Headers::new();

crates/test-programs/src/bin/http_outbound_request_content_length.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,19 @@ fn main() {
7070

7171
{
7272
let request_body = outgoing_body.write().unwrap();
73-
request_body
73+
let e = request_body
7474
.blocking_write_and_flush("more than 11 bytes".as_bytes())
7575
.expect_err("write should fail");
7676

77-
// TODO: show how to use http-error-code to unwrap this error
77+
let e = match e {
78+
test_programs::wasi::io::streams::StreamError::LastOperationFailed(e) => e,
79+
test_programs::wasi::io::streams::StreamError::Closed => panic!("request closed"),
80+
};
81+
82+
assert!(matches!(
83+
http_types::http_error_code(&e),
84+
Some(http_types::ErrorCode::InternalError(Some(msg)))
85+
if msg == "too much written to output stream"));
7886
}
7987

8088
let e =

crates/wasi-http/src/http_impl.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::{
33
outgoing_handler,
44
types::{self, Scheme},
55
},
6+
http_request_error, internal_error,
67
types::{HostFutureIncomingResponse, HostOutgoingRequest, OutgoingRequest},
78
WasiHttpView,
89
};
@@ -77,22 +78,21 @@ impl<T: WasiHttpView> outgoing_handler::Host for T {
7778
uri = uri.path_and_query(path);
7879
}
7980

80-
builder = builder.uri(
81-
uri.build()
82-
.map_err(|_| types::ErrorCode::HttpRequestUriInvalid)?,
83-
);
81+
builder = builder.uri(uri.build().map_err(http_request_error)?);
8482

8583
for (k, v) in req.headers.iter() {
8684
builder = builder.header(k, v);
8785
}
8886

89-
let body = req
90-
.body
91-
.unwrap_or_else(|| Empty::<Bytes>::new().map_err(|_| todo!("thing")).boxed());
87+
let body = req.body.unwrap_or_else(|| {
88+
Empty::<Bytes>::new()
89+
.map_err(|_| unreachable!("Infallible error"))
90+
.boxed()
91+
});
9292

9393
let request = builder
9494
.body(body)
95-
.map_err(|err| types::ErrorCode::InternalError(Some(err.to_string())))?;
95+
.map_err(|err| internal_error(err.to_string()))?;
9696

9797
Ok(Ok(self.send_request(OutgoingRequest {
9898
use_tls,

crates/wasi-http/src/lib.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub mod bindings {
1717
tracing: true,
1818
async: false,
1919
with: {
20+
"wasi:io/error": wasmtime_wasi::preview2::bindings::io::error,
2021
"wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams,
2122
"wasi:io/poll": wasmtime_wasi::preview2::bindings::io::poll,
2223

@@ -47,3 +48,54 @@ pub(crate) fn dns_error(rcode: String, info_code: u16) -> bindings::http::types:
4748
pub(crate) fn internal_error(msg: String) -> bindings::http::types::ErrorCode {
4849
bindings::http::types::ErrorCode::InternalError(Some(msg))
4950
}
51+
52+
/// Translate a [`http::Error`] to a wasi-http `ErrorCode` in the context of a request.
53+
pub fn http_request_error(err: http::Error) -> bindings::http::types::ErrorCode {
54+
use bindings::http::types::ErrorCode;
55+
56+
if err.is::<http::uri::InvalidUri>() {
57+
return ErrorCode::HttpRequestUriInvalid;
58+
}
59+
60+
tracing::warn!("http request error: {err:?}");
61+
62+
ErrorCode::HttpProtocolError
63+
}
64+
65+
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
66+
pub fn hyper_request_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
67+
use bindings::http::types::ErrorCode;
68+
use std::error::Error;
69+
70+
// If there's a source, we might be able to extract a wasi-http error from it.
71+
if let Some(cause) = err.source() {
72+
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
73+
return err.clone();
74+
}
75+
}
76+
77+
tracing::warn!("hyper request error: {err:?}");
78+
79+
ErrorCode::HttpProtocolError
80+
}
81+
82+
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a response.
83+
pub fn hyper_response_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
84+
use bindings::http::types::ErrorCode;
85+
use std::error::Error;
86+
87+
if err.is_timeout() {
88+
return ErrorCode::HttpResponseTimeout;
89+
}
90+
91+
// If there's a source, we might be able to extract a wasi-http error from it.
92+
if let Some(cause) = err.source() {
93+
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
94+
return err.clone();
95+
}
96+
}
97+
98+
tracing::warn!("hyper response error: {err:?}");
99+
100+
ErrorCode::HttpProtocolError
101+
}

crates/wasi-http/src/types.rs

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use crate::{
55
bindings::http::types::{self, Method, Scheme},
66
body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
7-
dns_error,
7+
dns_error, hyper_request_error,
88
};
99
use http_body_util::BodyExt;
1010
use hyper::header::HeaderName;
@@ -35,17 +35,18 @@ pub trait WasiHttpView: Send {
3535
fn new_incoming_request(
3636
&mut self,
3737
req: hyper::Request<HyperIncomingBody>,
38-
) -> wasmtime::Result<Resource<HostIncomingRequest>> {
38+
) -> wasmtime::Result<Resource<HostIncomingRequest>>
39+
where
40+
Self: Sized,
41+
{
3942
let (parts, body) = req.into_parts();
4043
let body = HostIncomingBody::new(
4144
body,
4245
// TODO: this needs to be plumbed through
4346
std::time::Duration::from_millis(600 * 1000),
4447
);
45-
Ok(self.table().push(HostIncomingRequest {
46-
parts,
47-
body: Some(body),
48-
})?)
48+
let incoming_req = HostIncomingRequest::new(self, parts, Some(body));
49+
Ok(self.table().push(incoming_req)?)
4950
}
5051

5152
fn new_response_outparam(
@@ -73,6 +74,41 @@ pub trait WasiHttpView: Send {
7374
}
7475
}
7576

77+
/// Returns `true` when the header is forbidden according to this [`WasiHttpView`] implementation.
78+
pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
79+
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
80+
hyper::header::CONNECTION,
81+
HeaderName::from_static("keep-alive"),
82+
hyper::header::PROXY_AUTHENTICATE,
83+
hyper::header::PROXY_AUTHORIZATION,
84+
HeaderName::from_static("proxy-connection"),
85+
hyper::header::TE,
86+
hyper::header::TRANSFER_ENCODING,
87+
hyper::header::UPGRADE,
88+
HeaderName::from_static("http2-settings"),
89+
];
90+
91+
FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
92+
}
93+
94+
/// Removes forbidden headers from a [`hyper::HeaderMap`].
95+
pub(crate) fn remove_forbidden_headers(
96+
view: &mut dyn WasiHttpView,
97+
headers: &mut hyper::HeaderMap,
98+
) {
99+
let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
100+
if is_forbidden_header(view, name) {
101+
Some(name.clone())
102+
} else {
103+
None
104+
}
105+
}));
106+
107+
for name in forbidden_keys {
108+
headers.remove(name);
109+
}
110+
}
111+
76112
pub fn default_send_request(
77113
view: &mut dyn WasiHttpView,
78114
OutgoingRequest {
@@ -156,20 +192,22 @@ async fn handler(
156192
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
157193
let mut parts = authority.split(":");
158194
let host = parts.next().unwrap_or(&authority);
159-
let domain = rustls::ServerName::try_from(host)
160-
.map_err(|_| dns_error("invalid dns name".to_string(), 0))?;
161-
let stream = connector
162-
.connect(domain, tcp_stream)
163-
.await
164-
.map_err(|_| types::ErrorCode::TlsProtocolError)?;
195+
let domain = rustls::ServerName::try_from(host).map_err(|e| {
196+
tracing::warn!("dns lookup error: {e:?}");
197+
dns_error("invalid dns name".to_string(), 0)
198+
})?;
199+
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
200+
tracing::warn!("tls protocol error: {e:?}");
201+
types::ErrorCode::TlsProtocolError
202+
})?;
165203

166204
let (sender, conn) = timeout(
167205
connect_timeout,
168206
hyper::client::conn::http1::handshake(stream),
169207
)
170208
.await
171209
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
172-
.map_err(|_| types::ErrorCode::ConnectionTimeout)?;
210+
.map_err(hyper_request_error)?;
173211

174212
let worker = preview2::spawn(async move {
175213
match conn.await {
@@ -190,7 +228,7 @@ async fn handler(
190228
)
191229
.await
192230
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
193-
.map_err(|_| types::ErrorCode::HttpProtocolError)?;
231+
.map_err(hyper_request_error)?;
194232

195233
let worker = preview2::spawn(async move {
196234
match conn.await {
@@ -206,11 +244,8 @@ async fn handler(
206244
let resp = timeout(first_byte_timeout, sender.send_request(request))
207245
.await
208246
.map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
209-
.map_err(|_| types::ErrorCode::HttpProtocolError)?
210-
.map(|body| {
211-
body.map_err(|_| types::ErrorCode::HttpProtocolError)
212-
.boxed()
213-
});
247+
.map_err(hyper_request_error)?
248+
.map(|body| body.map_err(hyper_request_error).boxed());
214249

215250
Ok(IncomingResponseInternal {
216251
resp,
@@ -264,10 +299,21 @@ impl TryInto<http::Method> for types::Method {
264299
}
265300

266301
pub struct HostIncomingRequest {
267-
pub parts: http::request::Parts,
302+
pub(crate) parts: http::request::Parts,
268303
pub body: Option<HostIncomingBody>,
269304
}
270305

306+
impl HostIncomingRequest {
307+
pub fn new(
308+
view: &mut dyn WasiHttpView,
309+
mut parts: http::request::Parts,
310+
body: Option<HostIncomingBody>,
311+
) -> Self {
312+
remove_forbidden_headers(view, &mut parts.headers);
313+
Self { parts, body }
314+
}
315+
}
316+
271317
pub struct HostResponseOutparam {
272318
pub result:
273319
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
@@ -318,7 +364,7 @@ impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
318364
Some(body) => builder.body(body),
319365
None => builder.body(
320366
Empty::<bytes::Bytes>::new()
321-
.map_err(|_| unreachable!())
367+
.map_err(|_| unreachable!("Infallible error"))
322368
.boxed(),
323369
),
324370
}

crates/wasi-http/src/types_impl.rs

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use crate::{
22
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
33
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody},
44
types::{
5-
FieldMap, HostFields, HostFutureIncomingResponse, HostIncomingRequest,
6-
HostIncomingResponse, HostOutgoingRequest, HostOutgoingResponse, HostResponseOutparam,
5+
is_forbidden_header, remove_forbidden_headers, FieldMap, HostFields,
6+
HostFutureIncomingResponse, HostIncomingRequest, HostIncomingResponse, HostOutgoingRequest,
7+
HostOutgoingResponse, HostResponseOutparam,
78
},
89
WasiHttpView,
910
};
1011
use anyhow::Context;
11-
use hyper::header::HeaderName;
1212
use std::any::Any;
1313
use std::str::FromStr;
1414
use wasmtime::component::Resource;
@@ -20,9 +20,10 @@ use wasmtime_wasi::preview2::{
2020
impl<T: WasiHttpView> crate::bindings::http::types::Host for T {
2121
fn http_error_code(
2222
&mut self,
23-
_err: wasmtime::component::Resource<types::IoError>,
23+
err: wasmtime::component::Resource<types::IoError>,
2424
) -> wasmtime::Result<Option<types::ErrorCode>> {
25-
todo!()
25+
let e = self.table().get(&err)?;
26+
Ok(e.downcast_ref::<types::ErrorCode>().cloned())
2627
}
2728
}
2829

@@ -88,22 +89,6 @@ fn get_fields_mut<'a>(
8889
}
8990
}
9091

91-
fn is_forbidden_header<T: WasiHttpView>(view: &mut T, name: &HeaderName) -> bool {
92-
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
93-
hyper::header::CONNECTION,
94-
HeaderName::from_static("keep-alive"),
95-
hyper::header::PROXY_AUTHENTICATE,
96-
hyper::header::PROXY_AUTHORIZATION,
97-
HeaderName::from_static("proxy-connection"),
98-
hyper::header::TE,
99-
hyper::header::TRANSFER_ENCODING,
100-
hyper::header::UPGRADE,
101-
HeaderName::from_static("http2-settings"),
102-
];
103-
104-
FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
105-
}
106-
10792
impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
10893
fn new(&mut self) -> wasmtime::Result<Resource<HostFields>> {
10994
let id = self
@@ -833,11 +818,13 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
833818
Ok(Err(e)) => return Ok(Some(Ok(Err(e)))),
834819
};
835820

836-
let (parts, body) = resp.resp.into_parts();
821+
let (mut parts, body) = resp.resp.into_parts();
822+
823+
remove_forbidden_headers(self, &mut parts.headers);
837824

838825
let resp = self.table().push(HostIncomingResponse {
839826
status: parts.status.as_u16(),
840-
headers: FieldMap::from(parts.headers),
827+
headers: parts.headers,
841828
body: Some({
842829
let mut body = HostIncomingBody::new(body, resp.between_bytes_timeout);
843830
body.retain_worker(&resp.worker);

0 commit comments

Comments
 (0)