44use  crate :: { 
55    bindings:: http:: types:: { self ,  Method ,  Scheme } , 
66    body:: { HostIncomingBody ,  HyperIncomingBody ,  HyperOutgoingBody } , 
7-     dns_error, 
7+     dns_error,  hyper_request_error , 
88} ; 
99use  http_body_util:: BodyExt ; 
1010use  hyper:: header:: HeaderName ; 
@@ -35,17 +35,18 @@ pub trait WasiHttpView: Send {
3535    fn  new_incoming_request ( 
3636        & mut  self , 
3737        req :  hyper:: Request < HyperIncomingBody > , 
38-     )  -> wasmtime:: Result < Resource < HostIncomingRequest > >  { 
38+     )  -> wasmtime:: Result < Resource < HostIncomingRequest > > 
39+     where 
40+         Self :  Sized , 
41+     { 
3942        let  ( parts,  body)  = req. into_parts ( ) ; 
4043        let  body = HostIncomingBody :: new ( 
4144            body, 
4245            // TODO: this needs to be plumbed through 
4346            std:: time:: Duration :: from_millis ( 600  *  1000 ) , 
4447        ) ; 
45-         Ok ( self . table ( ) . push ( HostIncomingRequest  { 
46-             parts, 
47-             body :  Some ( body) , 
48-         } ) ?) 
48+         let  incoming_req = HostIncomingRequest :: new ( self ,  parts,  Some ( body) ) ; 
49+         Ok ( self . table ( ) . push ( incoming_req) ?) 
4950    } 
5051
5152    fn  new_response_outparam ( 
@@ -73,6 +74,41 @@ pub trait WasiHttpView: Send {
7374    } 
7475} 
7576
77+ /// Returns `true` when the header is forbidden according to this [`WasiHttpView`] implementation. 
78+ pub ( crate )  fn  is_forbidden_header ( view :  & mut  dyn  WasiHttpView ,  name :  & HeaderName )  -> bool  { 
79+     static  FORBIDDEN_HEADERS :  [ HeaderName ;  9 ]  = [ 
80+         hyper:: header:: CONNECTION , 
81+         HeaderName :: from_static ( "keep-alive" ) , 
82+         hyper:: header:: PROXY_AUTHENTICATE , 
83+         hyper:: header:: PROXY_AUTHORIZATION , 
84+         HeaderName :: from_static ( "proxy-connection" ) , 
85+         hyper:: header:: TE , 
86+         hyper:: header:: TRANSFER_ENCODING , 
87+         hyper:: header:: UPGRADE , 
88+         HeaderName :: from_static ( "http2-settings" ) , 
89+     ] ; 
90+ 
91+     FORBIDDEN_HEADERS . contains ( name)  || view. is_forbidden_header ( name) 
92+ } 
93+ 
94+ /// Removes forbidden headers from a [`hyper::HeaderMap`]. 
95+ pub ( crate )  fn  remove_forbidden_headers ( 
96+     view :  & mut  dyn  WasiHttpView , 
97+     headers :  & mut  hyper:: HeaderMap , 
98+ )  { 
99+     let  forbidden_keys = Vec :: from_iter ( headers. keys ( ) . filter_map ( |name| { 
100+         if  is_forbidden_header ( view,  name)  { 
101+             Some ( name. clone ( ) ) 
102+         }  else  { 
103+             None 
104+         } 
105+     } ) ) ; 
106+ 
107+     for  name in  forbidden_keys { 
108+         headers. remove ( name) ; 
109+     } 
110+ } 
111+ 
76112pub  fn  default_send_request ( 
77113    view :  & mut  dyn  WasiHttpView , 
78114    OutgoingRequest  { 
@@ -156,20 +192,22 @@ async fn handler(
156192            let  connector = tokio_rustls:: TlsConnector :: from ( std:: sync:: Arc :: new ( config) ) ; 
157193            let  mut  parts = authority. split ( ":" ) ; 
158194            let  host = parts. next ( ) . unwrap_or ( & authority) ; 
159-             let  domain = rustls:: ServerName :: try_from ( host) 
160-                 . map_err ( |_| dns_error ( "invalid dns name" . to_string ( ) ,  0 ) ) ?; 
161-             let  stream = connector
162-                 . connect ( domain,  tcp_stream) 
163-                 . await 
164-                 . map_err ( |_| types:: ErrorCode :: TlsProtocolError ) ?; 
195+             let  domain = rustls:: ServerName :: try_from ( host) . map_err ( |e| { 
196+                 tracing:: warn!( "dns lookup error: {e:?}" ) ; 
197+                 dns_error ( "invalid dns name" . to_string ( ) ,  0 ) 
198+             } ) ?; 
199+             let  stream = connector. connect ( domain,  tcp_stream) . await . map_err ( |e| { 
200+                 tracing:: warn!( "tls protocol error: {e:?}" ) ; 
201+                 types:: ErrorCode :: TlsProtocolError 
202+             } ) ?; 
165203
166204            let  ( sender,  conn)  = timeout ( 
167205                connect_timeout, 
168206                hyper:: client:: conn:: http1:: handshake ( stream) , 
169207            ) 
170208            . await 
171209            . map_err ( |_| types:: ErrorCode :: ConnectionTimeout ) ?
172-             . map_err ( |_| types :: ErrorCode :: ConnectionTimeout ) ?; 
210+             . map_err ( hyper_request_error ) ?; 
173211
174212            let  worker = preview2:: spawn ( async  move  { 
175213                match  conn. await  { 
@@ -190,7 +228,7 @@ async fn handler(
190228        ) 
191229        . await 
192230        . map_err ( |_| types:: ErrorCode :: ConnectionTimeout ) ?
193-         . map_err ( |_| types :: ErrorCode :: HttpProtocolError ) ?; 
231+         . map_err ( hyper_request_error ) ?; 
194232
195233        let  worker = preview2:: spawn ( async  move  { 
196234            match  conn. await  { 
@@ -206,11 +244,8 @@ async fn handler(
206244    let  resp = timeout ( first_byte_timeout,  sender. send_request ( request) ) 
207245        . await 
208246        . map_err ( |_| types:: ErrorCode :: ConnectionReadTimeout ) ?
209-         . map_err ( |_| types:: ErrorCode :: HttpProtocolError ) ?
210-         . map ( |body| { 
211-             body. map_err ( |_| types:: ErrorCode :: HttpProtocolError ) 
212-                 . boxed ( ) 
213-         } ) ; 
247+         . map_err ( hyper_request_error) ?
248+         . map ( |body| body. map_err ( hyper_request_error) . boxed ( ) ) ; 
214249
215250    Ok ( IncomingResponseInternal  { 
216251        resp, 
@@ -264,10 +299,21 @@ impl TryInto<http::Method> for types::Method {
264299} 
265300
266301pub  struct  HostIncomingRequest  { 
267-     pub  parts :  http:: request:: Parts , 
302+     pub ( crate )  parts :  http:: request:: Parts , 
268303    pub  body :  Option < HostIncomingBody > , 
269304} 
270305
306+ impl  HostIncomingRequest  { 
307+     pub  fn  new ( 
308+         view :  & mut  dyn  WasiHttpView , 
309+         mut  parts :  http:: request:: Parts , 
310+         body :  Option < HostIncomingBody > , 
311+     )  -> Self  { 
312+         remove_forbidden_headers ( view,  & mut  parts. headers ) ; 
313+         Self  {  parts,  body } 
314+     } 
315+ } 
316+ 
271317pub  struct  HostResponseOutparam  { 
272318    pub  result : 
273319        tokio:: sync:: oneshot:: Sender < Result < hyper:: Response < HyperOutgoingBody > ,  types:: ErrorCode > > , 
@@ -318,7 +364,7 @@ impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
318364            Some ( body)  => builder. body ( body) , 
319365            None  => builder. body ( 
320366                Empty :: < bytes:: Bytes > :: new ( ) 
321-                     . map_err ( |_| unreachable ! ( ) ) 
367+                     . map_err ( |_| unreachable ! ( "Infallible error" ) ) 
322368                    . boxed ( ) , 
323369            ) , 
324370        } 
0 commit comments