Skip to content

Commit

Permalink
Add support for gRPC streams. (#101)
Browse files Browse the repository at this point in the history
Signed-off-by: Shikugawa <Shikugawa@gmail.com>
  • Loading branch information
Shikugawa authored May 19, 2021
1 parent 30066a7 commit c94f6e4
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 16 deletions.
132 changes: 118 additions & 14 deletions src/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ pub(crate) fn register_grpc_callout(token_id: u32) {
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_callout(token_id));
}

pub(crate) fn register_grpc_stream(token_id: u32) {
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_stream(token_id));
}

struct NoopRoot;

impl Context for NoopRoot {}
Expand All @@ -57,6 +61,7 @@ struct Dispatcher {
active_id: Cell<u32>,
callouts: RefCell<HashMap<u32, u32>>,
grpc_callouts: RefCell<HashMap<u32, u32>>,
grpc_streams: RefCell<HashMap<u32, u32>>,
}

impl Dispatcher {
Expand All @@ -71,6 +76,7 @@ impl Dispatcher {
active_id: Cell::new(0),
callouts: RefCell::new(HashMap::new()),
grpc_callouts: RefCell::new(HashMap::new()),
grpc_streams: RefCell::new(HashMap::new()),
}
}

Expand All @@ -97,6 +103,17 @@ impl Dispatcher {
}
}

fn register_grpc_stream(&self, token_id: u32) {
if self
.grpc_streams
.borrow_mut()
.insert(token_id, self.active_id.get())
.is_some()
{
panic!("duplicate token_id")
}
}

fn register_grpc_callout(&self, token_id: u32) {
if self
.grpc_callouts
Expand Down Expand Up @@ -399,47 +416,116 @@ impl Dispatcher {
}
}

fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
let context_id = self
.grpc_callouts
fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) {
let context_id = *self
.grpc_streams
.borrow_mut()
.remove(&token_id)
.get(&token_id)
.expect("invalid token_id");

if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, 0, response_size);
http_stream.on_grpc_stream_initial_metadata(token_id, headers);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, 0, response_size);
stream.on_grpc_stream_initial_metadata(token_id, headers);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, 0, response_size);
root.on_grpc_stream_initial_metadata(token_id, headers);
}
}

fn on_grpc_close(&self, token_id: u32, status_code: u32) {
let context_id = self
.grpc_callouts
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, 0, response_size);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, 0, response_size);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, 0, response_size);
}
} else if let Some(context_id) = self.grpc_streams.borrow_mut().get(&token_id) {
let context_id = *context_id;
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_message(token_id, response_size);
}
} else {
panic!("invalid token_id")
}
}

fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) {
let context_id = *self
.grpc_streams
.borrow_mut()
.remove(&token_id)
.get(&token_id)
.expect("invalid token_id");

if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, status_code, 0);
http_stream.on_grpc_stream_trailing_metadata(token_id, trailers);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, status_code, 0);
stream.on_grpc_stream_trailing_metadata(token_id, trailers);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, status_code, 0);
root.on_grpc_stream_trailing_metadata(token_id, trailers);
}
}

fn on_grpc_close(&self, token_id: u32, status_code: u32) {
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, status_code, 0);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, status_code, 0);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, status_code, 0);
}
} else if let Some(context_id) = self.grpc_streams.borrow_mut().remove(&token_id) {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_close(token_id, status_code)
}
} else {
panic!("invalid token_id")
}
}
}
Expand Down Expand Up @@ -571,11 +657,29 @@ pub extern "C" fn proxy_on_http_call_response(
})
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_receive_initial_metadata(
_context_id: u32,
token_id: u32,
headers: u32,
) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_initial_metadata(token_id, headers))
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size))
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_receive_trailing_metadata(
_context_id: u32,
token_id: u32,
trailers: u32,
) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_trailing_metadata(token_id, trailers))
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_close(_context_id: u32, token_id: u32, status_code: u32) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_close(token_id, status_code))
Expand Down
131 changes: 129 additions & 2 deletions src/hostcalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,24 @@ pub fn get_map(map_type: MapType) -> Result<Vec<(String, String)>, Status> {
}
}

pub fn get_map_bytes(map_type: MapType) -> Result<Vec<(String, Vec<u8>)>, Status> {
unsafe {
let mut return_data: *mut u8 = null_mut();
let mut return_size: usize = 0;
match proxy_get_header_map_pairs(map_type, &mut return_data, &mut return_size) {
Status::Ok => {
if !return_data.is_null() {
let serialized_map = Vec::from_raw_parts(return_data, return_size, return_size);
Ok(utils::deserialize_bytes_map(&serialized_map))
} else {
Ok(Vec::new())
}
}
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_set_header_map_pairs(
map_type: MapType,
Expand Down Expand Up @@ -677,7 +695,7 @@ pub fn dispatch_grpc_call(
timeout: Duration,
) -> Result<u32, Status> {
let mut return_callout_id = 0;
let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata);
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
unsafe {
match proxy_grpc_call(
upstream_name.as_ptr(),
Expand All @@ -704,6 +722,80 @@ pub fn dispatch_grpc_call(
}
}

extern "C" {
fn proxy_grpc_stream(
upstream_data: *const u8,
upstream_size: usize,
service_name_data: *const u8,
service_name_size: usize,
method_name_data: *const u8,
method_name_size: usize,
initial_metadata_data: *const u8,
initial_metadata_size: usize,
return_stream_id: *mut u32,
) -> Status;
}

pub fn open_grpc_stream(
upstream_name: &str,
service_name: &str,
method_name: &str,
initial_metadata: Vec<(&str, &[u8])>,
) -> Result<u32, Status> {
let mut return_stream_id = 0;
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
unsafe {
match proxy_grpc_stream(
upstream_name.as_ptr(),
upstream_name.len(),
service_name.as_ptr(),
service_name.len(),
method_name.as_ptr(),
method_name.len(),
serialized_initial_metadata.as_ptr(),
serialized_initial_metadata.len(),
&mut return_stream_id,
) {
Status::Ok => {
dispatcher::register_grpc_stream(return_stream_id);
Ok(return_stream_id)
}
Status::ParseFailure => Err(Status::ParseFailure),
Status::InternalFailure => Err(Status::InternalFailure),
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_grpc_send(
token: u32,
message_ptr: *const u8,
message_len: usize,
end_stream: bool,
) -> Status;
}

pub fn send_grpc_stream_message(
token: u32,
message: Option<&[u8]>,
end_stream: bool,
) -> Result<(), Status> {
unsafe {
match proxy_grpc_send(
token,
message.map_or(null(), |message| message.as_ptr()),
message.map_or(0, |message| message.len()),
end_stream,
) {
Status::Ok => Ok(()),
Status::BadArgument => Err(Status::BadArgument),
Status::NotFound => Err(Status::NotFound),
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_grpc_cancel(token_id: u32) -> Status;
}
Expand All @@ -718,6 +810,20 @@ pub fn cancel_grpc_call(token_id: u32) -> Result<(), Status> {
}
}

extern "C" {
fn proxy_grpc_close(token_id: u32) -> Status;
}

pub fn close_grpc_stream(token_id: u32) -> Result<(), Status> {
unsafe {
match proxy_grpc_close(token_id) {
Status::Ok => Ok(()),
Status::NotFound => Err(Status::NotFound),
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_set_effective_context(context_id: u32) -> Status;
}
Expand Down Expand Up @@ -850,7 +956,7 @@ mod utils {
bytes
}

pub(super) fn serialize_bytes_value_map(map: Vec<(&str, &[u8])>) -> Bytes {
pub(super) fn serialize_bytes_map(map: Vec<(&str, &[u8])>) -> Bytes {
let mut size: usize = 4;
for (name, value) in &map {
size += name.len() + value.len() + 10;
Expand Down Expand Up @@ -893,4 +999,25 @@ mod utils {
}
map
}

pub(super) fn deserialize_bytes_map(bytes: &[u8]) -> Vec<(String, Vec<u8>)> {
let mut map = Vec::new();
if bytes.is_empty() {
return map;
}
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[0..4]).unwrap()) as usize;
let mut p = 4 + size * 8;
for n in 0..size {
let s = 4 + n * 8;
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s..s + 4]).unwrap()) as usize;
let key = bytes[p..p + size].to_vec();
p += size + 1;
let size =
u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s + 4..s + 8]).unwrap()) as usize;
let value = bytes[p..p + size].to_vec();
p += size + 1;
map.push((String::from_utf8(key).unwrap(), value));
}
map
}
}
Loading

0 comments on commit c94f6e4

Please sign in to comment.