|
1 | | -use std::sync::Arc; |
| 1 | +use std::{ |
| 2 | + convert::Infallible, |
| 3 | + future::Future, |
| 4 | + pin::Pin, |
| 5 | + sync::Arc, |
| 6 | + task::{Context, Poll}, |
| 7 | +}; |
2 | 8 |
|
3 | 9 | use dashmap::{mapref::entry::Entry, DashMap}; |
4 | | -use tonic::{Request, Response, Status}; |
| 10 | +use futures::FutureExt; |
| 11 | +use tonic::{ |
| 12 | + body::BoxBody, |
| 13 | + codegen::http::{HeaderMap, Request, Response}, // http-0.2 types |
| 14 | + server::NamedService, |
| 15 | +}; |
| 16 | +use tower::Service; |
5 | 17 |
|
6 | 18 | use crate::{ |
7 | 19 | lighthouse::{Lighthouse, LighthouseOpt}, |
8 | | - torchftpb::{ |
9 | | - lighthouse_service_server::LighthouseService, LighthouseHeartbeatRequest, |
10 | | - LighthouseHeartbeatResponse, LighthouseQuorumRequest, LighthouseQuorumResponse, |
11 | | - }, |
| 20 | + torchftpb::lighthouse_service_server::LighthouseServiceServer, |
12 | 21 | }; |
13 | 22 |
|
14 | | -/// Metadata header for both client and router |
| 23 | +/// Metadata header recognised by both client interceptor and this router. |
15 | 24 | pub const ROOM_ID_HEADER: &str = "room-id"; |
16 | 25 |
|
17 | | -/// Top-level service registered with tonic’s `Server::builder()` |
| 26 | +/// gRPC server for a single room (inner state = `Arc<Lighthouse>`). |
| 27 | +type GrpcSvc = LighthouseServiceServer<Arc<Lighthouse>>; |
| 28 | + |
18 | 29 | #[derive(Clone)] |
19 | 30 | pub struct Router { |
20 | | - rooms: Arc<DashMap<String, Arc<Lighthouse>>>, |
21 | | - tmpl_opt: LighthouseOpt, // (cloned for each new room) |
| 31 | + rooms: Arc<DashMap<String, Arc<GrpcSvc>>>, |
| 32 | + tmpl_opt: LighthouseOpt, |
22 | 33 | } |
23 | 34 |
|
24 | | -/// Designates a single tonic gRPC server into many logical “rooms.” |
25 | | -/// Inspects the `room-id` metadata header on each request, then |
26 | | -/// lazily creates or reuses an Arc<Lighthouse> for that namespace |
27 | 35 | impl Router { |
28 | | - /// Create a new router given the CLI/config options that are |
29 | | - /// normally passed straight to `Lighthouse::new`. |
30 | 36 | pub fn new(tmpl_opt: LighthouseOpt) -> Self { |
31 | 37 | Self { |
32 | 38 | rooms: Arc::new(DashMap::new()), |
33 | 39 | tmpl_opt, |
34 | 40 | } |
35 | 41 | } |
36 | 42 |
|
37 | | - /// Room lookup: creation if it doesn't exist, access if it does |
38 | | - async fn room(&self, id: &str) -> Arc<Lighthouse> { |
39 | | - // 1. Quick optimistic read (no locking contention). |
40 | | - if let Some(handle) = self.rooms.get(id) { |
41 | | - return handle.clone(); |
| 43 | + fn room_id(hdrs: &HeaderMap) -> &str { |
| 44 | + hdrs.get(ROOM_ID_HEADER) |
| 45 | + .and_then(|v| v.to_str().ok()) |
| 46 | + .unwrap_or("default") |
| 47 | + } |
| 48 | + |
| 49 | + async fn room_service( |
| 50 | + rooms: Arc<DashMap<String, Arc<GrpcSvc>>>, |
| 51 | + tmpl: LighthouseOpt, |
| 52 | + id: &str, |
| 53 | + ) -> Arc<GrpcSvc> { |
| 54 | + if let Some(svc) = rooms.get(id) { |
| 55 | + return svc.clone(); |
42 | 56 | } |
43 | 57 |
|
44 | | - // 2. Build the Lighthouse instance *off the map* so |
45 | | - // we don't hold any guard across `.await`. |
46 | | - let new_room = Lighthouse::new(self.tmpl_opt.clone()) |
| 58 | + // Build room state once. |
| 59 | + let lh = Lighthouse::new(tmpl.clone()) |
47 | 60 | .await |
48 | 61 | .expect("failed to create Lighthouse"); |
49 | 62 |
|
50 | | - // 3. Second pass: insert if still vacant, otherwise reuse |
51 | | - // whatever another task inserted first. |
52 | | - match self.rooms.entry(id.to_owned()) { |
53 | | - Entry::Occupied(entry) => entry.get().clone(), |
54 | | - Entry::Vacant(entry) => { |
55 | | - entry.insert(new_room.clone()); |
56 | | - new_room |
| 63 | + let svc_new = Arc::new(LighthouseServiceServer::new(lh)); |
| 64 | + |
| 65 | + match rooms.entry(id.to_owned()) { |
| 66 | + Entry::Occupied(e) => e.get().clone(), |
| 67 | + Entry::Vacant(v) => { |
| 68 | + v.insert(svc_new.clone()); |
| 69 | + svc_new |
57 | 70 | } |
58 | 71 | } |
59 | 72 | } |
60 | | - |
61 | | - /// Extracts `"room-id"` from metadata, defaulting to `"default"`. |
62 | | - fn extract_room_id(meta: &tonic::metadata::MetadataMap) -> &str { |
63 | | - meta.get(ROOM_ID_HEADER) |
64 | | - .and_then(|v| v.to_str().ok()) |
65 | | - .unwrap_or("default") |
66 | | - } |
67 | 73 | } |
68 | 74 |
|
69 | | -#[tonic::async_trait] |
70 | | -impl LighthouseService for Router { |
71 | | - async fn quorum( |
72 | | - &self, |
73 | | - req: Request<LighthouseQuorumRequest>, |
74 | | - ) -> Result<Response<LighthouseQuorumResponse>, Status> { |
75 | | - let id = Self::extract_room_id(req.metadata()).to_owned(); |
76 | | - let room = self.room(&id).await; |
77 | | - <Arc<Lighthouse> as LighthouseService>::quorum(&room, req).await |
| 75 | +// Tower::Service implementation |
| 76 | +impl Service<Request<BoxBody>> for Router { |
| 77 | + type Response = Response<BoxBody>; |
| 78 | + type Error = Infallible; |
| 79 | + type Future = |
| 80 | + Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; |
| 81 | + |
| 82 | + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 83 | + Poll::Ready(Ok(())) |
78 | 84 | } |
79 | 85 |
|
80 | | - async fn heartbeat( |
81 | | - &self, |
82 | | - req: Request<LighthouseHeartbeatRequest>, |
83 | | - ) -> Result<Response<LighthouseHeartbeatResponse>, Status> { |
84 | | - let id = Self::extract_room_id(req.metadata()).to_owned(); |
85 | | - let room = self.room(&id).await; |
86 | | - <Arc<Lighthouse> as LighthouseService>::heartbeat(&room, req).await |
| 86 | + fn call(&mut self, req: Request<BoxBody>) -> Self::Future { |
| 87 | + let rooms = self.rooms.clone(); |
| 88 | + let tmpl = self.tmpl_opt.clone(); |
| 89 | + let room = Self::room_id(req.headers()).to_owned(); |
| 90 | + |
| 91 | + async move { |
| 92 | + let svc_arc = Self::room_service(rooms, tmpl, &room).await; |
| 93 | + |
| 94 | + // `Arc<GrpcSvc>` itself isn’t a Service; clone the inner value. |
| 95 | + let mut svc = (*svc_arc).clone(); |
| 96 | + let resp = svc |
| 97 | + .call(req) |
| 98 | + .await |
| 99 | + .map_err(|_e| -> Infallible { unreachable!() })?; |
| 100 | + |
| 101 | + Ok(resp) |
| 102 | + } |
| 103 | + .boxed() |
87 | 104 | } |
88 | 105 | } |
| 106 | + |
| 107 | +// Forward tonic’s NamedService marker |
| 108 | +impl NamedService for Router { |
| 109 | + const NAME: &'static str = <GrpcSvc as NamedService>::NAME; |
| 110 | +} |
0 commit comments