@@ -13,14 +13,15 @@ mod timeout;
1313use anyhow:: Result ;
1414use atty:: Stream ;
1515use core:: time:: Duration ;
16- use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
16+ use pyo3:: exceptions:: { PyRuntimeError , PyStopIteration , PyTimeoutError } ;
1717use std:: cmp;
1818use std:: env;
1919use std:: sync:: Arc ;
2020use std:: thread:: available_parallelism;
2121use structopt:: StructOpt ;
2222use tokio:: runtime:: Runtime ;
2323use tokio:: task:: JoinHandle ;
24+ use tokio_stream:: StreamExt ;
2425use tonic:: transport:: Channel ;
2526use tonic:: Status ;
2627
@@ -35,11 +36,13 @@ pub mod torchftpb {
3536use crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ;
3637use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
3738use crate :: torchftpb:: {
38- CheckpointMetadataRequest , LighthouseHeartbeatRequest , LighthouseQuorumRequest ,
39- ManagerQuorumRequest , ShouldCommitRequest ,
39+ CheckpointMetadataRequest , FailureNotification as ProtoFailureNotification ,
40+ LighthouseHeartbeatRequest , LighthouseQuorumRequest , ManagerQuorumRequest , ShouldCommitRequest ,
41+ SubscribeFailuresRequest ,
4042} ;
4143use pyo3:: prelude:: * ;
4244use pyo3:: types:: { PyDict , PyString } ;
45+ use pyo3:: { PyRef , PyRefMut } ;
4346
4447// Get the number of threads to use for the tokio runtime
4548fn num_threads ( ) -> usize {
@@ -290,6 +293,43 @@ struct QuorumResult {
290293 heal : bool ,
291294}
292295
296+ #[ pyclass( unsendable) ]
297+ struct FailureStream {
298+ runtime : Arc < Runtime > ,
299+ stream : tonic:: Streaming < ProtoFailureNotification > ,
300+ timeout : Duration ,
301+ }
302+
303+ #[ pymethods]
304+ impl FailureStream {
305+ fn __iter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
306+ slf
307+ }
308+ fn __next__ ( mut slf : PyRefMut < ' _ , Self > ) -> PyResult < FailureNotification > {
309+ let runtime = slf. runtime . clone ( ) ;
310+ let timeout = slf. timeout ;
311+ // borrow stream mutably for the whole async block
312+ let fut = async { tokio:: time:: timeout ( timeout, slf. stream . next ( ) ) . await } ;
313+
314+ match runtime. block_on ( fut) {
315+ Ok ( Some ( Ok ( note) ) ) => Ok ( FailureNotification {
316+ replica_id : note. replica_id ,
317+ } ) ,
318+ Ok ( Some ( Err ( status) ) ) => Err ( StatusError ( status) . into ( ) ) ,
319+ Ok ( None ) => Err ( PyStopIteration :: new_err ( ( ) ) ) ,
320+ Err ( _) => Err ( PyTimeoutError :: new_err (
321+ "Timeout waiting for failure notification" ,
322+ ) ) ,
323+ }
324+ }
325+ }
326+
327+ #[ pyclass( get_all, set_all) ]
328+ #[ derive( Clone ) ]
329+ struct FailureNotification {
330+ replica_id : String ,
331+ }
332+
293333#[ pymethods]
294334impl QuorumResult {
295335 #[ new]
@@ -396,6 +436,12 @@ pub struct Timestamp {
396436 pub nanos : i32 ,
397437}
398438
439+ #[ pyclass( get_all, set_all) ]
440+ #[ derive( Clone ) ]
441+ struct FailureNotificationPy {
442+ replica_id : String ,
443+ }
444+
399445/// quorum result.
400446///
401447/// Args:
@@ -478,7 +524,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
478524#[ pyclass]
479525struct LighthouseClient {
480526 client : LighthouseServiceClient < Channel > ,
481- runtime : Runtime ,
527+ runtime : Arc < Runtime > ,
482528}
483529
484530#[ pymethods]
@@ -487,11 +533,13 @@ impl LighthouseClient {
487533 #[ new]
488534 fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
489535 py. allow_threads ( move || {
490- let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
491- . worker_threads ( num_threads ( ) )
492- . thread_name ( "torchft-lhclnt" )
493- . enable_all ( )
494- . build ( ) ?;
536+ let runtime = Arc :: new (
537+ tokio:: runtime:: Builder :: new_multi_thread ( )
538+ . worker_threads ( num_threads ( ) )
539+ . thread_name ( "torchft-lhclnt" )
540+ . enable_all ( )
541+ . build ( ) ?,
542+ ) ;
495543 let client = runtime
496544 . block_on ( manager:: lighthouse_client_new ( addr, connect_timeout) )
497545 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -586,6 +634,22 @@ impl LighthouseClient {
586634 Ok ( ( ) )
587635 } )
588636 }
637+
638+ #[ pyo3( signature = ( timeout = Duration :: from_secs( 5 ) ) ) ]
639+ fn subscribe_failures ( & self , py : Python < ' _ > , timeout : Duration ) -> PyResult < FailureStream > {
640+ py. allow_threads ( move || {
641+ let req = tonic:: Request :: new ( SubscribeFailuresRequest { } ) ;
642+ let response = self
643+ . runtime
644+ . block_on ( self . client . clone ( ) . subscribe_failures ( req) )
645+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
646+ Ok ( FailureStream {
647+ runtime : self . runtime . clone ( ) ,
648+ stream : response. into_inner ( ) ,
649+ timeout : timeout,
650+ } )
651+ } )
652+ }
589653}
590654
591655/// LighthouseServer is a GRPC server for the lighthouse service.
@@ -610,7 +674,7 @@ struct LighthouseServer {
610674
611675#[ pymethods]
612676impl LighthouseServer {
613- #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None ) ) ]
677+ #[ pyo3( signature = ( bind, min_replicas, join_timeout_ms=None , quorum_tick_ms=None , heartbeat_timeout_ms=None , failure_tick_ms= None ) ) ]
614678 #[ new]
615679 fn new (
616680 py : Python < ' _ > ,
@@ -619,10 +683,12 @@ impl LighthouseServer {
619683 join_timeout_ms : Option < u64 > ,
620684 quorum_tick_ms : Option < u64 > ,
621685 heartbeat_timeout_ms : Option < u64 > ,
686+ failure_tick_ms : Option < u64 > ,
622687 ) -> PyResult < Self > {
623688 let join_timeout_ms = join_timeout_ms. unwrap_or ( 100 ) ;
624689 let quorum_tick_ms = quorum_tick_ms. unwrap_or ( 100 ) ;
625690 let heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ;
691+ let failure_tick_ms = failure_tick_ms. unwrap_or ( 1000 ) ;
626692
627693 py. allow_threads ( move || {
628694 let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
@@ -638,6 +704,7 @@ impl LighthouseServer {
638704 join_timeout_ms : join_timeout_ms,
639705 quorum_tick_ms : quorum_tick_ms,
640706 heartbeat_timeout_ms : heartbeat_timeout_ms,
707+ failure_tick_ms : failure_tick_ms,
641708 } ) )
642709 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
643710
@@ -663,6 +730,22 @@ impl LighthouseServer {
663730 self . handle . abort ( ) ;
664731 } )
665732 }
733+
734+ /// inject_failure broadcasts a failure notification for the given replica.
735+ ///
736+ /// This helper is intended for testing `subscribe_failures` from Python.
737+ #[ pyo3( signature = ( replica_id) ) ]
738+ fn inject_failure ( & self , py : Python < ' _ > , replica_id : String ) {
739+ let lighthouse = self . lighthouse . clone ( ) ;
740+ let runtime = & self . _runtime ;
741+ py. allow_threads ( move || {
742+ let _ = runtime. block_on ( async {
743+ if let Err ( e) = lighthouse. inject_failure ( replica_id) . await {
744+ eprintln ! ( "Failed to inject failure: {}" , e) ;
745+ }
746+ } ) ;
747+ } ) ;
748+ }
666749}
667750
668751struct StatusError ( Status ) ;
@@ -750,6 +833,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
750833 m. add_class :: < LighthouseServer > ( ) ?;
751834 m. add_class :: < LighthouseClient > ( ) ?;
752835 m. add_class :: < QuorumResult > ( ) ?;
836+ m. add_class :: < FailureNotificationPy > ( ) ?;
837+ m. add_class :: < FailureStream > ( ) ?;
753838 m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
754839
755840 Ok ( ( ) )
0 commit comments