From f9fc7cb603f6426e18267d580a96182f2882a59f Mon Sep 17 00:00:00 2001 From: lucasmerlin Date: Wed, 28 Aug 2024 11:21:21 +0200 Subject: [PATCH] Add connect info api (#12) --- examples/helloworld/src/server.rs | 20 +++++++++++++++----- transport/Cargo.toml | 2 +- transport/src/lib.rs | 18 +++++++++++------- transport/src/native.rs | 17 +++++++---------- transport/src/web.rs | 1 + 5 files changed, 35 insertions(+), 23 deletions(-) diff --git a/examples/helloworld/src/server.rs b/examples/helloworld/src/server.rs index c5ecc3d..3cf589a 100644 --- a/examples/helloworld/src/server.rs +++ b/examples/helloworld/src/server.rs @@ -1,13 +1,13 @@ use tonic_ws_transport::WsConnection; use futures_util::StreamExt; +use hello_world::greeter_server::{Greeter, GreeterServer}; +use hello_world::{HelloReply, HelloRequest}; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::server::TcpConnectInfo; use tonic::{transport::Server, Request, Response, Status}; -use hello_world::greeter_server::{Greeter, GreeterServer}; -use hello_world::{HelloReply, HelloRequest}; - pub mod hello_world { tonic::include_proto!("helloworld"); } @@ -21,7 +21,13 @@ impl Greeter for MyGreeter { &self, request: Request, ) -> Result, Status> { - println!("Got a request: {:?}", request); + let addr = request.extensions().get::().unwrap(); + + println!( + "Got a request: {:?} from ip {}", + request, + addr.remote_addr.unwrap() + ); let reply = hello_world::HelloReply { message: format!("Hello {}!", request.into_inner().name), @@ -39,6 +45,10 @@ async fn main() -> Result<(), Box> { let incoming = listener_stream.filter_map(|connection| async { match connection { Ok(tcp_stream) => { + let info = TcpConnectInfo { + local_addr: tcp_stream.local_addr().ok(), + remote_addr: tcp_stream.peer_addr().ok(), + }; let ws_stream = match tokio_tungstenite::accept_async(tcp_stream).await { Ok(ws_stream) => ws_stream, Err(e) => { @@ -46,7 +56,7 @@ async fn main() -> Result<(), Box> { return None; } }; - Some(Ok(WsConnection::from_combined_channel(ws_stream))) + Some(Ok(WsConnection::from_combined_channel(ws_stream, info))) } Err(e) => Some(Err(e)), } diff --git a/transport/Cargo.toml b/transport/Cargo.toml index f2ece5d..59b00a2 100644 --- a/transport/Cargo.toml +++ b/transport/Cargo.toml @@ -31,7 +31,7 @@ futures-util = { version = "0.3.12", default-features = false, features = ["sink http = "0.2.3" pin-project = "1.0.5" thiserror = "1.0.23" -tokio = { version = "=1.34.0", default-features = false, features = ["rt"] } +tokio = { version = "1.34.0", default-features = false, features = ["rt"] } tokio-util = { version = "0.6.3", default-features = false, features = ["io"] } tower = { version = "0.4.4", default-features = false, optional = true } tungstenite = { version = "0.20.0", default-features = false } diff --git a/transport/src/lib.rs b/transport/src/lib.rs index 3c19b90..138734c 100644 --- a/transport/src/lib.rs +++ b/transport/src/lib.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "native")] -pub use native::WsConnectionInfo; - use futures_util::{ready, sink::Sink}; use pin_project::pin_project; use thiserror::Error; @@ -100,7 +97,10 @@ impl WsConnector { } let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?; - Ok(WsConnection::from_combined_channel(ws_stream)) + Ok(WsConnection::from_combined_channel( + ws_stream, + EmptyConnectInfo, + )) } } @@ -140,18 +140,22 @@ impl Future for WsConnecting { } } +#[derive(Debug, Clone)] +pub struct EmptyConnectInfo; + #[pin_project] -pub struct WsConnection { +pub struct WsConnection { #[pin] pub(crate) sink: WsConnectionSink, #[pin] pub(crate) reader: WsConnectionReader, + pub(crate) info: CI, } type WsConnectionSink = Box + Unpin + Send>; type WsConnectionReader = Box; -impl AsyncWrite for WsConnection { +impl AsyncWrite for WsConnection { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let mut self_ = self.project(); ready!(self_.sink.as_mut().poll_ready(cx)?); @@ -175,7 +179,7 @@ impl AsyncWrite for WsConnection { } // forward AsyncRead impl to the `reader` field -impl AsyncRead for WsConnection { +impl AsyncRead for WsConnection { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, diff --git a/transport/src/native.rs b/transport/src/native.rs index 34fbedb..a909734 100644 --- a/transport/src/native.rs +++ b/transport/src/native.rs @@ -7,8 +7,8 @@ use tungstenite::{Error as TungsteniteError, Message}; use std::io; -impl WsConnection { - pub fn from_combined_channel(ws_stream: S) -> Self +impl WsConnection { + pub fn from_combined_channel(ws_stream: S, info: T) -> Self where S: Sink + Stream> @@ -37,23 +37,20 @@ impl WsConnection { Self { sink: Box::new(sink), reader, + info, } } } -#[derive(Clone)] -#[non_exhaustive] -pub struct WsConnectionInfo {} - -impl Connected for WsConnection { - type ConnectInfo = WsConnectionInfo; +impl Connected for WsConnection { + type ConnectInfo = T; fn connect_info(&self) -> Self::ConnectInfo { - WsConnectionInfo {} + self.info.clone() } } -impl hyper::client::connect::Connection for WsConnection { +impl hyper::client::connect::Connection for WsConnection { fn connected(&self) -> hyper::client::connect::Connected { hyper::client::connect::Connected::new() } diff --git a/transport/src/web.rs b/transport/src/web.rs index 355e15f..be225f9 100644 --- a/transport/src/web.rs +++ b/transport/src/web.rs @@ -43,6 +43,7 @@ pub async fn connect(dst: http::Uri) -> Result { Ok(super::WsConnection { sink: Box::new(messages_sink), reader: Box::new(tokio_util::io::StreamReader::new(bytes_stream)), + info: crate::EmptyConnectInfo, }) }