Skip to content

Commit 0a9ce34

Browse files
author
Matt Kotzbauer
committed
Tonic-level routing changed to tower-level in src/router.rs - Server::builder calls (in src/bin/lighthouser.rs, src/lib.rs) and torchft/multi_quorum_test.py modified to reflect change.
1 parent 5ab4c0c commit 0a9ce34

File tree

5 files changed

+119
-102
lines changed

5 files changed

+119
-102
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ axum = "0.7.7"
1111
chrono = "0.4.40"
1212
dashmap = "6.1"
1313
fern = {version = "0.7.1", features = ["colored"]}
14+
futures = "0.3"
1415
gethostname = "0.5.0"
16+
hyper = "0.14"
17+
http = "0.2"
1518
log = "0.4.22"
1619
prost = "0.13.3"
1720
prost-types = "0.13.3"
@@ -24,6 +27,7 @@ structopt = "0.3.26"
2427
tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] }
2528
tokio-stream = "0.1"
2629
tonic = "0.12.2"
30+
tower = "0.4"
2731

2832
[build-dependencies]
2933
tonic-build = "0.12.2"

src/bin/lighthouse.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@ use std::net::SocketAddr;
88
use structopt::StructOpt;
99
use tonic::transport::Server;
1010
use torchft::lighthouse::LighthouseOpt;
11-
use torchft::torchftpb::lighthouse_service_server::LighthouseServiceServer;
1211
use torchft::router::Router;
1312

14-
1513
#[tokio::main(flavor = "multi_thread", worker_threads = 4)]
1614
async fn main() {
1715
stderrlog::new()
@@ -23,9 +21,11 @@ async fn main() {
2321

2422
let opt = LighthouseOpt::from_args();
2523
let router = Router::new(opt.clone());
24+
let addr: SocketAddr = opt.bind.parse().expect("invalid --bind address");
25+
2626
Server::builder()
27-
.add_service(LighthouseServiceServer::new(router))
28-
.serve(opt.bind.parse::<SocketAddr>().unwrap())
27+
.add_service(router)
28+
.serve(addr)
2929
.await
3030
.unwrap();
3131
}

src/lib.rs

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
// This source code is licensed under the BSD-style license found in the
55
// LICENSE file in the root directory of this source tree.
66

7+
mod interceptor;
78
pub mod lighthouse;
89
pub mod manager;
9-
pub mod router;
1010
mod net;
1111
mod retry;
12+
pub mod router;
1213
mod timeout;
13-
mod interceptor;
1414

1515
pub use crate::router::Router;
1616

@@ -20,15 +20,16 @@ use core::time::Duration;
2020
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
2121
use std::cmp;
2222
use std::env;
23+
use std::net::SocketAddr;
2324
use std::sync::Arc;
2425
use std::thread::available_parallelism;
2526
use structopt::StructOpt;
2627
use tokio::runtime::Runtime;
2728
use tokio::task::JoinHandle;
2829
use tokio_stream::wrappers::TcpListenerStream;
30+
use tonic::service::interceptor::InterceptedService;
2931
use tonic::transport::{Channel, Endpoint};
3032
use tonic::Status;
31-
use tonic::service::interceptor::InterceptedService;
3233

3334
use chrono::Local;
3435
use fern::colors::{Color, ColoredLevelConfig};
@@ -40,9 +41,7 @@ pub mod torchftpb {
4041

4142
use crate::interceptor::RoomIdInterceptor;
4243
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
43-
use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer;
4444
use crate::torchftpb::manager_service_client::ManagerServiceClient;
45-
use crate::torchftpb::LighthouseHeartbeatRequest;
4645
use crate::torchftpb::{
4746
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
4847
ManagerQuorumRequest, ShouldCommitRequest,
@@ -349,10 +348,11 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
349348

350349
async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
351350
let router = Router::new(opt.clone());
351+
let addr: SocketAddr = opt.bind.parse()?;
352352

353353
tonic::transport::Server::builder()
354-
.add_service(LighthouseServiceServer::new(router))
355-
.serve(opt.bind.parse::<std::net::SocketAddr>()?)
354+
.add_service(router)
355+
.serve(addr)
356356
.await?;
357357

358358
Ok(())
@@ -489,9 +489,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
489489
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
490490
#[pyclass]
491491
struct LighthouseClient {
492-
client: LighthouseServiceClient<
493-
InterceptedService<Channel, RoomIdInterceptor>
494-
>,
492+
client: LighthouseServiceClient<InterceptedService<Channel, RoomIdInterceptor>>,
495493
runtime: Runtime,
496494
}
497495

@@ -515,21 +513,15 @@ impl LighthouseClient {
515513
let endpoint = Endpoint::from_shared(addr.clone())
516514
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
517515
let channel = runtime
518-
.block_on(
519-
endpoint
520-
.connect_timeout(connect_timeout)
521-
.connect(),
522-
)
516+
.block_on(endpoint.connect_timeout(connect_timeout).connect())
523517
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
524518

525519
let interceptor =
526520
RoomIdInterceptor::new(room_id.unwrap_or_else(|| "default".to_owned()));
527521

528-
let client =
529-
LighthouseServiceClient::with_interceptor(channel, interceptor);
522+
let client = LighthouseServiceClient::with_interceptor(channel, interceptor);
530523

531-
Ok(Self { client, runtime })
532-
524+
Ok(Self { client, runtime })
533525
})
534526
}
535527

@@ -674,10 +666,11 @@ impl LighthouseServer {
674666
let bound_sock = listener.local_addr()?;
675667
let bound = format!("http://{}", bound_sock);
676668
let incoming = TcpListenerStream::new(listener);
669+
let router = Router::new(opt.clone());
677670

678671
let handle = rt.spawn(async move {
679672
tonic::transport::Server::builder()
680-
.add_service(LighthouseServiceServer::new(Router::new(opt.clone())))
673+
.add_service(router)
681674
.serve_with_incoming(incoming)
682675
.await
683676
.map_err(|e: tonic::transport::Error| anyhow::anyhow!(e))

src/router.rs

Lines changed: 75 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,110 @@
1-
use std::sync::Arc;
1+
use std::{
2+
convert::Infallible,
3+
future::Future,
4+
pin::Pin,
5+
sync::Arc,
6+
task::{Context, Poll},
7+
};
28

39
use dashmap::{mapref::entry::Entry, DashMap};
4-
use tonic::{Request, Response, Status};
10+
use futures::FutureExt;
11+
use tonic::{
12+
body::BoxBody,
13+
codegen::http::{HeaderMap, Request, Response}, // http-0.2 types
14+
server::NamedService,
15+
};
16+
use tower::Service;
517

618
use crate::{
719
lighthouse::{Lighthouse, LighthouseOpt},
8-
torchftpb::{
9-
lighthouse_service_server::LighthouseService, LighthouseHeartbeatRequest,
10-
LighthouseHeartbeatResponse, LighthouseQuorumRequest, LighthouseQuorumResponse,
11-
},
20+
torchftpb::lighthouse_service_server::LighthouseServiceServer,
1221
};
1322

14-
/// Metadata header for both client and router
23+
/// Metadata header recognised by both client interceptor and this router.
1524
pub const ROOM_ID_HEADER: &str = "room-id";
1625

17-
/// Top-level service registered with tonic’s `Server::builder()`
26+
/// gRPC server for a single room (inner state = `Arc<Lighthouse>`).
27+
type GrpcSvc = LighthouseServiceServer<Arc<Lighthouse>>;
28+
1829
#[derive(Clone)]
1930
pub struct Router {
20-
rooms: Arc<DashMap<String, Arc<Lighthouse>>>,
21-
tmpl_opt: LighthouseOpt, // (cloned for each new room)
31+
rooms: Arc<DashMap<String, Arc<GrpcSvc>>>,
32+
tmpl_opt: LighthouseOpt,
2233
}
2334

24-
/// Designates a single tonic gRPC server into many logical “rooms.”
25-
/// Inspects the `room-id` metadata header on each request, then
26-
/// lazily creates or reuses an Arc<Lighthouse> for that namespace
2735
impl Router {
28-
/// Create a new router given the CLI/config options that are
29-
/// normally passed straight to `Lighthouse::new`.
3036
pub fn new(tmpl_opt: LighthouseOpt) -> Self {
3137
Self {
3238
rooms: Arc::new(DashMap::new()),
3339
tmpl_opt,
3440
}
3541
}
3642

37-
/// Room lookup: creation if it doesn't exist, access if it does
38-
async fn room(&self, id: &str) -> Arc<Lighthouse> {
39-
// 1. Quick optimistic read (no locking contention).
40-
if let Some(handle) = self.rooms.get(id) {
41-
return handle.clone();
43+
fn room_id(hdrs: &HeaderMap) -> &str {
44+
hdrs.get(ROOM_ID_HEADER)
45+
.and_then(|v| v.to_str().ok())
46+
.unwrap_or("default")
47+
}
48+
49+
async fn room_service(
50+
rooms: Arc<DashMap<String, Arc<GrpcSvc>>>,
51+
tmpl: LighthouseOpt,
52+
id: &str,
53+
) -> Arc<GrpcSvc> {
54+
if let Some(svc) = rooms.get(id) {
55+
return svc.clone();
4256
}
4357

44-
// 2. Build the Lighthouse instance *off the map* so
45-
// we don't hold any guard across `.await`.
46-
let new_room = Lighthouse::new(self.tmpl_opt.clone())
58+
// Build room state once.
59+
let lh = Lighthouse::new(tmpl.clone())
4760
.await
4861
.expect("failed to create Lighthouse");
4962

50-
// 3. Second pass: insert if still vacant, otherwise reuse
51-
// whatever another task inserted first.
52-
match self.rooms.entry(id.to_owned()) {
53-
Entry::Occupied(entry) => entry.get().clone(),
54-
Entry::Vacant(entry) => {
55-
entry.insert(new_room.clone());
56-
new_room
63+
let svc_new = Arc::new(LighthouseServiceServer::new(lh));
64+
65+
match rooms.entry(id.to_owned()) {
66+
Entry::Occupied(e) => e.get().clone(),
67+
Entry::Vacant(v) => {
68+
v.insert(svc_new.clone());
69+
svc_new
5770
}
5871
}
5972
}
60-
61-
/// Extracts `"room-id"` from metadata, defaulting to `"default"`.
62-
fn extract_room_id(meta: &tonic::metadata::MetadataMap) -> &str {
63-
meta.get(ROOM_ID_HEADER)
64-
.and_then(|v| v.to_str().ok())
65-
.unwrap_or("default")
66-
}
6773
}
6874

69-
#[tonic::async_trait]
70-
impl LighthouseService for Router {
71-
async fn quorum(
72-
&self,
73-
req: Request<LighthouseQuorumRequest>,
74-
) -> Result<Response<LighthouseQuorumResponse>, Status> {
75-
let id = Self::extract_room_id(req.metadata()).to_owned();
76-
let room = self.room(&id).await;
77-
<Arc<Lighthouse> as LighthouseService>::quorum(&room, req).await
75+
// Tower::Service implementation
76+
impl Service<Request<BoxBody>> for Router {
77+
type Response = Response<BoxBody>;
78+
type Error = Infallible;
79+
type Future =
80+
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
81+
82+
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83+
Poll::Ready(Ok(()))
7884
}
7985

80-
async fn heartbeat(
81-
&self,
82-
req: Request<LighthouseHeartbeatRequest>,
83-
) -> Result<Response<LighthouseHeartbeatResponse>, Status> {
84-
let id = Self::extract_room_id(req.metadata()).to_owned();
85-
let room = self.room(&id).await;
86-
<Arc<Lighthouse> as LighthouseService>::heartbeat(&room, req).await
86+
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
87+
let rooms = self.rooms.clone();
88+
let tmpl = self.tmpl_opt.clone();
89+
let room = Self::room_id(req.headers()).to_owned();
90+
91+
async move {
92+
let svc_arc = Self::room_service(rooms, tmpl, &room).await;
93+
94+
// `Arc<GrpcSvc>` itself isn’t a Service; clone the inner value.
95+
let mut svc = (*svc_arc).clone();
96+
let resp = svc
97+
.call(req)
98+
.await
99+
.map_err(|_e| -> Infallible { unreachable!() })?;
100+
101+
Ok(resp)
102+
}
103+
.boxed()
87104
}
88105
}
106+
107+
// Forward tonic’s NamedService marker
108+
impl NamedService for Router {
109+
const NAME: &'static str = <GrpcSvc as NamedService>::NAME;
110+
}

torchft/multi_quorum_test.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,44 @@
1-
"""
2-
Validate that one Lighthouse server can host isolated quorums
3-
for multiple logical rooms (job IDs) via `room-id` metadata header.
4-
"""
5-
61
from __future__ import annotations
72

83
import datetime as _dt
4+
import time
95

106
import pytest
117

128
import torchft._torchft as ext
139

14-
_TIMEOUT = _dt.timedelta(seconds=3) # connect + RPC timeout
15-
10+
_TIMEOUT = _dt.timedelta(seconds=3)
1611

1712
def _client(addr: str, room: str) -> ext.LighthouseClient:
18-
"""Utility: create a client with a logical room-id."""
13+
"""Helper: create a LighthouseClient bound to a logical room."""
1914
return ext.LighthouseClient(addr, _TIMEOUT, room)
2015

2116

2217
@pytest.mark.asyncio
2318
async def test_multi_room_quorums() -> None:
24-
# 1) one server, any free port
25-
server = ext.LighthouseServer("[::]:0", 1)
26-
addr = server.address()
19+
# 1) Launch one Lighthouse server on any free port
20+
server = ext.LighthouseServer("[::]:0", min_replicas=1)
21+
addr: str = server.address()
22+
23+
# (give the Tokio runtime a tick to bind the listener)
24+
time.sleep(0.1)
2725

28-
# 2) two clients in two separate rooms
29-
a = _client(addr, "jobA")
30-
b = _client(addr, "jobB")
26+
# 2) Two clients, each in its own room
27+
cli_a = _client(addr, "jobA")
28+
cli_b = _client(addr, "jobB")
3129

32-
# 3) explicit heartbeats (exercises RPC path)
33-
a.heartbeat("a0")
34-
b.heartbeat("b0")
30+
# 3) Explicit heart-beats (exercise the RPC path)
31+
cli_a.heartbeat("a0")
32+
cli_b.heartbeat("b0")
3533

36-
# 4) ask for a quorum from each room
37-
qa = a.quorum("a0", _TIMEOUT)
38-
qb = b.quorum("b0", _TIMEOUT)
34+
# 4) Ask each room for a quorum
35+
q_a = cli_a.quorum("a0", _TIMEOUT)
36+
q_b = cli_b.quorum("b0", _TIMEOUT)
3937

40-
# 5) verify the rooms are independent
41-
assert qa.quorum_id == qb.quorum_id == 1
42-
assert len(qa.participants) == 1 and qa.participants[0].replica_id == "a0"
43-
assert len(qb.participants) == 1 and qb.participants[0].replica_id == "b0"
38+
# 5) Assert the rooms are isolated
39+
assert q_a.quorum_id == q_b.quorum_id == 1
40+
assert len(q_a.participants) == 1 and q_a.participants[0].replica_id == "a0"
41+
assert len(q_b.participants) == 1 and q_b.participants[0].replica_id == "b0"
4442

45-
# 6) shutdown
43+
# 6) Clean shutdown
4644
server.shutdown()

0 commit comments

Comments
 (0)