Skip to content

Commit

Permalink
Add connect info api (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasmerlin authored Aug 28, 2024
1 parent da082ea commit f9fc7cb
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 23 deletions.
20 changes: 15 additions & 5 deletions examples/helloworld/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use tonic_ws_transport::WsConnection;

use futures_util::StreamExt;
use hello_world::greeter_server::{Greeter, GreeterServer};
use hello_world::{HelloReply, HelloRequest};
use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::transport::server::TcpConnectInfo;
use tonic::{transport::Server, Request, Response, Status};

use hello_world::greeter_server::{Greeter, GreeterServer};
use hello_world::{HelloReply, HelloRequest};

pub mod hello_world {
tonic::include_proto!("helloworld");
}
Expand All @@ -21,7 +21,13 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request: {:?}", request);
let addr = request.extensions().get::<TcpConnectInfo>().unwrap();

println!(
"Got a request: {:?} from ip {}",
request,
addr.remote_addr.unwrap()
);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
Expand All @@ -39,14 +45,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let incoming = listener_stream.filter_map(|connection| async {
match connection {
Ok(tcp_stream) => {
let info = TcpConnectInfo {
local_addr: tcp_stream.local_addr().ok(),
remote_addr: tcp_stream.peer_addr().ok(),
};
let ws_stream = match tokio_tungstenite::accept_async(tcp_stream).await {
Ok(ws_stream) => ws_stream,
Err(e) => {
eprintln!("failed to accept connection: {e}");
return None;
}
};
Some(Ok(WsConnection::from_combined_channel(ws_stream)))
Some(Ok(WsConnection::from_combined_channel(ws_stream, info)))
}
Err(e) => Some(Err(e)),
}
Expand Down
2 changes: 1 addition & 1 deletion transport/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ futures-util = { version = "0.3.12", default-features = false, features = ["sink
http = "0.2.3"
pin-project = "1.0.5"
thiserror = "1.0.23"
tokio = { version = "=1.34.0", default-features = false, features = ["rt"] }
tokio = { version = "1.34.0", default-features = false, features = ["rt"] }
tokio-util = { version = "0.6.3", default-features = false, features = ["io"] }
tower = { version = "0.4.4", default-features = false, optional = true }
tungstenite = { version = "0.20.0", default-features = false }
Expand Down
18 changes: 11 additions & 7 deletions transport/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[cfg(feature = "native")]
pub use native::WsConnectionInfo;

use futures_util::{ready, sink::Sink};
use pin_project::pin_project;
use thiserror::Error;
Expand Down Expand Up @@ -100,7 +97,10 @@ impl WsConnector {
}

let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
Ok(WsConnection::from_combined_channel(ws_stream))
Ok(WsConnection::from_combined_channel(
ws_stream,
EmptyConnectInfo,
))
}
}

Expand Down Expand Up @@ -140,18 +140,22 @@ impl Future for WsConnecting {
}
}

#[derive(Debug, Clone)]
pub struct EmptyConnectInfo;

#[pin_project]
pub struct WsConnection {
pub struct WsConnection<CI = EmptyConnectInfo> {
#[pin]
pub(crate) sink: WsConnectionSink,
#[pin]
pub(crate) reader: WsConnectionReader,
pub(crate) info: CI,
}

type WsConnectionSink = Box<dyn Sink<Message, Error = Error> + Unpin + Send>;
type WsConnectionReader = Box<dyn AsyncRead + Unpin + Send>;

impl AsyncWrite for WsConnection {
impl<T> AsyncWrite for WsConnection<T> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut self_ = self.project();
ready!(self_.sink.as_mut().poll_ready(cx)?);
Expand All @@ -175,7 +179,7 @@ impl AsyncWrite for WsConnection {
}

// forward AsyncRead impl to the `reader` field
impl AsyncRead for WsConnection {
impl<T> AsyncRead for WsConnection<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
Expand Down
17 changes: 7 additions & 10 deletions transport/src/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use tungstenite::{Error as TungsteniteError, Message};

use std::io;

impl WsConnection {
pub fn from_combined_channel<S>(ws_stream: S) -> Self
impl<T> WsConnection<T> {
pub fn from_combined_channel<S>(ws_stream: S, info: T) -> Self
where
S: Sink<Message, Error = TungsteniteError>
+ Stream<Item = Result<Message, TungsteniteError>>
Expand Down Expand Up @@ -37,23 +37,20 @@ impl WsConnection {
Self {
sink: Box::new(sink),
reader,
info,
}
}
}

#[derive(Clone)]
#[non_exhaustive]
pub struct WsConnectionInfo {}

impl Connected for WsConnection {
type ConnectInfo = WsConnectionInfo;
impl<T: Clone + Send + Sync + 'static> Connected for WsConnection<T> {
type ConnectInfo = T;

fn connect_info(&self) -> Self::ConnectInfo {
WsConnectionInfo {}
self.info.clone()
}
}

impl hyper::client::connect::Connection for WsConnection {
impl<T> hyper::client::connect::Connection for WsConnection<T> {
fn connected(&self) -> hyper::client::connect::Connected {
hyper::client::connect::Connected::new()
}
Expand Down
1 change: 1 addition & 0 deletions transport/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub async fn connect(dst: http::Uri) -> Result<super::WsConnection, Error> {
Ok(super::WsConnection {
sink: Box::new(messages_sink),
reader: Box::new(tokio_util::io::StreamReader::new(bytes_stream)),
info: crate::EmptyConnectInfo,
})
}

Expand Down

0 comments on commit f9fc7cb

Please sign in to comment.