Skip to content

Commit 273d3ee

Browse files
Matt KotzbauerMatt Kotzbauer
authored andcommitted
Edits to tower-based routing: src/router.rs room return type changed to Arc<Lighthouse>, Lighthouse::new now takes id prefix, test relocated to lighthouse_test.py and now uses coordination API, LighthouseServer now resolves host/port from the bound socket to give a routable http://host:port address
1 parent 0a9ce34 commit 273d3ee

File tree

6 files changed

+53
-74
lines changed

6 files changed

+53
-74
lines changed

src/interceptor.rs~

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/lib.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub use crate::router::Router;
1717
use anyhow::Result;
1818
use atty::Stream;
1919
use core::time::Duration;
20+
use gethostname::gethostname;
2021
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
2122
use std::cmp;
2223
use std::env;
@@ -664,7 +665,6 @@ impl LighthouseServer {
664665

665666
let listener = rt.block_on(tokio::net::TcpListener::bind(&bind))?;
666667
let bound_sock = listener.local_addr()?;
667-
let bound = format!("http://{}", bound_sock);
668668
let incoming = TcpListenerStream::new(listener);
669669
let router = Router::new(opt.clone());
670670

@@ -676,8 +676,15 @@ impl LighthouseServer {
676676
.map_err(|e: tonic::transport::Error| anyhow::anyhow!(e))
677677
});
678678

679+
let host = if bind.starts_with("0.0.0.0") || bind.starts_with("[::]") {
680+
gethostname().to_string_lossy().into_owned()
681+
} else {
682+
bind.rsplit_once(':').map(|(h, _)| h.to_string()).unwrap()
683+
};
684+
let public_addr = format!("http://{}:{}", host, bound_sock.port());
685+
679686
Ok(Self {
680-
bind: bound,
687+
bind: public_addr,
681688
handle,
682689
_runtime: rt,
683690
})

src/lighthouse.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ struct State {
5858
}
5959

6060
pub struct Lighthouse {
61+
id: String,
6162
state: Mutex<State>,
6263
opt: LighthouseOpt,
6364
listener: Mutex<Option<tokio::net::TcpListener>>,
@@ -261,12 +262,13 @@ fn quorum_compute(
261262
}
262263

263264
impl Lighthouse {
264-
pub async fn new(opt: LighthouseOpt) -> Result<Arc<Self>> {
265+
pub async fn new(id: String, opt: LighthouseOpt) -> Result<Arc<Self>> {
265266
let listener = tokio::net::TcpListener::bind(&opt.bind).await?;
266267

267268
let (tx, _) = broadcast::channel(16);
268269

269270
Ok(Arc::new(Self {
271+
id: id,
270272
state: Mutex::new(State {
271273
participants: HashMap::new(),
272274
channel: tx,

src/router.rs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use dashmap::{mapref::entry::Entry, DashMap};
1010
use futures::FutureExt;
1111
use tonic::{
1212
body::BoxBody,
13-
codegen::http::{HeaderMap, Request, Response}, // http-0.2 types
13+
codegen::http::{HeaderMap, Request, Response},
1414
server::NamedService,
1515
};
1616
use tower::Service;
@@ -28,7 +28,7 @@ type GrpcSvc = LighthouseServiceServer<Arc<Lighthouse>>;
2828

2929
#[derive(Clone)]
3030
pub struct Router {
31-
rooms: Arc<DashMap<String, Arc<GrpcSvc>>>,
31+
rooms: Arc<DashMap<String, Arc<Lighthouse>>>,
3232
tmpl_opt: LighthouseOpt,
3333
}
3434

@@ -47,26 +47,23 @@ impl Router {
4747
}
4848

4949
async fn room_service(
50-
rooms: Arc<DashMap<String, Arc<GrpcSvc>>>,
50+
rooms: Arc<DashMap<String, Arc<Lighthouse>>>,
5151
tmpl: LighthouseOpt,
5252
id: &str,
53-
) -> Arc<GrpcSvc> {
54-
if let Some(svc) = rooms.get(id) {
55-
return svc.clone();
53+
) -> Arc<Lighthouse> {
54+
if let Some(lh) = rooms.get(id) {
55+
return lh.clone();
5656
}
5757

58-
// Build room state once.
59-
let lh = Lighthouse::new(tmpl.clone())
58+
let lh = Lighthouse::new(id.to_owned(), tmpl.clone())
6059
.await
6160
.expect("failed to create Lighthouse");
6261

63-
let svc_new = Arc::new(LighthouseServiceServer::new(lh));
64-
6562
match rooms.entry(id.to_owned()) {
6663
Entry::Occupied(e) => e.get().clone(),
6764
Entry::Vacant(v) => {
68-
v.insert(svc_new.clone());
69-
svc_new
65+
v.insert(lh.clone());
66+
lh
7067
}
7168
}
7269
}
@@ -89,10 +86,9 @@ impl Service<Request<BoxBody>> for Router {
8986
let room = Self::room_id(req.headers()).to_owned();
9087

9188
async move {
92-
let svc_arc = Self::room_service(rooms, tmpl, &room).await;
89+
let lh = Self::room_service(rooms, tmpl, &room).await;
9390

94-
// `Arc<GrpcSvc>` itself isn’t a Service; clone the inner value.
95-
let mut svc = (*svc_arc).clone();
91+
let mut svc = LighthouseServiceServer::new(lh);
9692
let resp = svc
9793
.call(req)
9894
.await

torchft/lighthouse_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch.distributed as dist
66

7+
import torchft.coordination as cd
78
from torchft import Manager, ProcessGroupGloo
89
from torchft._torchft import LighthouseClient, LighthouseServer, Quorum, QuorumMember
910

@@ -155,3 +156,32 @@ def test_heartbeat_round_trip(self) -> None:
155156

156157
finally:
157158
lighthouse.shutdown()
159+
160+
def test_multi_room_quorums(self) -> None:
161+
"""One server, two logical rooms should yield two isolated quorums."""
162+
server = cd.LighthouseServer(bind="[::]:0", min_replicas=1)
163+
addr = server.address()
164+
165+
try:
166+
# Two clients in two independent rooms
167+
cli_a = cd.LighthouseClient(addr, timedelta(seconds=1), room_id="jobA")
168+
cli_b = cd.LighthouseClient(addr, timedelta(seconds=1), room_id="jobB")
169+
170+
# Explicit heartbeat so each room has one participant
171+
cli_a.heartbeat("a0")
172+
cli_b.heartbeat("b0")
173+
174+
q_a = cli_a.quorum("a0", timedelta(seconds=1))
175+
q_b = cli_b.quorum("b0", timedelta(seconds=1))
176+
177+
# Both rooms got a quorum-id of 1 but with disjoint members
178+
self.assertEqual(q_a.quorum_id, 1)
179+
self.assertEqual(q_b.quorum_id, 1)
180+
181+
self.assertEqual(len(q_a.participants), 1)
182+
self.assertEqual(len(q_b.participants), 1)
183+
self.assertEqual(q_a.participants[0].replica_id, "a0")
184+
self.assertEqual(q_b.participants[0].replica_id, "b0")
185+
186+
finally:
187+
server.shutdown()

torchft/multi_quorum_test.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

0 commit comments

Comments
 (0)