Skip to content

Commit d7f6d1b

Browse files
MattKotzbauerMatt Kotzbauerd4l3k
authored
Expose heartbeat method (src/lib.rs LighthouseClient) on pyo3 bindings (#176)
* Expose heartbeat method (src/lib.rs LighthouseClient) on pyo3 bindings * fix lint --------- Co-authored-by: Matt Kotzbauer <matt@dhcp-10-250-15-170.harvard.edu> Co-authored-by: Tristan Rice <rice@fn.lc>
1 parent 8c1d175 commit d7f6d1b

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

src/lib.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ pub mod torchftpb {
3535
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
3636
use crate::torchftpb::manager_service_client::ManagerServiceClient;
3737
use crate::torchftpb::{
38-
CheckpointMetadataRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest,
38+
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
39+
ManagerQuorumRequest, ShouldCommitRequest,
3940
};
4041
use pyo3::prelude::*;
4142
use pyo3::types::{PyDict, PyString};
@@ -562,6 +563,26 @@ impl LighthouseClient {
562563
});
563564
Ok(convert_quorum(py, &quorum?)?)
564565
}
566+
567+
/// Send a single heartbeat to the lighthouse.
568+
///
569+
/// Args:
570+
/// replica_id (str): The replica_id you registered with.
571+
/// timeout (timedelta, optional): Per-RPC deadline. Default = 5 s.
572+
#[pyo3(signature = (replica_id, timeout = Duration::from_secs(5)))]
573+
fn heartbeat(
574+
&self,
575+
py: Python<'_>,
576+
replica_id: String,
577+
timeout: Duration,
578+
) -> Result<(), StatusError> {
579+
py.allow_threads(move || {
580+
let mut req = tonic::Request::new(LighthouseHeartbeatRequest { replica_id });
581+
req.set_timeout(timeout);
582+
self.runtime.block_on(self.client.clone().heartbeat(req))?;
583+
Ok(())
584+
})
585+
}
565586
}
566587

567588
/// LighthouseServer is a GRPC server for the lighthouse service.

torchft/_torchft.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,8 @@ class LighthouseClient:
9999
shrink_only: Optional[bool] = None,
100100
data: Optional[dict[Hashable, object]] = None,
101101
) -> Quorum: ...
102+
def heartbeat(
103+
self,
104+
replica_id: str,
105+
timeout: timedelta = timedelta(seconds=5),
106+
) -> None: ...

torchft/lighthouse_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,37 @@ def test_lighthouse_client_behavior(self) -> None:
121121
finally:
122122
# Cleanup
123123
lighthouse.shutdown()
124+
125+
def test_heartbeat_round_trip(self) -> None:
126+
lighthouse = LighthouseServer(
127+
bind="[::]:0",
128+
min_replicas=1,
129+
heartbeat_timeout_ms=200,
130+
)
131+
try:
132+
client = LighthouseClient(
133+
addr=lighthouse.address(),
134+
connect_timeout=timedelta(seconds=1),
135+
)
136+
137+
client.heartbeat("rep0")
138+
139+
# (Should still be alive, as sleep time is less than timeout)
140+
time.sleep(0.15)
141+
q = client.quorum(
142+
replica_id="rep0",
143+
timeout=timedelta(milliseconds=500),
144+
)
145+
assert any(m.replica_id == "rep0" for m in q.participants)
146+
147+
# (Wait long enough for timeout to trigger)
148+
time.sleep(0.25)
149+
# "Probe" with different replica so we don't revive rep0
150+
probe = client.quorum(
151+
replica_id="probe",
152+
timeout=timedelta(milliseconds=500),
153+
)
154+
assert all(m.replica_id != "rep0" for m in probe.participants)
155+
156+
finally:
157+
lighthouse.shutdown()

0 commit comments

Comments
 (0)