Skip to content

Commit aa00ef5

Browse files
authored
Support generic quorum api on LighthouseClient (#150)
* support generic quorum api on LighthouseClient with document updated.
1 parent 2b3cd8d commit aa00ef5

File tree

9 files changed

+365
-7
lines changed

9 files changed

+365
-7
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ gethostname = "0.5.0"
1414
log = "0.4.22"
1515
prost = "0.13.3"
1616
prost-types = "0.13.3"
17-
pyo3 = {version = "0.22.3", features = ["extension-module"]}
17+
pyo3 = {version = "0.24", features = ["extension-module"]}
1818
rand = "0.8.5"
1919
slog = "2.7.0"
2020
slog-stdlog = "4.1.1"

proto/torchft.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ message QuorumMember {
4242
int64 step = 4;
4343
uint64 world_size = 5;
4444
bool shrink_only = 6;
45+
// User passing in data stored as JSON string.
46+
string data = 7;
4547
}
4648

4749
message Quorum {

src/lib.rs

Lines changed: 227 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ pub mod torchftpb {
3232
tonic::include_proto!("torchft");
3333
}
3434

35+
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
3536
use crate::torchftpb::manager_service_client::ManagerServiceClient;
36-
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
37+
use crate::torchftpb::{
38+
CheckpointMetadataRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest,
39+
};
3740
use pyo3::prelude::*;
41+
use pyo3::types::{PyDict, PyString};
3842

3943
// Get the number of threads to use for the tokio runtime
4044
fn num_threads() -> usize {
@@ -304,7 +308,7 @@ impl QuorumResult {
304308
fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
305309
// clear python signal handlers
306310
// signal.signal(signal.SIGINT, signal.SIG_DFL)
307-
let signal = py.import_bound("signal")?;
311+
let signal = py.import("signal")?;
308312
let set_signal = signal.getattr("signal")?;
309313
let args = (signal.getattr("SIGINT")?, signal.getattr("SIG_DFL")?);
310314
set_signal.call1(args)?;
@@ -337,6 +341,217 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
337341
Ok(())
338342
}
339343

344+
/// quorum member of one quorum.
345+
///
346+
/// Args:
347+
/// replica_id (str): The string id of the replica calling quorum.
348+
/// address (str): The address of the replica calling quorum.
349+
/// store_address (str): The address of the store.
350+
/// step (int): The step of the replica calling quorum.
351+
/// world_size (int): The world size of the replica calling quorum.
352+
/// shrink_only (bool): Whether the quorum is for shrinking only.
353+
/// timeout (timedelta): The timeout for quorum.
354+
/// data (dict or None): The data to be passed with quorum.
355+
#[pyclass(get_all, set_all)]
356+
pub struct QuorumMember {
357+
replica_id: String,
358+
address: String,
359+
store_address: String,
360+
step: i64,
361+
world_size: u64,
362+
shrink_only: bool,
363+
data: Option<Py<PyDict>>,
364+
}
365+
366+
impl QuorumMember {
367+
// PyDict has not implemeted Clone, so we need to implement it manually
368+
pub fn clone_with_py(&self, py: Python) -> Self {
369+
QuorumMember {
370+
replica_id: self.replica_id.clone(),
371+
address: self.address.clone(),
372+
store_address: self.store_address.clone(),
373+
step: self.step,
374+
world_size: self.world_size,
375+
shrink_only: self.shrink_only,
376+
data: self.data.as_ref().map(|d| d.clone_ref(py)),
377+
}
378+
}
379+
}
380+
381+
impl Clone for QuorumMember {
382+
fn clone(&self) -> Self {
383+
Python::with_gil(|py| self.clone_with_py(py))
384+
}
385+
}
386+
387+
#[pyclass(get_all, set_all)]
388+
#[derive(Clone)]
389+
pub struct Timestamp {
390+
pub seconds: i64,
391+
pub nanos: i32,
392+
}
393+
394+
/// quorum result.
395+
///
396+
/// Args:
397+
/// quorum_id (int): The id of current quorum.
398+
/// participants (list[QuorumMember]): All members within the quorum.
399+
/// created (timedelta): Time of quorum created in server.
400+
#[pyclass(get_all, set_all)]
401+
struct Quorum {
402+
quorum_id: i64,
403+
participants: Vec<QuorumMember>,
404+
created: Timestamp,
405+
}
406+
407+
impl From<prost_types::Timestamp> for Timestamp {
408+
fn from(ts: prost_types::Timestamp) -> Self {
409+
Timestamp {
410+
seconds: ts.seconds,
411+
nanos: ts.nanos,
412+
}
413+
}
414+
}
415+
416+
// Util functions to convert between python dict and rust string using json.
417+
fn pydict_to_string<'py>(py: Python, data: Option<&Bound<'_, PyDict>>) -> PyResult<String> {
418+
match data {
419+
Some(d) => {
420+
let json = py.import("json")?;
421+
let json_obj = json.call_method1("dumps", (d,))?;
422+
let py_str: &Bound<PyString> = json_obj.downcast()?;
423+
Ok(py_str.to_str()?.to_owned())
424+
}
425+
None => Ok(String::new()),
426+
}
427+
}
428+
429+
fn string_to_pydict(py: Python, s: &str) -> PyResult<Option<Py<PyDict>>> {
430+
if s.is_empty() {
431+
return Ok(None); // Treat empty string as None
432+
}
433+
434+
let json = py.import("json")?;
435+
let obj = json.call_method1("loads", (s,))?;
436+
let dict: &Bound<PyDict> = obj.downcast()?;
437+
Ok(Some(dict.to_owned().into())) // convert Bound<PyDict> -> Py<PyDict>
438+
}
439+
440+
fn convert_quorum_member(py: Python, m: &torchftpb::QuorumMember) -> PyResult<QuorumMember> {
441+
Ok(QuorumMember {
442+
replica_id: m.replica_id.clone(),
443+
address: m.address.clone(),
444+
store_address: m.store_address.clone(),
445+
step: m.step.clone(),
446+
world_size: m.world_size.clone(),
447+
shrink_only: m.shrink_only.clone(),
448+
data: string_to_pydict(py, &m.data)?,
449+
})
450+
}
451+
452+
fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
453+
let participants: Vec<QuorumMember> = q
454+
.participants
455+
.iter()
456+
.map(|m| convert_quorum_member(py, m)) // this expects &m
457+
.collect::<Result<Vec<_>, _>>()?;
458+
459+
Ok(Quorum {
460+
quorum_id: q.quorum_id,
461+
participants: participants,
462+
created: Timestamp::from(q.created.unwrap()),
463+
})
464+
}
465+
466+
/// LighthouseClient is a GRPC client to the lighthouse service.
467+
///
468+
/// It is used to directly communicate with the lighthouse Server.
469+
///
470+
/// Args:
471+
/// addr (str): The HTTP address of the lighthouse server.
472+
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
473+
#[pyclass]
474+
struct LighthouseClient {
475+
client: LighthouseServiceClient<Channel>,
476+
runtime: Runtime,
477+
}
478+
479+
#[pymethods]
480+
impl LighthouseClient {
481+
#[pyo3(signature = (addr, connect_timeout))]
482+
#[new]
483+
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
484+
py.allow_threads(move || {
485+
let runtime = tokio::runtime::Builder::new_multi_thread()
486+
.worker_threads(num_threads())
487+
.thread_name("torchft-lhclnt")
488+
.enable_all()
489+
.build()?;
490+
let client = runtime
491+
.block_on(manager::lighthouse_client_new(addr, connect_timeout))
492+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
493+
Ok(Self {
494+
client: client,
495+
runtime: runtime,
496+
})
497+
})
498+
}
499+
500+
/// quorum sends a request to the lighthouse server to form a quorum.
501+
///
502+
/// Args:
503+
/// replica_id (str): The string id of the replica calling quorum.
504+
/// address (str): The address of the replica calling quorum.
505+
/// store_address (str): The address of the store.
506+
/// step (int): The step of the replica calling quorum.
507+
/// world_size (int): The world size of the replica calling quorum.
508+
/// shrink_only (bool): Whether the quorum is for shrinking only.
509+
/// timeout (timedelta): The timeout for quorum.
510+
/// data (Optional[dict]): The data to be passed with quorum.
511+
///
512+
/// Returns:
513+
/// Quorum: Current quorum if successful.
514+
fn quorum<'py>(
515+
&self,
516+
py: Python<'_>,
517+
replica_id: String,
518+
address: String,
519+
store_address: String,
520+
step: i64,
521+
world_size: u64,
522+
shrink_only: bool,
523+
timeout: Duration,
524+
data: Option<&Bound<'_, PyDict>>,
525+
) -> Result<Quorum, StatusError> {
526+
let data_string = pydict_to_string(py, data)?;
527+
let quorum: Result<torchftpb::Quorum, StatusError> = py.allow_threads(move || {
528+
let mut request = tonic::Request::new(LighthouseQuorumRequest {
529+
requester: Some(torchftpb::QuorumMember {
530+
replica_id: replica_id,
531+
address: address,
532+
store_address: store_address,
533+
step: step,
534+
world_size: world_size,
535+
shrink_only: shrink_only,
536+
data: data_string,
537+
}),
538+
});
539+
540+
// This timeout is processed on the server side so we also enable
541+
// keep alives to detect server health.
542+
request.set_timeout(timeout);
543+
544+
let response = self.runtime.block_on(self.client.clone().quorum(request))?;
545+
let resp = response.into_inner();
546+
let quorum = resp
547+
.quorum
548+
.ok_or_else(|| Status::internal("missing quorum"))?;
549+
Ok(quorum)
550+
});
551+
Ok(convert_quorum(py, &quorum?)?)
552+
}
553+
}
554+
340555
/// LighthouseServer is a GRPC server for the lighthouse service.
341556
///
342557
/// It is used to coordinate the ManagerServer for each replica group.
@@ -428,6 +643,12 @@ impl From<StatusError> for PyErr {
428643
}
429644
}
430645

646+
impl From<pyo3::PyErr> for StatusError {
647+
fn from(err: pyo3::PyErr) -> Self {
648+
StatusError(Status::internal(err.to_string()))
649+
}
650+
}
651+
431652
impl From<Status> for StatusError {
432653
fn from(other: Status) -> Self {
433654
Self(other)
@@ -479,9 +700,13 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
479700
// setup logging on import
480701
setup_logging().map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
481702

703+
m.add_class::<Timestamp>()?;
704+
m.add_class::<QuorumMember>()?;
705+
m.add_class::<Quorum>()?;
482706
m.add_class::<ManagerServer>()?;
483707
m.add_class::<ManagerClient>()?;
484708
m.add_class::<LighthouseServer>()?;
709+
m.add_class::<LighthouseClient>()?;
485710
m.add_class::<QuorumResult>()?;
486711
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;
487712

0 commit comments

Comments
 (0)