66
77pub mod lighthouse;
88pub mod manager;
9+ pub mod router;
910mod net;
1011mod retry;
11- mod router;
1212mod timeout;
13+ mod interceptor;
1314
1415pub use crate :: router:: Router ;
1516
@@ -25,8 +26,9 @@ use structopt::StructOpt;
2526use tokio:: runtime:: Runtime ;
2627use tokio:: task:: JoinHandle ;
2728use tokio_stream:: wrappers:: TcpListenerStream ;
28- use tonic:: transport:: Channel ;
29+ use tonic:: transport:: { Channel , Endpoint } ;
2930use tonic:: Status ;
31+ use tonic:: service:: interceptor:: InterceptedService ;
3032
3133use chrono:: Local ;
3234use fern:: colors:: { Color , ColoredLevelConfig } ;
@@ -36,6 +38,7 @@ pub mod torchftpb {
3638 tonic:: include_proto!( "torchft" ) ;
3739}
3840
41+ use crate :: interceptor:: RoomIdInterceptor ;
3942use crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ;
4043use crate :: torchftpb:: lighthouse_service_server:: LighthouseServiceServer ;
4144use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
@@ -486,9 +489,10 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
486489/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
487490#[ pyclass]
488491struct LighthouseClient {
489- client : LighthouseServiceClient < Channel > ,
492+ client : LighthouseServiceClient <
493+ InterceptedService < Channel , RoomIdInterceptor >
494+ > ,
490495 runtime : Runtime ,
491- room_id : Option < String > ,
492496}
493497
494498#[ pymethods]
@@ -507,14 +511,25 @@ impl LighthouseClient {
507511 . thread_name ( "torchft-lhclnt" )
508512 . enable_all ( )
509513 . build ( ) ?;
510- let client = runtime
511- . block_on ( manager :: lighthouse_client_new ( addr, connect_timeout ) )
514+
515+ let endpoint = Endpoint :: from_shared ( addr. clone ( ) )
512516 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
513- Ok ( Self {
514- client : client,
515- runtime : runtime,
516- room_id : room_id,
517- } )
517+ let channel = runtime
518+ . block_on (
519+ endpoint
520+ . connect_timeout ( connect_timeout)
521+ . connect ( ) ,
522+ )
523+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
524+
525+ let interceptor =
526+ RoomIdInterceptor :: new ( room_id. unwrap_or_else ( || "default" . to_owned ( ) ) ) ;
527+
528+ let client =
529+ LighthouseServiceClient :: with_interceptor ( channel, interceptor) ;
530+
531+ Ok ( Self { client, runtime } )
532+
518533 } )
519534 }
520535
@@ -569,8 +584,6 @@ impl LighthouseClient {
569584 } ) ,
570585 } ) ;
571586
572- let mut request = self . add_room_header ( request) ;
573-
574587 // This timeout is processed on the server side so we also enable
575588 // keep alives to detect server health.
576589 request. set_timeout ( timeout) ;
@@ -599,29 +612,13 @@ impl LighthouseClient {
599612 ) -> Result < ( ) , StatusError > {
600613 py. allow_threads ( move || {
601614 let mut req = tonic:: Request :: new ( LighthouseHeartbeatRequest { replica_id } ) ;
602- let mut req = self . add_room_header ( req) ;
603615 req. set_timeout ( timeout) ;
604616 self . runtime . block_on ( self . client . clone ( ) . heartbeat ( req) ) ?;
605617 Ok ( ( ) )
606618 } )
607619 }
608620}
609621
610- impl LighthouseClient {
611- /// Attach `"room-id"` header if `self.room_id` is Some(_)
612- fn add_room_header < T > ( & self , mut req : tonic:: Request < T > ) -> tonic:: Request < T > {
613- if let Some ( ref id) = self . room_id {
614- use tonic:: metadata:: MetadataValue ;
615- req. metadata_mut ( ) . insert (
616- crate :: router:: ROOM_ID_HEADER ,
617- MetadataValue :: try_from ( id. as_str ( ) ) . expect ( "room-id ascii" ) ,
618- ) ;
619- }
620- req
621- }
622-
623- }
624-
625622/// LighthouseServer is a GRPC server for the lighthouse service.
626623///
627624/// It is used to coordinate the ManagerServer for each replica group.
0 commit comments