@@ -32,9 +32,13 @@ pub mod torchftpb {
32
32
tonic:: include_proto!( "torchft" ) ;
33
33
}
34
34
35
+ use crate :: torchftpb:: lighthouse_service_client:: LighthouseServiceClient ;
35
36
use crate :: torchftpb:: manager_service_client:: ManagerServiceClient ;
36
- use crate :: torchftpb:: { CheckpointMetadataRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
37
+ use crate :: torchftpb:: {
38
+ CheckpointMetadataRequest , LighthouseQuorumRequest , ManagerQuorumRequest , ShouldCommitRequest ,
39
+ } ;
37
40
use pyo3:: prelude:: * ;
41
+ use pyo3:: types:: { PyDict , PyString } ;
38
42
39
43
// Get the number of threads to use for the tokio runtime
40
44
fn num_threads ( ) -> usize {
@@ -304,7 +308,7 @@ impl QuorumResult {
304
308
fn reset_python_signals ( py : Python < ' _ > ) -> PyResult < ( ) > {
305
309
// clear python signal handlers
306
310
// signal.signal(signal.SIGINT, signal.SIG_DFL)
307
- let signal = py. import_bound ( "signal" ) ?;
311
+ let signal = py. import ( "signal" ) ?;
308
312
let set_signal = signal. getattr ( "signal" ) ?;
309
313
let args = ( signal. getattr ( "SIGINT" ) ?, signal. getattr ( "SIG_DFL" ) ?) ;
310
314
set_signal. call1 ( args) ?;
@@ -337,6 +341,217 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
337
341
Ok ( ( ) )
338
342
}
339
343
344
+ /// quorum member of one quorum.
345
+ ///
346
+ /// Args:
347
+ /// replica_id (str): The string id of the replica calling quorum.
348
+ /// address (str): The address of the replica calling quorum.
349
+ /// store_address (str): The address of the store.
350
+ /// step (int): The step of the replica calling quorum.
351
+ /// world_size (int): The world size of the replica calling quorum.
352
+ /// shrink_only (bool): Whether the quorum is for shrinking only.
353
+ /// timeout (timedelta): The timeout for quorum.
354
+ /// data (dict or None): The data to be passed with quorum.
355
+ #[ pyclass( get_all, set_all) ]
356
+ pub struct QuorumMember {
357
+ replica_id : String ,
358
+ address : String ,
359
+ store_address : String ,
360
+ step : i64 ,
361
+ world_size : u64 ,
362
+ shrink_only : bool ,
363
+ data : Option < Py < PyDict > > ,
364
+ }
365
+
366
+ impl QuorumMember {
367
+ // PyDict has not implemeted Clone, so we need to implement it manually
368
+ pub fn clone_with_py ( & self , py : Python ) -> Self {
369
+ QuorumMember {
370
+ replica_id : self . replica_id . clone ( ) ,
371
+ address : self . address . clone ( ) ,
372
+ store_address : self . store_address . clone ( ) ,
373
+ step : self . step ,
374
+ world_size : self . world_size ,
375
+ shrink_only : self . shrink_only ,
376
+ data : self . data . as_ref ( ) . map ( |d| d. clone_ref ( py) ) ,
377
+ }
378
+ }
379
+ }
380
+
381
+ impl Clone for QuorumMember {
382
+ fn clone ( & self ) -> Self {
383
+ Python :: with_gil ( |py| self . clone_with_py ( py) )
384
+ }
385
+ }
386
+
387
+ #[ pyclass( get_all, set_all) ]
388
+ #[ derive( Clone ) ]
389
+ pub struct Timestamp {
390
+ pub seconds : i64 ,
391
+ pub nanos : i32 ,
392
+ }
393
+
394
+ /// quorum result.
395
+ ///
396
+ /// Args:
397
+ /// quorum_id (int): The id of current quorum.
398
+ /// participants (list[QuorumMember]): All members within the quorum.
399
+ /// created (timedelta): Time of quorum created in server.
400
+ #[ pyclass( get_all, set_all) ]
401
+ struct Quorum {
402
+ quorum_id : i64 ,
403
+ participants : Vec < QuorumMember > ,
404
+ created : Timestamp ,
405
+ }
406
+
407
+ impl From < prost_types:: Timestamp > for Timestamp {
408
+ fn from ( ts : prost_types:: Timestamp ) -> Self {
409
+ Timestamp {
410
+ seconds : ts. seconds ,
411
+ nanos : ts. nanos ,
412
+ }
413
+ }
414
+ }
415
+
416
+ // Util functions to convert between python dict and rust string using json.
417
+ fn pydict_to_string < ' py > ( py : Python , data : Option < & Bound < ' _ , PyDict > > ) -> PyResult < String > {
418
+ match data {
419
+ Some ( d) => {
420
+ let json = py. import ( "json" ) ?;
421
+ let json_obj = json. call_method1 ( "dumps" , ( d, ) ) ?;
422
+ let py_str: & Bound < PyString > = json_obj. downcast ( ) ?;
423
+ Ok ( py_str. to_str ( ) ?. to_owned ( ) )
424
+ }
425
+ None => Ok ( String :: new ( ) ) ,
426
+ }
427
+ }
428
+
429
+ fn string_to_pydict ( py : Python , s : & str ) -> PyResult < Option < Py < PyDict > > > {
430
+ if s. is_empty ( ) {
431
+ return Ok ( None ) ; // Treat empty string as None
432
+ }
433
+
434
+ let json = py. import ( "json" ) ?;
435
+ let obj = json. call_method1 ( "loads" , ( s, ) ) ?;
436
+ let dict: & Bound < PyDict > = obj. downcast ( ) ?;
437
+ Ok ( Some ( dict. to_owned ( ) . into ( ) ) ) // convert Bound<PyDict> -> Py<PyDict>
438
+ }
439
+
440
+ fn convert_quorum_member ( py : Python , m : & torchftpb:: QuorumMember ) -> PyResult < QuorumMember > {
441
+ Ok ( QuorumMember {
442
+ replica_id : m. replica_id . clone ( ) ,
443
+ address : m. address . clone ( ) ,
444
+ store_address : m. store_address . clone ( ) ,
445
+ step : m. step . clone ( ) ,
446
+ world_size : m. world_size . clone ( ) ,
447
+ shrink_only : m. shrink_only . clone ( ) ,
448
+ data : string_to_pydict ( py, & m. data ) ?,
449
+ } )
450
+ }
451
+
452
+ fn convert_quorum ( py : Python , q : & torchftpb:: Quorum ) -> PyResult < Quorum > {
453
+ let participants: Vec < QuorumMember > = q
454
+ . participants
455
+ . iter ( )
456
+ . map ( |m| convert_quorum_member ( py, m) ) // this expects &m
457
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
458
+
459
+ Ok ( Quorum {
460
+ quorum_id : q. quorum_id ,
461
+ participants : participants,
462
+ created : Timestamp :: from ( q. created . unwrap ( ) ) ,
463
+ } )
464
+ }
465
+
466
+ /// LighthouseClient is a GRPC client to the lighthouse service.
467
+ ///
468
+ /// It is used to directly communicate with the lighthouse Server.
469
+ ///
470
+ /// Args:
471
+ /// addr (str): The HTTP address of the lighthouse server.
472
+ /// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
473
+ #[ pyclass]
474
+ struct LighthouseClient {
475
+ client : LighthouseServiceClient < Channel > ,
476
+ runtime : Runtime ,
477
+ }
478
+
479
+ #[ pymethods]
480
+ impl LighthouseClient {
481
+ #[ pyo3( signature = ( addr, connect_timeout) ) ]
482
+ #[ new]
483
+ fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
484
+ py. allow_threads ( move || {
485
+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
486
+ . worker_threads ( num_threads ( ) )
487
+ . thread_name ( "torchft-lhclnt" )
488
+ . enable_all ( )
489
+ . build ( ) ?;
490
+ let client = runtime
491
+ . block_on ( manager:: lighthouse_client_new ( addr, connect_timeout) )
492
+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
493
+ Ok ( Self {
494
+ client : client,
495
+ runtime : runtime,
496
+ } )
497
+ } )
498
+ }
499
+
500
+ /// quorum sends a request to the lighthouse server to form a quorum.
501
+ ///
502
+ /// Args:
503
+ /// replica_id (str): The string id of the replica calling quorum.
504
+ /// address (str): The address of the replica calling quorum.
505
+ /// store_address (str): The address of the store.
506
+ /// step (int): The step of the replica calling quorum.
507
+ /// world_size (int): The world size of the replica calling quorum.
508
+ /// shrink_only (bool): Whether the quorum is for shrinking only.
509
+ /// timeout (timedelta): The timeout for quorum.
510
+ /// data (Optional[dict]): The data to be passed with quorum.
511
+ ///
512
+ /// Returns:
513
+ /// Quorum: Current quorum if successful.
514
+ fn quorum < ' py > (
515
+ & self ,
516
+ py : Python < ' _ > ,
517
+ replica_id : String ,
518
+ address : String ,
519
+ store_address : String ,
520
+ step : i64 ,
521
+ world_size : u64 ,
522
+ shrink_only : bool ,
523
+ timeout : Duration ,
524
+ data : Option < & Bound < ' _ , PyDict > > ,
525
+ ) -> Result < Quorum , StatusError > {
526
+ let data_string = pydict_to_string ( py, data) ?;
527
+ let quorum: Result < torchftpb:: Quorum , StatusError > = py. allow_threads ( move || {
528
+ let mut request = tonic:: Request :: new ( LighthouseQuorumRequest {
529
+ requester : Some ( torchftpb:: QuorumMember {
530
+ replica_id : replica_id,
531
+ address : address,
532
+ store_address : store_address,
533
+ step : step,
534
+ world_size : world_size,
535
+ shrink_only : shrink_only,
536
+ data : data_string,
537
+ } ) ,
538
+ } ) ;
539
+
540
+ // This timeout is processed on the server side so we also enable
541
+ // keep alives to detect server health.
542
+ request. set_timeout ( timeout) ;
543
+
544
+ let response = self . runtime . block_on ( self . client . clone ( ) . quorum ( request) ) ?;
545
+ let resp = response. into_inner ( ) ;
546
+ let quorum = resp
547
+ . quorum
548
+ . ok_or_else ( || Status :: internal ( "missing quorum" ) ) ?;
549
+ Ok ( quorum)
550
+ } ) ;
551
+ Ok ( convert_quorum ( py, & quorum?) ?)
552
+ }
553
+ }
554
+
340
555
/// LighthouseServer is a GRPC server for the lighthouse service.
341
556
///
342
557
/// It is used to coordinate the ManagerServer for each replica group.
@@ -428,6 +643,12 @@ impl From<StatusError> for PyErr {
428
643
}
429
644
}
430
645
646
+ impl From < pyo3:: PyErr > for StatusError {
647
+ fn from ( err : pyo3:: PyErr ) -> Self {
648
+ StatusError ( Status :: internal ( err. to_string ( ) ) )
649
+ }
650
+ }
651
+
431
652
impl From < Status > for StatusError {
432
653
fn from ( other : Status ) -> Self {
433
654
Self ( other)
@@ -479,9 +700,13 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
479
700
// setup logging on import
480
701
setup_logging ( ) . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
481
702
703
+ m. add_class :: < Timestamp > ( ) ?;
704
+ m. add_class :: < QuorumMember > ( ) ?;
705
+ m. add_class :: < Quorum > ( ) ?;
482
706
m. add_class :: < ManagerServer > ( ) ?;
483
707
m. add_class :: < ManagerClient > ( ) ?;
484
708
m. add_class :: < LighthouseServer > ( ) ?;
709
+ m. add_class :: < LighthouseClient > ( ) ?;
485
710
m. add_class :: < QuorumResult > ( ) ?;
486
711
m. add_function ( wrap_pyfunction ! ( lighthouse_main, m) ?) ?;
487
712
0 commit comments