diff --git a/src/dispatcher.rs b/src/dispatcher.rs index b98e32fe..58fb7a65 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -42,6 +42,10 @@ pub(crate) fn register_grpc_callout(token_id: u32) { DISPATCHER.with(|dispatcher| dispatcher.register_grpc_callout(token_id)); } +pub(crate) fn register_grpc_stream(token_id: u32) { + DISPATCHER.with(|dispatcher| dispatcher.register_grpc_stream(token_id)); +} + struct NoopRoot; impl Context for NoopRoot {} @@ -57,6 +61,7 @@ struct Dispatcher { active_id: Cell, callouts: RefCell>, grpc_callouts: RefCell>, + grpc_streams: RefCell>, } impl Dispatcher { @@ -71,6 +76,7 @@ impl Dispatcher { active_id: Cell::new(0), callouts: RefCell::new(HashMap::new()), grpc_callouts: RefCell::new(HashMap::new()), + grpc_streams: RefCell::new(HashMap::new()), } } @@ -97,6 +103,17 @@ impl Dispatcher { } } + fn register_grpc_stream(&self, token_id: u32) { + if self + .grpc_streams + .borrow_mut() + .insert(token_id, self.active_id.get()) + .is_some() + { + panic!("duplicate token_id") + } + } + fn register_grpc_callout(&self, token_id: u32) { if self .grpc_callouts @@ -399,47 +416,116 @@ impl Dispatcher { } } - fn on_grpc_receive(&self, token_id: u32, response_size: usize) { - let context_id = self - .grpc_callouts + fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) { + let context_id = *self + .grpc_streams .borrow_mut() - .remove(&token_id) + .get(&token_id) .expect("invalid token_id"); if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); - http_stream.on_grpc_call_response(token_id, 0, response_size); + http_stream.on_grpc_stream_initial_metadata(token_id, headers); } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); - stream.on_grpc_call_response(token_id, 0, response_size); + stream.on_grpc_stream_initial_metadata(token_id, headers); } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); - root.on_grpc_call_response(token_id, 0, response_size); + root.on_grpc_stream_initial_metadata(token_id, headers); } } - fn on_grpc_close(&self, token_id: u32, status_code: u32) { - let context_id = self - .grpc_callouts + fn on_grpc_receive(&self, token_id: u32, response_size: usize) { + if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) { + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_call_response(token_id, 0, response_size); + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_call_response(token_id, 0, response_size); + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_call_response(token_id, 0, response_size); + } + } else if let Some(context_id) = self.grpc_streams.borrow_mut().get(&token_id) { + let context_id = *context_id; + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_stream_message(token_id, response_size); + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_stream_message(token_id, response_size); + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_stream_message(token_id, response_size); + } + } else { + panic!("invalid token_id") + } + } + + fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) { + let context_id = *self + .grpc_streams .borrow_mut() - .remove(&token_id) + .get(&token_id) .expect("invalid token_id"); if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); - http_stream.on_grpc_call_response(token_id, status_code, 0); + http_stream.on_grpc_stream_trailing_metadata(token_id, trailers); } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); - stream.on_grpc_call_response(token_id, status_code, 0); + stream.on_grpc_stream_trailing_metadata(token_id, trailers); } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); - root.on_grpc_call_response(token_id, status_code, 0); + root.on_grpc_stream_trailing_metadata(token_id, trailers); + } + } + + fn on_grpc_close(&self, token_id: u32, status_code: u32) { + if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) { + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_call_response(token_id, status_code, 0); + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_call_response(token_id, status_code, 0); + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_call_response(token_id, status_code, 0); + } + } else if let Some(context_id) = self.grpc_streams.borrow_mut().remove(&token_id) { + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_stream_close(token_id, status_code) + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_stream_close(token_id, status_code) + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_stream_close(token_id, status_code) + } + } else { + panic!("invalid token_id") } } } @@ -571,11 +657,29 @@ pub extern "C" fn proxy_on_http_call_response( }) } +#[no_mangle] +pub extern "C" fn proxy_on_grpc_receive_initial_metadata( + _context_id: u32, + token_id: u32, + headers: u32, +) { + DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_initial_metadata(token_id, headers)) +} + #[no_mangle] pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) { DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size)) } +#[no_mangle] +pub extern "C" fn proxy_on_grpc_receive_trailing_metadata( + _context_id: u32, + token_id: u32, + trailers: u32, +) { + DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_trailing_metadata(token_id, trailers)) +} + #[no_mangle] pub extern "C" fn proxy_on_grpc_close(_context_id: u32, token_id: u32, status_code: u32) { DISPATCHER.with(|dispatcher| dispatcher.on_grpc_close(token_id, status_code)) diff --git a/src/hostcalls.rs b/src/hostcalls.rs index 6b4e960d..947d0559 100644 --- a/src/hostcalls.rs +++ b/src/hostcalls.rs @@ -177,6 +177,24 @@ pub fn get_map(map_type: MapType) -> Result, Status> { } } +pub fn get_map_bytes(map_type: MapType) -> Result)>, Status> { + unsafe { + let mut return_data: *mut u8 = null_mut(); + let mut return_size: usize = 0; + match proxy_get_header_map_pairs(map_type, &mut return_data, &mut return_size) { + Status::Ok => { + if !return_data.is_null() { + let serialized_map = Vec::from_raw_parts(return_data, return_size, return_size); + Ok(utils::deserialize_bytes_map(&serialized_map)) + } else { + Ok(Vec::new()) + } + } + status => panic!("unexpected status: {}", status as u32), + } + } +} + extern "C" { fn proxy_set_header_map_pairs( map_type: MapType, @@ -677,7 +695,7 @@ pub fn dispatch_grpc_call( timeout: Duration, ) -> Result { let mut return_callout_id = 0; - let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata); + let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata); unsafe { match proxy_grpc_call( upstream_name.as_ptr(), @@ -704,6 +722,80 @@ pub fn dispatch_grpc_call( } } +extern "C" { + fn proxy_grpc_stream( + upstream_data: *const u8, + upstream_size: usize, + service_name_data: *const u8, + service_name_size: usize, + method_name_data: *const u8, + method_name_size: usize, + initial_metadata_data: *const u8, + initial_metadata_size: usize, + return_stream_id: *mut u32, + ) -> Status; +} + +pub fn open_grpc_stream( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, +) -> Result { + let mut return_stream_id = 0; + let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata); + unsafe { + match proxy_grpc_stream( + upstream_name.as_ptr(), + upstream_name.len(), + service_name.as_ptr(), + service_name.len(), + method_name.as_ptr(), + method_name.len(), + serialized_initial_metadata.as_ptr(), + serialized_initial_metadata.len(), + &mut return_stream_id, + ) { + Status::Ok => { + dispatcher::register_grpc_stream(return_stream_id); + Ok(return_stream_id) + } + Status::ParseFailure => Err(Status::ParseFailure), + Status::InternalFailure => Err(Status::InternalFailure), + status => panic!("unexpected status: {}", status as u32), + } + } +} + +extern "C" { + fn proxy_grpc_send( + token: u32, + message_ptr: *const u8, + message_len: usize, + end_stream: bool, + ) -> Status; +} + +pub fn send_grpc_stream_message( + token: u32, + message: Option<&[u8]>, + end_stream: bool, +) -> Result<(), Status> { + unsafe { + match proxy_grpc_send( + token, + message.map_or(null(), |message| message.as_ptr()), + message.map_or(0, |message| message.len()), + end_stream, + ) { + Status::Ok => Ok(()), + Status::BadArgument => Err(Status::BadArgument), + Status::NotFound => Err(Status::NotFound), + status => panic!("unexpected status: {}", status as u32), + } + } +} + extern "C" { fn proxy_grpc_cancel(token_id: u32) -> Status; } @@ -718,6 +810,20 @@ pub fn cancel_grpc_call(token_id: u32) -> Result<(), Status> { } } +extern "C" { + fn proxy_grpc_close(token_id: u32) -> Status; +} + +pub fn close_grpc_stream(token_id: u32) -> Result<(), Status> { + unsafe { + match proxy_grpc_close(token_id) { + Status::Ok => Ok(()), + Status::NotFound => Err(Status::NotFound), + status => panic!("unexpected status: {}", status as u32), + } + } +} + extern "C" { fn proxy_set_effective_context(context_id: u32) -> Status; } @@ -850,7 +956,7 @@ mod utils { bytes } - pub(super) fn serialize_bytes_value_map(map: Vec<(&str, &[u8])>) -> Bytes { + pub(super) fn serialize_bytes_map(map: Vec<(&str, &[u8])>) -> Bytes { let mut size: usize = 4; for (name, value) in &map { size += name.len() + value.len() + 10; @@ -893,4 +999,25 @@ mod utils { } map } + + pub(super) fn deserialize_bytes_map(bytes: &[u8]) -> Vec<(String, Vec)> { + let mut map = Vec::new(); + if bytes.is_empty() { + return map; + } + let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[0..4]).unwrap()) as usize; + let mut p = 4 + size * 8; + for n in 0..size { + let s = 4 + n * 8; + let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s..s + 4]).unwrap()) as usize; + let key = bytes[p..p + size].to_vec(); + p += size + 1; + let size = + u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s + 4..s + 8]).unwrap()) as usize; + let value = bytes[p..p + size].to_vec(); + p += size + 1; + map.push((String::from_utf8(key).unwrap(), value)); + } + map + } } diff --git a/src/traits.rs b/src/traits.rs index 6baed9f1..b6125bd9 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -119,6 +119,49 @@ pub trait Context { hostcalls::cancel_grpc_call(token_id) } + fn open_grpc_stream( + &self, + cluster_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + ) -> Result { + hostcalls::open_grpc_stream(cluster_name, service_name, method_name, initial_metadata) + } + + fn on_grpc_stream_initial_metadata(&mut self, _token_id: u32, _num_elements: u32) {} + + fn get_grpc_stream_initial_metadata(&self) -> Vec<(String, Vec)> { + hostcalls::get_map_bytes(MapType::GrpcReceiveInitialMetadata).unwrap() + } + + fn send_grpc_stream_message( + &self, + token_id: u32, + message: Option<&[u8]>, + end_stream: bool, + ) -> Result<(), Status> { + hostcalls::send_grpc_stream_message(token_id, message, end_stream) + } + + fn on_grpc_stream_message(&mut self, _token_id: u32, _message_size: usize) {} + + fn get_grpc_stream_message(&mut self, start: usize, max_size: usize) -> Option { + hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, start, max_size).unwrap() + } + + fn on_grpc_stream_trailing_metadata(&mut self, _token_id: u32, _num_elements: u32) {} + + fn get_grpc_stream_trailing_metadata(&self) -> Vec<(String, Vec)> { + hostcalls::get_map_bytes(MapType::GrpcReceiveTrailingMetadata).unwrap() + } + + fn close_grpc_stream(&self, token_id: u32) -> Result<(), Status> { + hostcalls::close_grpc_stream(token_id) + } + + fn on_grpc_stream_close(&mut self, _token_id: u32, _status_code: u32) {} + fn on_done(&mut self) -> bool { true } diff --git a/src/types.rs b/src/types.rs index efdce127..b1e6d55a 100644 --- a/src/types.rs +++ b/src/types.rs @@ -73,6 +73,8 @@ pub enum MapType { HttpRequestTrailers = 1, HttpResponseHeaders = 2, HttpResponseTrailers = 3, + GrpcReceiveInitialMetadata = 4, + GrpcReceiveTrailingMetadata = 5, HttpCallResponseHeaders = 6, HttpCallResponseTrailers = 7, }