Skip to content

Commit 5ab4c0c

Browse files
author
Matt Kotzbauer
committed
Interceptor attached via LighthouseClient constructor rather than using add_room_header for each RPC call
1 parent eb482e5 commit 5ab4c0c

File tree

4 files changed

+64
-30
lines changed

4 files changed

+64
-30
lines changed

src/bin/lighthouse.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
use std::net::SocketAddr;
88
use structopt::StructOpt;
9+
use tonic::transport::Server;
910
use torchft::lighthouse::LighthouseOpt;
11+
use torchft::torchftpb::lighthouse_service_server::LighthouseServiceServer;
1012
use torchft::router::Router;
11-
use torchftpb::lighthouse_service_server::LighthouseServiceServer;
13+
1214

1315
#[tokio::main(flavor = "multi_thread", worker_threads = 4)]
1416
async fn main() {

src/interceptor.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use tonic::{service::Interceptor, metadata::MetadataValue, Request, Status};
2+
3+
/// Attaches user-assigned room-id header to every outbound RPC
4+
#[derive(Clone)]
5+
pub struct RoomIdInterceptor {
6+
room: String,
7+
}
8+
9+
impl RoomIdInterceptor {
10+
pub fn new(room: String) -> Self {
11+
Self { room }
12+
}
13+
}
14+
15+
impl Interceptor for RoomIdInterceptor {
16+
fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, Status> {
17+
req.metadata_mut().insert(
18+
crate::router::ROOM_ID_HEADER,
19+
MetadataValue::try_from(self.room.as_str()).expect("ascii header"),
20+
);
21+
Ok(req)
22+
}
23+
}

src/interceptor.rs~

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use tonic::{Request, Status, service::Interceptor};
2+
use tonic::metadata::MetadataValue;
3+
4+
pub fn room_id_interceptor(room: String) -> impl Interceptor {
5+
move |mut req: Request<()>| {
6+
req.metadata_mut().insert(
7+
crate::router::ROOM_ID_HEADER,
8+
MetadataValue::try_from(room.as_str()).expect("ascii header"),
9+
);
10+
Ok(req) // returning Err(Status) would cancel the call :contentReference[oaicite:0]{index=0}
11+
}
12+
}

src/lib.rs

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
pub mod lighthouse;
88
pub mod manager;
9+
pub mod router;
910
mod net;
1011
mod retry;
11-
mod router;
1212
mod timeout;
13+
mod interceptor;
1314

1415
pub use crate::router::Router;
1516

@@ -25,8 +26,9 @@ use structopt::StructOpt;
2526
use tokio::runtime::Runtime;
2627
use tokio::task::JoinHandle;
2728
use tokio_stream::wrappers::TcpListenerStream;
28-
use tonic::transport::Channel;
29+
use tonic::transport::{Channel, Endpoint};
2930
use tonic::Status;
31+
use tonic::service::interceptor::InterceptedService;
3032

3133
use chrono::Local;
3234
use fern::colors::{Color, ColoredLevelConfig};
@@ -36,6 +38,7 @@ pub mod torchftpb {
3638
tonic::include_proto!("torchft");
3739
}
3840

41+
use crate::interceptor::RoomIdInterceptor;
3942
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
4043
use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer;
4144
use crate::torchftpb::manager_service_client::ManagerServiceClient;
@@ -486,9 +489,10 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
486489
/// connect_timeout (timedelta): The timeout for connecting to the lighthouse server.
487490
#[pyclass]
488491
struct LighthouseClient {
489-
client: LighthouseServiceClient<Channel>,
492+
client: LighthouseServiceClient<
493+
InterceptedService<Channel, RoomIdInterceptor>
494+
>,
490495
runtime: Runtime,
491-
room_id: Option<String>,
492496
}
493497

494498
#[pymethods]
@@ -507,14 +511,25 @@ impl LighthouseClient {
507511
.thread_name("torchft-lhclnt")
508512
.enable_all()
509513
.build()?;
510-
let client = runtime
511-
.block_on(manager::lighthouse_client_new(addr, connect_timeout))
514+
515+
let endpoint = Endpoint::from_shared(addr.clone())
512516
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
513-
Ok(Self {
514-
client: client,
515-
runtime: runtime,
516-
room_id: room_id,
517-
})
517+
let channel = runtime
518+
.block_on(
519+
endpoint
520+
.connect_timeout(connect_timeout)
521+
.connect(),
522+
)
523+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
524+
525+
let interceptor =
526+
RoomIdInterceptor::new(room_id.unwrap_or_else(|| "default".to_owned()));
527+
528+
let client =
529+
LighthouseServiceClient::with_interceptor(channel, interceptor);
530+
531+
Ok(Self { client, runtime })
532+
518533
})
519534
}
520535

@@ -569,8 +584,6 @@ impl LighthouseClient {
569584
}),
570585
});
571586

572-
let mut request = self.add_room_header(request);
573-
574587
// This timeout is processed on the server side so we also enable
575588
// keep alives to detect server health.
576589
request.set_timeout(timeout);
@@ -599,29 +612,13 @@ impl LighthouseClient {
599612
) -> Result<(), StatusError> {
600613
py.allow_threads(move || {
601614
let mut req = tonic::Request::new(LighthouseHeartbeatRequest { replica_id });
602-
let mut req = self.add_room_header(req);
603615
req.set_timeout(timeout);
604616
self.runtime.block_on(self.client.clone().heartbeat(req))?;
605617
Ok(())
606618
})
607619
}
608620
}
609621

610-
impl LighthouseClient {
611-
/// Attach `"room-id"` header if `self.room_id` is Some(_)
612-
fn add_room_header<T>(&self, mut req: tonic::Request<T>) -> tonic::Request<T> {
613-
if let Some(ref id) = self.room_id {
614-
use tonic::metadata::MetadataValue;
615-
req.metadata_mut().insert(
616-
crate::router::ROOM_ID_HEADER,
617-
MetadataValue::try_from(id.as_str()).expect("room-id ascii"),
618-
);
619-
}
620-
req
621-
}
622-
623-
}
624-
625622
/// LighthouseServer is a GRPC server for the lighthouse service.
626623
///
627624
/// It is used to coordinate the ManagerServer for each replica group.

0 commit comments

Comments
 (0)