Skip to content

feat: Accept handler #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ quinn = { package = "iroh-quinn", version = "0.12", optional = true }
serde = { version = "1.0.183", features = ["derive"] }
tokio = { version = "1", default-features = false, features = ["macros", "sync"] }
tokio-serde = { version = "0.8", features = ["bincode"], optional = true }
tokio-util = { version = "0.7", features = ["codec"], optional = true }
tokio-util = { version = "0.7", features = ["rt"] }
tracing = "0.1"
hex = "0.4.3"
futures = { version = "0.3.30", optional = true }
Expand All @@ -52,12 +52,13 @@ proc-macro2 = "1.0.66"
futures-buffered = "0.2.4"
testresult = "0.4.1"
nested_enum_utils = "0.1.0"
tokio-util = { version = "0.7", features = ["rt"] }

[features]
hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"]
quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"]
hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "tokio-util/codec"]
quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"]
flume-transport = ["dep:flume"]
iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"]
iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"]
macros = []
default = ["flume-transport"]

Expand Down
21 changes: 4 additions & 17 deletions examples/modularize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use app::AppService;
use futures_lite::StreamExt;
use futures_util::SinkExt;
use quic_rpc::{client::BoxedConnector, transport::flume, Listener, RpcClient, RpcServer};
use tracing::warn;

#[tokio::main]
async fn main() -> Result<()> {
Expand All @@ -32,23 +31,11 @@ async fn main() -> Result<()> {

async fn run_server<C: Listener<AppService>>(server_conn: C, handler: app::Handler) {
let server = RpcServer::<AppService, _>::new(server_conn);
loop {
let Ok(accepting) = server.accept().await else {
continue;
};
match accepting.read_first().await {
Err(err) => warn!(?err, "server accept failed"),
Ok((req, chan)) => {
let handler = handler.clone();
tokio::task::spawn(async move {
if let Err(err) = handler.handle_rpc_request(req, chan).await {
warn!(?err, "internal rpc error");
}
});
}
}
}
server
.accept_loop(move |req, chan| handler.clone().handle_rpc_request(req, chan))
.await
}

pub async fn client_demo(conn: BoxedConnector<AppService>) -> Result<()> {
let rpc_client = RpcClient::<AppService>::new(conn);
let client = app::Client::new(rpc_client.clone());
Expand Down
67 changes: 66 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ use std::{
marker::PhantomData,
pin::Pin,
result,
sync::Arc,
task::{self, Poll},
};

use futures_lite::{Future, Stream, StreamExt};
use futures_util::{SinkExt, TryStreamExt};
use pin_project::pin_project;
use tokio::sync::oneshot;
use tokio::{sync::oneshot, task::JoinSet};
use tokio_util::task::AbortOnDropHandle;
use tracing::{error, warn};

use crate::{
transport::{
Expand Down Expand Up @@ -211,6 +214,68 @@ impl<S: Service, C: Listener<S>> RpcServer<S, C> {
pub fn into_inner(self) -> C {
self.source
}

/// Run an accept loop for this server.
///
/// Each request will be handled in a separate task.
///
/// It is the caller's responsibility to poll the returned future to drive the server.
pub async fn accept_loop<Fun, Fut, E>(self, handler: Fun)
where
S: Service,
C: Listener<S>,
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Into<anyhow::Error> + 'static,
{
let handler = Arc::new(handler);
let mut tasks = JoinSet::new();
loop {
tokio::select! {
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
if let Err(e) = res {
if e.is_panic() {
error!("Panic handling RPC request: {e}");
}
}
}
req = self.accept() => {
let req = match req {
Ok(req) => req,
Err(e) => {
warn!("Error accepting RPC request: {e}");
continue;
}
};
let handler = handler.clone();
tasks.spawn(async move {
let (req, chan) = match req.read_first().await {
Ok((req, chan)) => (req, chan),
Err(e) => {
warn!("Error reading first message: {e}");
return;
}
};
if let Err(cause) = handler(req, chan).await {
warn!("Error handling RPC request: {}", cause.into());
}
});
}
}
}
}

/// Spawn an accept loop and return a handle to the task.
pub fn spawn_accept_loop<Fun, Fut, E>(self, handler: Fun) -> AbortOnDropHandle<()>
where
S: Service,
C: Listener<S>,
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Into<anyhow::Error> + 'static,
{
AbortOnDropHandle::new(tokio::spawn(self.accept_loop(handler)))
}
}

impl<S: Service, C: Listener<S>> AsRef<C> for RpcServer<S, C> {
Expand Down
16 changes: 3 additions & 13 deletions tests/flume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@ use quic_rpc::{
transport::flume,
RpcClient, RpcServer, Service,
};
use tokio_util::task::AbortOnDropHandle;

#[tokio::test]
async fn flume_channel_bench() -> anyhow::Result<()> {
tracing_subscriber::fmt::try_init().ok();
let (server, client) = flume::channel(1);

let server = RpcServer::<ComputeService, _>::new(server);
let server_handle = tokio::task::spawn(ComputeService::server(server));
let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server)));
let client = RpcClient::<ComputeService, _>::new(client);
bench(client, 1000000).await?;
// dropping the client will cause the server to terminate
match server_handle.await? {
Err(RpcServerError::Accept(_)) => {}
e => panic!("unexpected termination result {e:?}"),
}
Ok(())
}

Expand Down Expand Up @@ -101,13 +97,7 @@ async fn flume_channel_smoke() -> anyhow::Result<()> {
let (server, client) = flume::channel(1);

let server = RpcServer::<ComputeService, _>::new(server);
let server_handle = tokio::task::spawn(ComputeService::server(server));
let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server)));
smoke_test(client).await?;

// dropping the client will cause the server to terminate
match server_handle.await? {
Err(RpcServerError::Accept(_)) => {}
e => panic!("unexpected termination result {e:?}"),
}
Ok(())
}
21 changes: 5 additions & 16 deletions tests/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,13 @@ use tokio::task::JoinHandle;

mod math;
use math::*;
use tokio_util::task::AbortOnDropHandle;
mod util;

fn run_server(addr: &SocketAddr) -> JoinHandle<anyhow::Result<()>> {
fn run_server(addr: &SocketAddr) -> AbortOnDropHandle<()> {
let channel = HyperListener::serve(addr).unwrap();
let server = RpcServer::new(channel);
tokio::spawn(async move {
loop {
let server = server.clone();
ComputeService::server(server).await?;
}
#[allow(unreachable_code)]
anyhow::Ok(())
})
ComputeService::server(server)
}

#[derive(Debug, Serialize, Deserialize, From, TryInto)]
Expand Down Expand Up @@ -133,25 +127,21 @@ impl TestService {
async fn hyper_channel_bench() -> anyhow::Result<()> {
let addr: SocketAddr = "127.0.0.1:3000".parse()?;
let uri: Uri = "http://127.0.0.1:3000".parse()?;
let server_handle = run_server(&addr);
let _server_handle = run_server(&addr);
let client = HyperConnector::new(uri);
let client = RpcClient::new(client);
bench(client, 50000).await?;
println!("terminating server");
server_handle.abort();
let _ = server_handle.await;
Ok(())
}

#[tokio::test]
async fn hyper_channel_smoke() -> anyhow::Result<()> {
let addr: SocketAddr = "127.0.0.1:3001".parse()?;
let uri: Uri = "http://127.0.0.1:3001".parse()?;
let server_handle = run_server(&addr);
let _server_handle = run_server(&addr);
let client = HyperConnector::new(uri);
smoke_test(client).await?;
server_handle.abort();
let _ = server_handle.await;
Ok(())
}

Expand Down Expand Up @@ -302,6 +292,5 @@ async fn hyper_channel_errors() -> anyhow::Result<()> {

println!("terminating server");
server_handle.abort();
let _ = server_handle.await;
Ok(())
}
38 changes: 15 additions & 23 deletions tests/iroh-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

use iroh_net::{key::SecretKey, NodeAddr};
use quic_rpc::{transport, RpcClient, RpcServer};
use tokio::task::JoinHandle;
use testresult::TestResult;

use crate::transport::iroh_net::{IrohNetConnector, IrohNetListener};

mod math;
use math::*;
use tokio_util::task::AbortOnDropHandle;
mod util;

const ALPN: &[u8] = b"quic-rpc/iroh-net/test";
Expand Down Expand Up @@ -44,13 +47,10 @@ impl Endpoints {
}
}

fn run_server(server: iroh_net::Endpoint) -> JoinHandle<anyhow::Result<()>> {
tokio::task::spawn(async move {
let connection = transport::iroh_net::IrohNetListener::new(server)?;
let server = RpcServer::new(connection);
ComputeService::server(server).await?;
anyhow::Ok(())
})
fn run_server(server: iroh_net::Endpoint) -> AbortOnDropHandle<()> {
let connection = IrohNetListener::new(server).unwrap();
let server = RpcServer::new(connection);
ComputeService::server(server)
}

// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand All @@ -64,17 +64,12 @@ async fn iroh_net_channel_bench() -> anyhow::Result<()> {
server_node_addr,
} = Endpoints::new().await?;
tracing::debug!("Starting server");
let server_handle = run_server(server);
let _server_handle = run_server(server);
tracing::debug!("Starting client");

let client = RpcClient::new(transport::iroh_net::IrohNetConnector::new(
client,
server_node_addr,
ALPN.into(),
));
let client = RpcClient::new(IrohNetConnector::new(client, server_node_addr, ALPN.into()));
tracing::debug!("Starting benchmark");
bench(client, 50000).await?;
server_handle.abort();
Ok(())
}

Expand All @@ -86,11 +81,9 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> {
server,
server_node_addr,
} = Endpoints::new().await?;
let server_handle = run_server(server);
let client_connection =
transport::iroh_net::IrohNetConnector::new(client, server_node_addr, ALPN.into());
let _server_handle = run_server(server);
let client_connection = IrohNetConnector::new(client, server_node_addr, ALPN.into());
smoke_test(client_connection).await?;
server_handle.abort();
Ok(())
}

Expand All @@ -99,7 +92,7 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> {
///
/// This is a regression test.
#[tokio::test]
async fn server_away_and_back() -> anyhow::Result<()> {
async fn server_away_and_back() -> TestResult<()> {
tracing_subscriber::fmt::try_init().ok();
tracing::info!("Creating endpoints");

Expand Down Expand Up @@ -128,7 +121,7 @@ async fn server_away_and_back() -> anyhow::Result<()> {
// create the RPC Server
let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?;
let server = RpcServer::new(connection);
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1));
let server_handle = tokio::spawn(ComputeService::server_bounded(server, 1));

// wait a bit for connection due to Windows test failing on CI
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
Expand All @@ -151,7 +144,7 @@ async fn server_away_and_back() -> anyhow::Result<()> {
// make the server run again
let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?;
let server = RpcServer::new(connection);
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5));
let server_handle = tokio::spawn(ComputeService::server_bounded(server, 5));

// wait a bit for connection due to Windows test failing on CI
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
Expand All @@ -163,7 +156,6 @@ async fn server_away_and_back() -> anyhow::Result<()> {
// server is running, this should work
let SqrResponse(response) = client.rpc(Sqr(3)).await?;
assert_eq!(response, 9);

server_handle.abort();
Ok(())
}
23 changes: 9 additions & 14 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use quic_rpc::{
};
use serde::{Deserialize, Serialize};
use thousands::Separable;
use tokio_util::task::AbortOnDropHandle;

/// compute the square of a number
#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -163,20 +164,14 @@ impl ComputeService {
}
}

pub async fn server<C: Listener<ComputeService>>(
pub fn server<C: Listener<ComputeService>>(
server: RpcServer<ComputeService, C>,
) -> result::Result<(), RpcServerError<C>> {
let s = server;
let service = ComputeService;
loop {
let (req, chan) = s.accept().await?.read_first().await?;
let service = service.clone();
tokio::spawn(async move { Self::handle_rpc_request(service, req, chan).await });
}
) -> AbortOnDropHandle<()> {
server.spawn_accept_loop(|req, chan| Self::handle_rpc_request(ComputeService, req, chan))
}

pub async fn handle_rpc_request<E>(
service: ComputeService,
self,
req: ComputeRequest,
chan: RpcChannel<ComputeService, E>,
) -> Result<(), RpcServerError<E>>
Expand All @@ -186,10 +181,10 @@ impl ComputeService {
use ComputeRequest::*;
#[rustfmt::skip]
match req {
Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
Sqr(msg) => chan.rpc(msg, self, Self::sqr).await,
Sum(msg) => chan.client_streaming(msg, self, Self::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, self, Self::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, self, Self::multiply).await,
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
}?;
Expand Down
Loading
Loading