Skip to content

Commit c94f6e4

Browse files
authored
Add support for gRPC streams. (#101)
Signed-off-by: Shikugawa <Shikugawa@gmail.com>
1 parent 30066a7 commit c94f6e4

File tree

4 files changed

+292
-16
lines changed

4 files changed

+292
-16
lines changed

src/dispatcher.rs

Lines changed: 118 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ pub(crate) fn register_grpc_callout(token_id: u32) {
4242
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_callout(token_id));
4343
}
4444

45+
pub(crate) fn register_grpc_stream(token_id: u32) {
46+
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_stream(token_id));
47+
}
48+
4549
struct NoopRoot;
4650

4751
impl Context for NoopRoot {}
@@ -57,6 +61,7 @@ struct Dispatcher {
5761
active_id: Cell<u32>,
5862
callouts: RefCell<HashMap<u32, u32>>,
5963
grpc_callouts: RefCell<HashMap<u32, u32>>,
64+
grpc_streams: RefCell<HashMap<u32, u32>>,
6065
}
6166

6267
impl Dispatcher {
@@ -71,6 +76,7 @@ impl Dispatcher {
7176
active_id: Cell::new(0),
7277
callouts: RefCell::new(HashMap::new()),
7378
grpc_callouts: RefCell::new(HashMap::new()),
79+
grpc_streams: RefCell::new(HashMap::new()),
7480
}
7581
}
7682

@@ -97,6 +103,17 @@ impl Dispatcher {
97103
}
98104
}
99105

106+
fn register_grpc_stream(&self, token_id: u32) {
107+
if self
108+
.grpc_streams
109+
.borrow_mut()
110+
.insert(token_id, self.active_id.get())
111+
.is_some()
112+
{
113+
panic!("duplicate token_id")
114+
}
115+
}
116+
100117
fn register_grpc_callout(&self, token_id: u32) {
101118
if self
102119
.grpc_callouts
@@ -399,47 +416,116 @@ impl Dispatcher {
399416
}
400417
}
401418

402-
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
403-
let context_id = self
404-
.grpc_callouts
419+
fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) {
420+
let context_id = *self
421+
.grpc_streams
405422
.borrow_mut()
406-
.remove(&token_id)
423+
.get(&token_id)
407424
.expect("invalid token_id");
408425

409426
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
410427
self.active_id.set(context_id);
411428
hostcalls::set_effective_context(context_id).unwrap();
412-
http_stream.on_grpc_call_response(token_id, 0, response_size);
429+
http_stream.on_grpc_stream_initial_metadata(token_id, headers);
413430
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
414431
self.active_id.set(context_id);
415432
hostcalls::set_effective_context(context_id).unwrap();
416-
stream.on_grpc_call_response(token_id, 0, response_size);
433+
stream.on_grpc_stream_initial_metadata(token_id, headers);
417434
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
418435
self.active_id.set(context_id);
419436
hostcalls::set_effective_context(context_id).unwrap();
420-
root.on_grpc_call_response(token_id, 0, response_size);
437+
root.on_grpc_stream_initial_metadata(token_id, headers);
421438
}
422439
}
423440

424-
fn on_grpc_close(&self, token_id: u32, status_code: u32) {
425-
let context_id = self
426-
.grpc_callouts
441+
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
442+
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
443+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
444+
self.active_id.set(context_id);
445+
hostcalls::set_effective_context(context_id).unwrap();
446+
http_stream.on_grpc_call_response(token_id, 0, response_size);
447+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
448+
self.active_id.set(context_id);
449+
hostcalls::set_effective_context(context_id).unwrap();
450+
stream.on_grpc_call_response(token_id, 0, response_size);
451+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
452+
self.active_id.set(context_id);
453+
hostcalls::set_effective_context(context_id).unwrap();
454+
root.on_grpc_call_response(token_id, 0, response_size);
455+
}
456+
} else if let Some(context_id) = self.grpc_streams.borrow_mut().get(&token_id) {
457+
let context_id = *context_id;
458+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
459+
self.active_id.set(context_id);
460+
hostcalls::set_effective_context(context_id).unwrap();
461+
http_stream.on_grpc_stream_message(token_id, response_size);
462+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
463+
self.active_id.set(context_id);
464+
hostcalls::set_effective_context(context_id).unwrap();
465+
stream.on_grpc_stream_message(token_id, response_size);
466+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
467+
self.active_id.set(context_id);
468+
hostcalls::set_effective_context(context_id).unwrap();
469+
root.on_grpc_stream_message(token_id, response_size);
470+
}
471+
} else {
472+
panic!("invalid token_id")
473+
}
474+
}
475+
476+
fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) {
477+
let context_id = *self
478+
.grpc_streams
427479
.borrow_mut()
428-
.remove(&token_id)
480+
.get(&token_id)
429481
.expect("invalid token_id");
430482

431483
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
432484
self.active_id.set(context_id);
433485
hostcalls::set_effective_context(context_id).unwrap();
434-
http_stream.on_grpc_call_response(token_id, status_code, 0);
486+
http_stream.on_grpc_stream_trailing_metadata(token_id, trailers);
435487
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
436488
self.active_id.set(context_id);
437489
hostcalls::set_effective_context(context_id).unwrap();
438-
stream.on_grpc_call_response(token_id, status_code, 0);
490+
stream.on_grpc_stream_trailing_metadata(token_id, trailers);
439491
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
440492
self.active_id.set(context_id);
441493
hostcalls::set_effective_context(context_id).unwrap();
442-
root.on_grpc_call_response(token_id, status_code, 0);
494+
root.on_grpc_stream_trailing_metadata(token_id, trailers);
495+
}
496+
}
497+
498+
fn on_grpc_close(&self, token_id: u32, status_code: u32) {
499+
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
500+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
501+
self.active_id.set(context_id);
502+
hostcalls::set_effective_context(context_id).unwrap();
503+
http_stream.on_grpc_call_response(token_id, status_code, 0);
504+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
505+
self.active_id.set(context_id);
506+
hostcalls::set_effective_context(context_id).unwrap();
507+
stream.on_grpc_call_response(token_id, status_code, 0);
508+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
509+
self.active_id.set(context_id);
510+
hostcalls::set_effective_context(context_id).unwrap();
511+
root.on_grpc_call_response(token_id, status_code, 0);
512+
}
513+
} else if let Some(context_id) = self.grpc_streams.borrow_mut().remove(&token_id) {
514+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
515+
self.active_id.set(context_id);
516+
hostcalls::set_effective_context(context_id).unwrap();
517+
http_stream.on_grpc_stream_close(token_id, status_code)
518+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
519+
self.active_id.set(context_id);
520+
hostcalls::set_effective_context(context_id).unwrap();
521+
stream.on_grpc_stream_close(token_id, status_code)
522+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
523+
self.active_id.set(context_id);
524+
hostcalls::set_effective_context(context_id).unwrap();
525+
root.on_grpc_stream_close(token_id, status_code)
526+
}
527+
} else {
528+
panic!("invalid token_id")
443529
}
444530
}
445531
}
@@ -571,11 +657,29 @@ pub extern "C" fn proxy_on_http_call_response(
571657
})
572658
}
573659

660+
#[no_mangle]
661+
pub extern "C" fn proxy_on_grpc_receive_initial_metadata(
662+
_context_id: u32,
663+
token_id: u32,
664+
headers: u32,
665+
) {
666+
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_initial_metadata(token_id, headers))
667+
}
668+
574669
#[no_mangle]
575670
pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) {
576671
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size))
577672
}
578673

674+
#[no_mangle]
675+
pub extern "C" fn proxy_on_grpc_receive_trailing_metadata(
676+
_context_id: u32,
677+
token_id: u32,
678+
trailers: u32,
679+
) {
680+
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_trailing_metadata(token_id, trailers))
681+
}
682+
579683
#[no_mangle]
580684
pub extern "C" fn proxy_on_grpc_close(_context_id: u32, token_id: u32, status_code: u32) {
581685
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_close(token_id, status_code))

src/hostcalls.rs

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,24 @@ pub fn get_map(map_type: MapType) -> Result<Vec<(String, String)>, Status> {
177177
}
178178
}
179179

180+
pub fn get_map_bytes(map_type: MapType) -> Result<Vec<(String, Vec<u8>)>, Status> {
181+
unsafe {
182+
let mut return_data: *mut u8 = null_mut();
183+
let mut return_size: usize = 0;
184+
match proxy_get_header_map_pairs(map_type, &mut return_data, &mut return_size) {
185+
Status::Ok => {
186+
if !return_data.is_null() {
187+
let serialized_map = Vec::from_raw_parts(return_data, return_size, return_size);
188+
Ok(utils::deserialize_bytes_map(&serialized_map))
189+
} else {
190+
Ok(Vec::new())
191+
}
192+
}
193+
status => panic!("unexpected status: {}", status as u32),
194+
}
195+
}
196+
}
197+
180198
extern "C" {
181199
fn proxy_set_header_map_pairs(
182200
map_type: MapType,
@@ -677,7 +695,7 @@ pub fn dispatch_grpc_call(
677695
timeout: Duration,
678696
) -> Result<u32, Status> {
679697
let mut return_callout_id = 0;
680-
let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata);
698+
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
681699
unsafe {
682700
match proxy_grpc_call(
683701
upstream_name.as_ptr(),
@@ -704,6 +722,80 @@ pub fn dispatch_grpc_call(
704722
}
705723
}
706724

725+
extern "C" {
726+
fn proxy_grpc_stream(
727+
upstream_data: *const u8,
728+
upstream_size: usize,
729+
service_name_data: *const u8,
730+
service_name_size: usize,
731+
method_name_data: *const u8,
732+
method_name_size: usize,
733+
initial_metadata_data: *const u8,
734+
initial_metadata_size: usize,
735+
return_stream_id: *mut u32,
736+
) -> Status;
737+
}
738+
739+
pub fn open_grpc_stream(
740+
upstream_name: &str,
741+
service_name: &str,
742+
method_name: &str,
743+
initial_metadata: Vec<(&str, &[u8])>,
744+
) -> Result<u32, Status> {
745+
let mut return_stream_id = 0;
746+
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
747+
unsafe {
748+
match proxy_grpc_stream(
749+
upstream_name.as_ptr(),
750+
upstream_name.len(),
751+
service_name.as_ptr(),
752+
service_name.len(),
753+
method_name.as_ptr(),
754+
method_name.len(),
755+
serialized_initial_metadata.as_ptr(),
756+
serialized_initial_metadata.len(),
757+
&mut return_stream_id,
758+
) {
759+
Status::Ok => {
760+
dispatcher::register_grpc_stream(return_stream_id);
761+
Ok(return_stream_id)
762+
}
763+
Status::ParseFailure => Err(Status::ParseFailure),
764+
Status::InternalFailure => Err(Status::InternalFailure),
765+
status => panic!("unexpected status: {}", status as u32),
766+
}
767+
}
768+
}
769+
770+
extern "C" {
771+
fn proxy_grpc_send(
772+
token: u32,
773+
message_ptr: *const u8,
774+
message_len: usize,
775+
end_stream: bool,
776+
) -> Status;
777+
}
778+
779+
pub fn send_grpc_stream_message(
780+
token: u32,
781+
message: Option<&[u8]>,
782+
end_stream: bool,
783+
) -> Result<(), Status> {
784+
unsafe {
785+
match proxy_grpc_send(
786+
token,
787+
message.map_or(null(), |message| message.as_ptr()),
788+
message.map_or(0, |message| message.len()),
789+
end_stream,
790+
) {
791+
Status::Ok => Ok(()),
792+
Status::BadArgument => Err(Status::BadArgument),
793+
Status::NotFound => Err(Status::NotFound),
794+
status => panic!("unexpected status: {}", status as u32),
795+
}
796+
}
797+
}
798+
707799
extern "C" {
708800
fn proxy_grpc_cancel(token_id: u32) -> Status;
709801
}
@@ -718,6 +810,20 @@ pub fn cancel_grpc_call(token_id: u32) -> Result<(), Status> {
718810
}
719811
}
720812

813+
extern "C" {
814+
fn proxy_grpc_close(token_id: u32) -> Status;
815+
}
816+
817+
pub fn close_grpc_stream(token_id: u32) -> Result<(), Status> {
818+
unsafe {
819+
match proxy_grpc_close(token_id) {
820+
Status::Ok => Ok(()),
821+
Status::NotFound => Err(Status::NotFound),
822+
status => panic!("unexpected status: {}", status as u32),
823+
}
824+
}
825+
}
826+
721827
extern "C" {
722828
fn proxy_set_effective_context(context_id: u32) -> Status;
723829
}
@@ -850,7 +956,7 @@ mod utils {
850956
bytes
851957
}
852958

853-
pub(super) fn serialize_bytes_value_map(map: Vec<(&str, &[u8])>) -> Bytes {
959+
pub(super) fn serialize_bytes_map(map: Vec<(&str, &[u8])>) -> Bytes {
854960
let mut size: usize = 4;
855961
for (name, value) in &map {
856962
size += name.len() + value.len() + 10;
@@ -893,4 +999,25 @@ mod utils {
893999
}
8941000
map
8951001
}
1002+
1003+
pub(super) fn deserialize_bytes_map(bytes: &[u8]) -> Vec<(String, Vec<u8>)> {
1004+
let mut map = Vec::new();
1005+
if bytes.is_empty() {
1006+
return map;
1007+
}
1008+
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[0..4]).unwrap()) as usize;
1009+
let mut p = 4 + size * 8;
1010+
for n in 0..size {
1011+
let s = 4 + n * 8;
1012+
let size = u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s..s + 4]).unwrap()) as usize;
1013+
let key = bytes[p..p + size].to_vec();
1014+
p += size + 1;
1015+
let size =
1016+
u32::from_le_bytes(<[u8; 4]>::try_from(&bytes[s + 4..s + 8]).unwrap()) as usize;
1017+
let value = bytes[p..p + size].to_vec();
1018+
p += size + 1;
1019+
map.push((String::from_utf8(key).unwrap(), value));
1020+
}
1021+
map
1022+
}
8961023
}

0 commit comments

Comments
 (0)