Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,19 @@ fn main() {

{
let request_body = outgoing_body.write().unwrap();
request_body
let e = request_body
.blocking_write_and_flush("more than 11 bytes".as_bytes())
.expect_err("write should fail");

// TODO: show how to use http-error-code to unwrap this error
let e = match e {
test_programs::wasi::io::streams::StreamError::LastOperationFailed(e) => e,
test_programs::wasi::io::streams::StreamError::Closed => panic!("request closed"),
};

assert!(matches!(
http_types::http_error_code(&e),
Some(http_types::ErrorCode::InternalError(Some(msg)))
if msg == "too much written to output stream"));
}

let e =
Expand Down
16 changes: 8 additions & 8 deletions crates/wasi-http/src/http_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
outgoing_handler,
types::{self, Scheme},
},
http_request_error, internal_error,
types::{HostFutureIncomingResponse, HostOutgoingRequest, OutgoingRequest},
WasiHttpView,
};
Expand Down Expand Up @@ -77,22 +78,21 @@ impl<T: WasiHttpView> outgoing_handler::Host for T {
uri = uri.path_and_query(path);
}

builder = builder.uri(
uri.build()
.map_err(|_| types::ErrorCode::HttpRequestUriInvalid)?,
);
builder = builder.uri(uri.build().map_err(http_request_error)?);

for (k, v) in req.headers.iter() {
builder = builder.header(k, v);
}

let body = req
.body
.unwrap_or_else(|| Empty::<Bytes>::new().map_err(|_| todo!("thing")).boxed());
let body = req.body.unwrap_or_else(|| {
Empty::<Bytes>::new()
.map_err(|_| unreachable!("Infallible error"))
.boxed()
});

let request = builder
.body(body)
.map_err(|err| types::ErrorCode::InternalError(Some(err.to_string())))?;
.map_err(|err| internal_error(err.to_string()))?;

Ok(Ok(self.send_request(OutgoingRequest {
use_tls,
Expand Down
52 changes: 52 additions & 0 deletions crates/wasi-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod bindings {
tracing: true,
async: false,
with: {
"wasi:io/error": wasmtime_wasi::preview2::bindings::io::error,
"wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams,
"wasi:io/poll": wasmtime_wasi::preview2::bindings::io::poll,

Expand Down Expand Up @@ -47,3 +48,54 @@ pub(crate) fn dns_error(rcode: String, info_code: u16) -> bindings::http::types:
pub(crate) fn internal_error(msg: String) -> bindings::http::types::ErrorCode {
bindings::http::types::ErrorCode::InternalError(Some(msg))
}

/// Translate a [`http::Error`] to a wasi-http `ErrorCode` in the context of a request.
pub fn http_request_error(err: http::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;

if err.is::<http::uri::InvalidUri>() {
return ErrorCode::HttpRequestUriInvalid;
}

tracing::warn!("http request error: {err:?}");

ErrorCode::HttpProtocolError
}

/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
pub fn hyper_request_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;
use std::error::Error;

// If there's a source, we might be able to extract a wasi-http error from it.
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
return err.clone();
}
}

tracing::warn!("hyper request error: {err:?}");

ErrorCode::HttpProtocolError
}

/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a response.
pub fn hyper_response_error(err: hyper::Error) -> bindings::http::types::ErrorCode {
use bindings::http::types::ErrorCode;
use std::error::Error;

if err.is_timeout() {
return ErrorCode::HttpResponseTimeout;
}

// If there's a source, we might be able to extract a wasi-http error from it.
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<ErrorCode>() {
return err.clone();
}
}

tracing::warn!("hyper response error: {err:?}");

ErrorCode::HttpProtocolError
}
29 changes: 14 additions & 15 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::{
bindings::http::types::{self, Method, Scheme},
body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
dns_error,
dns_error, hyper_request_error,
};
use http_body_util::BodyExt;
use hyper::header::HeaderName;
Expand Down Expand Up @@ -156,20 +156,22 @@ async fn handler(
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = rustls::ServerName::try_from(host)
.map_err(|_| dns_error("invalid dns name".to_string(), 0))?;
let stream = connector
.connect(domain, tcp_stream)
.await
.map_err(|_| types::ErrorCode::TlsProtocolError)?;
let domain = rustls::ServerName::try_from(host).map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?;
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
tracing::warn!("tls protocol error: {e:?}");
types::ErrorCode::TlsProtocolError
})?;

let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|_| types::ErrorCode::ConnectionTimeout)?;
.map_err(hyper_request_error)?;

let worker = preview2::spawn(async move {
match conn.await {
Expand All @@ -190,7 +192,7 @@ async fn handler(
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|_| types::ErrorCode::HttpProtocolError)?;
.map_err(hyper_request_error)?;

let worker = preview2::spawn(async move {
match conn.await {
Expand All @@ -206,11 +208,8 @@ async fn handler(
let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
.map_err(|_| types::ErrorCode::HttpProtocolError)?
.map(|body| {
body.map_err(|_| types::ErrorCode::HttpProtocolError)
.boxed()
});
.map_err(hyper_request_error)?
.map(|body| body.map_err(hyper_request_error).boxed());

Ok(IncomingResponseInternal {
resp,
Expand Down Expand Up @@ -318,7 +317,7 @@ impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
Some(body) => builder.body(body),
None => builder.body(
Empty::<bytes::Bytes>::new()
.map_err(|_| unreachable!())
.map_err(|_| unreachable!("Infallible error"))
.boxed(),
),
}
Expand Down
5 changes: 3 additions & 2 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use wasmtime_wasi::preview2::{
impl<T: WasiHttpView> crate::bindings::http::types::Host for T {
fn http_error_code(
&mut self,
_err: wasmtime::component::Resource<types::IoError>,
err: wasmtime::component::Resource<types::IoError>,
) -> wasmtime::Result<Option<types::ErrorCode>> {
todo!()
let e = self.table().get(&err)?;
Ok(e.downcast_ref::<types::ErrorCode>().cloned())
}
}

Expand Down
15 changes: 4 additions & 11 deletions src/commands/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use wasmtime_wasi::preview2::{
self, StreamError, StreamResult, Table, WasiCtx, WasiCtxBuilder, WasiView,
};
use wasmtime_wasi_http::{
bindings::http::types as http_types, body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView,
body::HyperOutgoingBody, hyper_response_error, WasiHttpCtx, WasiHttpView,
};

#[cfg(feature = "wasi-nn")]
Expand Down Expand Up @@ -365,16 +365,9 @@ impl hyper::service::Service<Request> for ProxyHandler {

let mut store = inner.cmd.new_store(&inner.engine, req_id)?;

let req = store.data_mut().new_incoming_request(req.map(|body| {
body.map_err(|err| {
if err.is_timeout() {
http_types::ErrorCode::HttpResponseTimeout
} else {
http_types::ErrorCode::InternalError(Some(err.message().to_string()))
}
})
.boxed()
}))?;
let req = store
.data_mut()
.new_incoming_request(req.map(|body| body.map_err(hyper_response_error).boxed()))?;

let out = store.data_mut().new_response_outparam(sender)?;

Expand Down