@@ -13,14 +13,15 @@ mod timeout;
1313use anyhow:: Result ;
1414use atty:: Stream ;
1515use core:: time:: Duration ;
16- use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
16+ use pyo3:: exceptions:: { PyRuntimeError , PyStopIteration , PyTimeoutError } ;
1717use std:: cmp;
1818use std:: env;
1919use std:: sync:: Arc ;
2020use std:: thread:: available_parallelism;
2121use structopt:: StructOpt ;
2222use tokio:: runtime:: Runtime ;
2323use tokio:: task:: JoinHandle ;
24+ use tokio_stream:: StreamExt ;
2425use tonic:: transport:: Channel ;
2526use tonic:: Status ;
2627
@@ -35,11 +36,13 @@ pub mod torchftpb {
3536use crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ;
3637use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
3738use crate :: torchftpb:: {
38- CheckpointMetadataRequest , LighthouseHeartbeatRequest , LighthouseQuorumRequest ,
39- ManagerQuorumRequest , ShouldCommitRequest ,
39+ CheckpointMetadataRequest , FailureNotification as ProtoFailureNotification ,
40+ LighthouseHeartbeatRequest , LighthouseQuorumRequest , ManagerQuorumRequest , ShouldCommitRequest ,
41+ SubscribeFailuresRequest ,
4042} ;
4143use pyo3:: prelude:: * ;
4244use pyo3:: types:: { PyDict , PyString } ;
45+ use pyo3:: { PyRef , PyRefMut } ;
4346
4447// Get the number of threads to use for the tokio runtime
4548fn num_threads ( ) -> usize {
@@ -290,6 +293,45 @@ struct QuorumResult {
290293 heal : bool ,
291294}
292295
296+ #[ pyclass( unsendable) ]
297+ struct FailureStream {
298+ runtime : Arc < Runtime > ,
299+ stream : tonic:: Streaming < ProtoFailureNotification > ,
300+ timeout : Duration ,
301+ }
302+
303+ #[ pymethods]
304+ impl FailureStream {
305+ fn __iter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
306+ slf
307+ }
308+ fn __next__ ( mut slf : PyRefMut < ' _ , Self > ) -> PyResult < FailureNotification > {
309+ let runtime = slf. runtime . clone ( ) ;
310+ let timeout = slf. timeout ;
311+ // borrow stream mutably for the whole async block
312+ let fut = async { tokio:: time:: timeout ( timeout, slf. stream . next ( ) ) . await } ;
313+
314+ match runtime. block_on ( fut) {
315+ Ok ( Some ( Ok ( note) ) ) => Ok ( FailureNotification {
316+ replica_id : note. replica_id ,
317+ error_message : note. error_message ,
318+ } ) ,
319+ Ok ( Some ( Err ( status) ) ) => Err ( StatusError ( status) . into ( ) ) ,
320+ Ok ( None ) => Err ( PyStopIteration :: new_err ( ( ) ) ) ,
321+ Err ( _) => Err ( PyTimeoutError :: new_err (
322+ "Timeout waiting for failure notification" ,
323+ ) ) ,
324+ }
325+ }
326+ }
327+
328+ #[ pyclass( get_all, set_all) ]
329+ #[ derive( Clone ) ]
330+ struct FailureNotification {
331+ replica_id : String ,
332+ error_message : String ,
333+ }
334+
293335#[ pymethods]
294336impl QuorumResult {
295337 #[ new]
@@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
478520#[ pyclass]
479521struct LighthouseClient {
480522 client : LighthouseServiceClient < Channel > ,
481- runtime : Runtime ,
523+ runtime : Arc < Runtime > ,
482524}
483525
484526#[ pymethods]
@@ -487,11 +529,13 @@ impl LighthouseClient {
487529 #[ new]
488530 fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
489531 py. allow_threads ( move || {
490- let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
491- . worker_threads ( num_threads ( ) )
492- . thread_name ( "torchft-lhclnt" )
493- . enable_all ( )
494- . build ( ) ?;
532+ let runtime = Arc :: new (
533+ tokio:: runtime:: Builder :: new_multi_thread ( )
534+ . worker_threads ( num_threads ( ) )
535+ . thread_name ( "torchft-lhclnt" )
536+ . enable_all ( )
537+ . build ( ) ?,
538+ ) ;
495539 let client = runtime
496540 . block_on ( manager:: lighthouse_client_new ( addr, connect_timeout) )
497541 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -586,6 +630,22 @@ impl LighthouseClient {
586630 Ok ( ( ) )
587631 } )
588632 }
633+
634+ #[ pyo3( signature = ( timeout = Duration :: from_secs( 5 ) ) ) ]
635+ fn subscribe_failures ( & self , py : Python < ' _ > , timeout : Duration ) -> PyResult < FailureStream > {
636+ py. allow_threads ( move || {
637+ let req = tonic:: Request :: new ( SubscribeFailuresRequest { } ) ;
638+ let response = self
639+ . runtime
640+ . block_on ( self . client . clone ( ) . subscribe_failures ( req) )
641+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
642+ Ok ( FailureStream {
643+ runtime : self . runtime . clone ( ) ,
644+ stream : response. into_inner ( ) ,
645+ timeout : timeout,
646+ } )
647+ } )
648+ }
589649}
590650
591651/// LighthouseServer is a GRPC server for the lighthouse service.
@@ -610,7 +670,7 @@ struct LighthouseServer {
610670
611671#[ pymethods]
612672impl LighthouseServer {
613- #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None ) ) ]
673+ #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None , failure_tick_ms= None ) ) ]
614674 #[ new]
615675 fn new (
616676 py : Python < ' _ > ,
@@ -619,10 +679,12 @@ impl LighthouseServer {
619679 join_timeout_ms : Option < u64 > ,
620680 quorum_tick_ms : Option < u64 > ,
621681 heartbeat_timeout_ms : Option < u64 > ,
682+ failure_tick_ms : Option < u64 > ,
622683 ) -> PyResult < Self > {
623684 let join_timeout_ms = join_timeout_ms. unwrap_or ( 100 ) ;
624685 let quorum_tick_ms = quorum_tick_ms. unwrap_or ( 100 ) ;
625686 let heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ;
687+ let failure_tick_ms = failure_tick_ms. unwrap_or ( 1000 ) ;
626688
627689 py. allow_threads ( move || {
628690 let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
@@ -638,6 +700,7 @@ impl LighthouseServer {
638700 join_timeout_ms : join_timeout_ms,
639701 quorum_tick_ms : quorum_tick_ms,
640702 heartbeat_timeout_ms : heartbeat_timeout_ms,
703+ failure_tick_ms : failure_tick_ms,
641704 } ) )
642705 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
643706
@@ -663,6 +726,22 @@ impl LighthouseServer {
663726 self . handle . abort ( ) ;
664727 } )
665728 }
729+
730+ /// inject_failure broadcasts a failure notification for the given replica.
731+ ///
732+ /// This helper is intended for testing `subscribe_failures` from Python.
733+ #[ pyo3( signature = ( replica_id) ) ]
734+ fn inject_failure ( & self , py : Python < ' _ > , replica_id : String ) {
735+ let lighthouse = self . lighthouse . clone ( ) ;
736+ let runtime = & self . _runtime ;
737+ py. allow_threads ( move || {
738+ let _ = runtime. block_on ( async {
739+ if let Err ( e) = lighthouse. inject_failure ( replica_id) . await {
740+ eprintln ! ( "Failed to inject failure: {}" , e) ;
741+ }
742+ } ) ;
743+ } ) ;
744+ }
666745}
667746
668747struct StatusError ( Status ) ;
@@ -750,6 +829,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
750829 m. add_class :: < LighthouseServer > ( ) ?;
751830 m. add_class :: < LighthouseClient > ( ) ?;
752831 m. add_class :: < QuorumResult > ( ) ?;
832+ m. add_class :: < FailureNotification > ( ) ?;
833+ m. add_class :: < FailureStream > ( ) ?;
753834 m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
754835
755836 Ok ( ( ) )
0 commit comments