Skip to content

Add support for gRPC streams. #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 118 additions & 14 deletions src/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ pub(crate) fn register_grpc_callout(token_id: u32) {
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_callout(token_id));
}

pub(crate) fn register_grpc_stream(token_id: u32) {
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_stream(token_id));
}

struct NoopRoot;

impl Context for NoopRoot {}
Expand All @@ -57,6 +61,7 @@ struct Dispatcher {
active_id: Cell<u32>,
callouts: RefCell<HashMap<u32, u32>>,
grpc_callouts: RefCell<HashMap<u32, u32>>,
grpc_streams: RefCell<HashMap<u32, u32>>,
}

impl Dispatcher {
Expand All @@ -71,6 +76,7 @@ impl Dispatcher {
active_id: Cell::new(0),
callouts: RefCell::new(HashMap::new()),
grpc_callouts: RefCell::new(HashMap::new()),
grpc_streams: RefCell::new(HashMap::new()),
}
}

Expand All @@ -97,6 +103,17 @@ impl Dispatcher {
}
}

fn register_grpc_stream(&self, token_id: u32) {
if self
.grpc_streams
.borrow_mut()
.insert(token_id, self.active_id.get())
.is_some()
{
panic!("duplicate token_id")
}
}

fn register_grpc_callout(&self, token_id: u32) {
if self
.grpc_callouts
Expand Down Expand Up @@ -399,47 +416,116 @@ impl Dispatcher {
}
}

fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
let context_id = self
.grpc_callouts
fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) {
let context_id = *self
.grpc_streams
.borrow_mut()
.remove(&token_id)
.get(&token_id)
.expect("invalid token_id");

if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, 0, response_size);
http_stream.on_grpc_stream_initial_metadata(token_id, headers);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, 0, response_size);
stream.on_grpc_stream_initial_metadata(token_id, headers);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, 0, response_size);
root.on_grpc_stream_initial_metadata(token_id, headers);
}
}

fn on_grpc_close(&self, token_id: u32, status_code: u32) {
let context_id = self
.grpc_callouts
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, 0, response_size);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, 0, response_size);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, 0, response_size);
}
} else if let Some(context_id) = self.grpc_streams.borrow_mut().get(&token_id) {
let context_id = *context_id;
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_message(token_id, response_size);
}
} else {
panic!("invalid token_id")
}
}

fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) {
let context_id = *self
.grpc_streams
.borrow_mut()
.remove(&token_id)
.get(&token_id)
.expect("invalid token_id");

if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, status_code, 0);
http_stream.on_grpc_stream_trailing_metadata(token_id, trailers);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, status_code, 0);
stream.on_grpc_stream_trailing_metadata(token_id, trailers);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, status_code, 0);
root.on_grpc_stream_trailing_metadata(token_id, trailers);
}
}

fn on_grpc_close(&self, token_id: u32, status_code: u32) {
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_call_response(token_id, status_code, 0);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_call_response(token_id, status_code, 0);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, status_code, 0);
}
} else if let Some(context_id) = self.grpc_streams.borrow_mut().remove(&token_id) {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_close(token_id, status_code)
}
} else {
panic!("invalid token_id")
}
}
}
Expand Down Expand Up @@ -571,11 +657,29 @@ pub extern "C" fn proxy_on_http_call_response(
})
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_receive_initial_metadata(
_context_id: u32,
token_id: u32,
headers: u32,
) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_initial_metadata(token_id, headers))
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size))
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_receive_trailing_metadata(
_context_id: u32,
token_id: u32,
trailers: u32,
) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_trailing_metadata(token_id, trailers))
}

#[no_mangle]
pub extern "C" fn proxy_on_grpc_close(_context_id: u32, token_id: u32, status_code: u32) {
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_close(token_id, status_code))
Expand Down
131 changes: 129 additions & 2 deletions src/hostcalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,24 @@ pub fn get_map(map_type: MapType) -> Result<Vec<(String, String)>, Status> {
}
}

pub fn get_map_bytes(map_type: MapType) -> Result<Vec<(String, Vec<u8>)>, Status> {
unsafe {
let mut return_data: *mut u8 = null_mut();
let mut return_size: usize = 0;
match proxy_get_header_map_pairs(map_type, &mut return_data, &mut return_size) {
Status::Ok => {
if !return_data.is_null() {
let serialized_map = Vec::from_raw_parts(return_data, return_size, return_size);
Ok(utils::deserialize_bytes_map(&serialized_map))
} else {
Ok(Vec::new())
}
}
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_set_header_map_pairs(
map_type: MapType,
Expand Down Expand Up @@ -677,7 +695,7 @@ pub fn dispatch_grpc_call(
timeout: Duration,
) -> Result<u32, Status> {
let mut return_callout_id = 0;
let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata);
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
unsafe {
match proxy_grpc_call(
upstream_name.as_ptr(),
Expand All @@ -704,6 +722,80 @@ pub fn dispatch_grpc_call(
}
}

extern "C" {
fn proxy_grpc_stream(
upstream_data: *const u8,
upstream_size: usize,
service_name_data: *const u8,
service_name_size: usize,
method_name_data: *const u8,
method_name_size: usize,
initial_metadata_data: *const u8,
initial_metadata_size: usize,
return_stream_id: *mut u32,
) -> Status;
}

pub fn open_grpc_stream(
upstream_name: &str,
service_name: &str,
method_name: &str,
initial_metadata: Vec<(&str, &[u8])>,
) -> Result<u32, Status> {
let mut return_stream_id = 0;
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
unsafe {
match proxy_grpc_stream(
upstream_name.as_ptr(),
upstream_name.len(),
service_name.as_ptr(),
service_name.len(),
method_name.as_ptr(),
method_name.len(),
serialized_initial_metadata.as_ptr(),
serialized_initial_metadata.len(),
&mut return_stream_id,
) {
Status::Ok => {
dispatcher::register_grpc_stream(return_stream_id);
Ok(return_stream_id)
}
Status::ParseFailure => Err(Status::ParseFailure),
Status::InternalFailure => Err(Status::InternalFailure),
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_grpc_send(
token: u32,
message_ptr: *const u8,
message_len: usize,
end_stream: bool,
) -> Status;
}

pub fn send_grpc_stream_message(
token: u32,
message: Option<&[u8]>,
end_stream: bool,
) -> Result<(), Status> {
unsafe {
match proxy_grpc_send(
token,
message.map_or(null(), |message| message.as_ptr()),
message.map_or(0, |message| message.len()),
end_stream,
) {
Status::Ok => Ok(()),
Status::BadArgument => Err(Status::BadArgument),
Status::NotFound => Err(Status::NotFound),
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_grpc_cancel(token_id: u32) -> Status;
}
Expand All @@ -718,6 +810,20 @@ pub fn cancel_grpc_call(token_id: u32) -> Result<(), Status> {
}
}

extern "C" {
fn proxy_grpc_close(token_id: u32) -> Status;
}

pub fn close_grpc_stream(token_id: u32) -> Result<(), Status> {
unsafe {
match proxy_grpc_close(token_id) {
Status::Ok => Ok(()),
Status::NotFound => Err(Status::NotFound),
status => panic!("unexpected status: {}", status as u32),
}
}
}

extern "C" {
fn proxy_set_effective_context(context_id: u32) -> Status;
}
Expand Down Expand Up @@ -850,7 +956,7 @@ mod utils {
bytes
}

pub(super) fn serialize_bytes_value_map(map: Vec<(&str, &[u8])>) -> Bytes {
pub(super) fn serialize_bytes_map(map: Vec<(&str, &[u8])>) -> Bytes {
let mut size: usize = 4;
for (name, value) in &map {
size += name.len() + value.len() + 10;
Expand Down Expand Up @@ -893,4 +999,25 @@ mod utils {
}
map
}

pub(super) fn deserialize_bytes_map(bytes: &[u8]) -> Vec<(String, Vec<u8>)> {
let mut map = Vec::new();
if bytes.is_empty() {
return map;
}
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[0..4]).unwrap()) as usize;
let mut p = 4 + size * 8;
for n in 0..size {
let s = 4 + n * 8;
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s..s + 4]).unwrap()) as usize;
let key = bytes[p..p + size].to_vec();
p += size + 1;
let size =
u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s + 4..s + 8]).unwrap()) as usize;
let value = bytes[p..p + size].to_vec();
p += size + 1;
map.push((String::from_utf8(key).unwrap(), value));
}
map
}
}
Loading