@@ -42,6 +42,10 @@ pub(crate) fn register_grpc_callout(token_id: u32) {
42
42
DISPATCHER . with ( |dispatcher| dispatcher. register_grpc_callout ( token_id) ) ;
43
43
}
44
44
45
+ pub ( crate ) fn register_grpc_stream ( token_id : u32 ) {
46
+ DISPATCHER . with ( |dispatcher| dispatcher. register_grpc_stream ( token_id) ) ;
47
+ }
48
+
45
49
struct NoopRoot ;
46
50
47
51
impl Context for NoopRoot { }
@@ -57,6 +61,7 @@ struct Dispatcher {
57
61
active_id : Cell < u32 > ,
58
62
callouts : RefCell < HashMap < u32 , u32 > > ,
59
63
grpc_callouts : RefCell < HashMap < u32 , u32 > > ,
64
+ grpc_streams : RefCell < HashMap < u32 , u32 > > ,
60
65
}
61
66
62
67
impl Dispatcher {
@@ -71,6 +76,7 @@ impl Dispatcher {
71
76
active_id : Cell :: new ( 0 ) ,
72
77
callouts : RefCell :: new ( HashMap :: new ( ) ) ,
73
78
grpc_callouts : RefCell :: new ( HashMap :: new ( ) ) ,
79
+ grpc_streams : RefCell :: new ( HashMap :: new ( ) ) ,
74
80
}
75
81
}
76
82
@@ -97,6 +103,17 @@ impl Dispatcher {
97
103
}
98
104
}
99
105
106
+ fn register_grpc_stream ( & self , token_id : u32 ) {
107
+ if self
108
+ . grpc_streams
109
+ . borrow_mut ( )
110
+ . insert ( token_id, self . active_id . get ( ) )
111
+ . is_some ( )
112
+ {
113
+ panic ! ( "duplicate token_id" )
114
+ }
115
+ }
116
+
100
117
fn register_grpc_callout ( & self , token_id : u32 ) {
101
118
if self
102
119
. grpc_callouts
@@ -399,47 +416,116 @@ impl Dispatcher {
399
416
}
400
417
}
401
418
402
- fn on_grpc_receive ( & self , token_id : u32 , response_size : usize ) {
403
- let context_id = self
404
- . grpc_callouts
419
+ fn on_grpc_receive_initial_metadata ( & self , token_id : u32 , headers : u32 ) {
420
+ let context_id = * self
421
+ . grpc_streams
405
422
. borrow_mut ( )
406
- . remove ( & token_id)
423
+ . get ( & token_id)
407
424
. expect ( "invalid token_id" ) ;
408
425
409
426
if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
410
427
self . active_id . set ( context_id) ;
411
428
hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
412
- http_stream. on_grpc_call_response ( token_id, 0 , response_size ) ;
429
+ http_stream. on_grpc_stream_initial_metadata ( token_id, headers ) ;
413
430
} else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
414
431
self . active_id . set ( context_id) ;
415
432
hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
416
- stream. on_grpc_call_response ( token_id, 0 , response_size ) ;
433
+ stream. on_grpc_stream_initial_metadata ( token_id, headers ) ;
417
434
} else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
418
435
self . active_id . set ( context_id) ;
419
436
hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
420
- root. on_grpc_call_response ( token_id, 0 , response_size ) ;
437
+ root. on_grpc_stream_initial_metadata ( token_id, headers ) ;
421
438
}
422
439
}
423
440
424
- fn on_grpc_close ( & self , token_id : u32 , status_code : u32 ) {
425
- let context_id = self
426
- . grpc_callouts
441
+ fn on_grpc_receive ( & self , token_id : u32 , response_size : usize ) {
442
+ if let Some ( context_id) = self . grpc_callouts . borrow_mut ( ) . remove ( & token_id) {
443
+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
444
+ self . active_id . set ( context_id) ;
445
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
446
+ http_stream. on_grpc_call_response ( token_id, 0 , response_size) ;
447
+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
448
+ self . active_id . set ( context_id) ;
449
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
450
+ stream. on_grpc_call_response ( token_id, 0 , response_size) ;
451
+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
452
+ self . active_id . set ( context_id) ;
453
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
454
+ root. on_grpc_call_response ( token_id, 0 , response_size) ;
455
+ }
456
+ } else if let Some ( context_id) = self . grpc_streams . borrow_mut ( ) . get ( & token_id) {
457
+ let context_id = * context_id;
458
+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
459
+ self . active_id . set ( context_id) ;
460
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
461
+ http_stream. on_grpc_stream_message ( token_id, response_size) ;
462
+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
463
+ self . active_id . set ( context_id) ;
464
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
465
+ stream. on_grpc_stream_message ( token_id, response_size) ;
466
+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
467
+ self . active_id . set ( context_id) ;
468
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
469
+ root. on_grpc_stream_message ( token_id, response_size) ;
470
+ }
471
+ } else {
472
+ panic ! ( "invalid token_id" )
473
+ }
474
+ }
475
+
476
+ fn on_grpc_receive_trailing_metadata ( & self , token_id : u32 , trailers : u32 ) {
477
+ let context_id = * self
478
+ . grpc_streams
427
479
. borrow_mut ( )
428
- . remove ( & token_id)
480
+ . get ( & token_id)
429
481
. expect ( "invalid token_id" ) ;
430
482
431
483
if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
432
484
self . active_id . set ( context_id) ;
433
485
hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
434
- http_stream. on_grpc_call_response ( token_id, status_code , 0 ) ;
486
+ http_stream. on_grpc_stream_trailing_metadata ( token_id, trailers ) ;
435
487
} else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
436
488
self . active_id . set ( context_id) ;
437
489
hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
438
- stream. on_grpc_call_response ( token_id, status_code , 0 ) ;
490
+ stream. on_grpc_stream_trailing_metadata ( token_id, trailers ) ;
439
491
} else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
440
492
self . active_id . set ( context_id) ;
441
493
hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
442
- root. on_grpc_call_response ( token_id, status_code, 0 ) ;
494
+ root. on_grpc_stream_trailing_metadata ( token_id, trailers) ;
495
+ }
496
+ }
497
+
498
+ fn on_grpc_close ( & self , token_id : u32 , status_code : u32 ) {
499
+ if let Some ( context_id) = self . grpc_callouts . borrow_mut ( ) . remove ( & token_id) {
500
+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
501
+ self . active_id . set ( context_id) ;
502
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
503
+ http_stream. on_grpc_call_response ( token_id, status_code, 0 ) ;
504
+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
505
+ self . active_id . set ( context_id) ;
506
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
507
+ stream. on_grpc_call_response ( token_id, status_code, 0 ) ;
508
+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
509
+ self . active_id . set ( context_id) ;
510
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
511
+ root. on_grpc_call_response ( token_id, status_code, 0 ) ;
512
+ }
513
+ } else if let Some ( context_id) = self . grpc_streams . borrow_mut ( ) . remove ( & token_id) {
514
+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
515
+ self . active_id . set ( context_id) ;
516
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
517
+ http_stream. on_grpc_stream_close ( token_id, status_code)
518
+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
519
+ self . active_id . set ( context_id) ;
520
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
521
+ stream. on_grpc_stream_close ( token_id, status_code)
522
+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
523
+ self . active_id . set ( context_id) ;
524
+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
525
+ root. on_grpc_stream_close ( token_id, status_code)
526
+ }
527
+ } else {
528
+ panic ! ( "invalid token_id" )
443
529
}
444
530
}
445
531
}
@@ -571,11 +657,29 @@ pub extern "C" fn proxy_on_http_call_response(
571
657
} )
572
658
}
573
659
660
+ #[ no_mangle]
661
+ pub extern "C" fn proxy_on_grpc_receive_initial_metadata (
662
+ _context_id : u32 ,
663
+ token_id : u32 ,
664
+ headers : u32 ,
665
+ ) {
666
+ DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_receive_initial_metadata ( token_id, headers) )
667
+ }
668
+
574
669
#[ no_mangle]
575
670
pub extern "C" fn proxy_on_grpc_receive ( _context_id : u32 , token_id : u32 , response_size : usize ) {
576
671
DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_receive ( token_id, response_size) )
577
672
}
578
673
674
+ #[ no_mangle]
675
+ pub extern "C" fn proxy_on_grpc_receive_trailing_metadata (
676
+ _context_id : u32 ,
677
+ token_id : u32 ,
678
+ trailers : u32 ,
679
+ ) {
680
+ DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_receive_trailing_metadata ( token_id, trailers) )
681
+ }
682
+
579
683
#[ no_mangle]
580
684
pub extern "C" fn proxy_on_grpc_close ( _context_id : u32 , token_id : u32 , status_code : u32 ) {
581
685
DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_close ( token_id, status_code) )
0 commit comments