Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf/refactor: use tokio_util::sync::PollSender for ActiveSession -> SessionManager messages #4603

Merged
merged 12 commits into from
Sep 22, 2023
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/metrics/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ metrics.workspace = true
# async
tokio = { workspace = true, features = ["full"], optional = true }
futures = { workspace = true, optional = true }
tokio-util = { workspace = true, optional = true }

[features]
common = ["tokio", "futures"]
common = ["tokio", "futures", "tokio-util"]
63 changes: 63 additions & 0 deletions crates/metrics/src/common/mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tokio::sync::mpsc::{
error::{SendError, TryRecvError, TrySendError},
OwnedPermit,
};
use tokio_util::sync::{PollSendError, PollSender};

/// Wrapper around [mpsc::unbounded_channel] that returns a new unbounded metered channel.
pub fn metered_unbounded_channel<T>(
Expand Down Expand Up @@ -265,3 +266,65 @@ struct MeteredReceiverMetrics {
/// Number of messages received
messages_received: Counter,
}

/// A wrapper type around [PollSender](PollSender) that updates metrics on send.
#[derive(Debug)]
pub struct MeteredPollSender<T> {
/// The [PollSender](PollSender) that this wraps around
sender: PollSender<T>,
/// Holds metrics for this type
metrics: MeteredPollSenderMetrics,
}

impl<T: Send + 'static> MeteredPollSender<T> {
/// Creates a new [`MeteredPollSender`] wrapping around the provided [PollSender](PollSender)
pub fn new(sender: PollSender<T>, scope: &'static str) -> Self {
Self { sender, metrics: MeteredPollSenderMetrics::new(scope) }
}

/// Returns the underlying [PollSender](PollSender).
pub fn inner(&self) -> &PollSender<T> {
&self.sender
}

/// Calls the underlying [PollSender](PollSender)'s `poll_reserve`, incrementing the appropriate
/// metrics depending on the result.
pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
match self.sender.poll_reserve(cx) {
Poll::Ready(Ok(permit)) => Poll::Ready(Ok(permit)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => {
self.metrics.back_pressure.increment(1);
Poll::Pending
}
}
}

/// Calls the underlying [PollSender](PollSender)'s `send_item`, incrementing the appropriate
/// metrics depending on the result.
pub fn send_item(&mut self, item: T) -> Result<(), PollSendError<T>> {
match self.sender.send_item(item) {
Ok(()) => {
self.metrics.messages_sent.increment(1);
Ok(())
}
Err(error) => Err(error),
}
}
}

impl<T> Clone for MeteredPollSender<T> {
fn clone(&self) -> Self {
Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
}
}

/// Throughput metrics for [MeteredPollSender]
#[derive(Clone, Metrics)]
#[metrics(dynamic = true)]
struct MeteredPollSenderMetrics {
/// Number of messages sent
messages_sent: Counter,
/// Number of delayed message deliveries caused by a full channel
back_pressure: Counter,
}
62 changes: 30 additions & 32 deletions crates/net/network/src/session/active.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use reth_eth_wire::{
DisconnectReason, EthMessage, EthStream, P2PStream,
};
use reth_interfaces::p2p::error::RequestError;
use reth_metrics::common::mpsc::MeteredSender;
use reth_metrics::common::mpsc::MeteredPollSender;
use reth_net_common::bandwidth_meter::MeteredStream;
use reth_primitives::PeerId;
use std::{
Expand Down Expand Up @@ -77,7 +77,7 @@ pub(crate) struct ActiveSession {
/// Incoming commands from the manager
pub(crate) commands_rx: ReceiverStream<SessionCommand>,
/// Sink to send messages to the [`SessionManager`](super::SessionManager).
pub(crate) to_session_manager: MeteredSender<ActiveSessionMessage>,
pub(crate) to_session_manager: MeteredPollSender<ActiveSessionMessage>,
/// A message that needs to be delivered to the session manager
pub(crate) pending_message_to_session: Option<ActiveSessionMessage>,
/// Incoming internal requests which are delegated to the remote peer.
Expand Down Expand Up @@ -304,8 +304,9 @@ impl ActiveSession {
/// Returns the message if the bounded channel is currently unable to handle this message.
#[allow(clippy::result_large_err)]
fn try_emit_broadcast(&self, message: PeerMessage) -> Result<(), ActiveSessionMessage> {
match self
.to_session_manager
let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };

match sender
.try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
{
Ok(_) => Ok(()),
Expand All @@ -329,8 +330,9 @@ impl ActiveSession {
/// Returns the message if the bounded channel is currently unable to handle this message.
#[allow(clippy::result_large_err)]
fn try_emit_request(&self, message: PeerMessage) -> Result<(), ActiveSessionMessage> {
match self
.to_session_manager
let Some(sender) = self.to_session_manager.inner().get_ref() else { return Ok(()) };

match sender
.try_send(ActiveSessionMessage::ValidMessage { peer_id: self.remote_peer_id, message })
{
Ok(_) => Ok(()),
Expand All @@ -354,9 +356,8 @@ impl ActiveSession {

/// Notify the manager that the peer sent a bad message
fn on_bad_message(&self) {
let _ = self
.to_session_manager
.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
let Some(sender) = self.to_session_manager.inner().get_ref() else { return };
let _ = sender.try_send(ActiveSessionMessage::BadMessage { peer_id: self.remote_peer_id });
}

/// Report back that this session has been closed.
Expand All @@ -367,8 +368,7 @@ impl ActiveSession {
remote_addr: self.remote_addr,
};

self.terminate_message =
Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg));
self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
self.poll_terminate_message(cx).expect("message is set")
}

Expand All @@ -379,8 +379,7 @@ impl ActiveSession {
remote_addr: self.remote_addr,
error,
};
self.terminate_message =
Some((PollSender::new(self.to_session_manager.inner().clone()).clone(), msg));
self.terminate_message = Some((self.to_session_manager.inner().clone(), msg));
self.poll_terminate_message(cx).expect("message is set")
}

Expand Down Expand Up @@ -575,22 +574,19 @@ impl Future for ActiveSession {
}

// try to resend the pending message that we could not send because the channel was
// full.
// full. [`PollSender`] will ensure that we're woken up again when the channel is
// ready to receive the message, and will only error if the channel is closed.
if let Some(msg) = this.pending_message_to_session.take() {
match this.to_session_manager.try_send(msg) {
Ok(_) => {}
Err(err) => {
match err {
TrySendError::Full(msg) => {
this.pending_message_to_session = Some(msg);
// ensure we're woken up again
cx.waker().wake_by_ref();
break 'receive
}
TrySendError::Closed(_) => {}
}
match this.to_session_manager.poll_reserve(cx) {
Poll::Ready(Ok(_)) => {
let _ = this.to_session_manager.send_item(msg);
}
}
Poll::Ready(Err(_)) => return Poll::Ready(()),
Poll::Pending => {
this.pending_message_to_session = Some(msg);
break 'receive
}
};
}

match this.conn.poll_next_unpin(cx) {
Expand Down Expand Up @@ -641,9 +637,10 @@ impl Future for ActiveSession {
while this.internal_request_timeout_interval.poll_tick(cx).is_ready() {
// check for timed out requests
if this.check_timed_out_requests(Instant::now()) {
let _ = this.to_session_manager.clone().try_send(
ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id },
);
if let Poll::Ready(Ok(_)) = this.to_session_manager.poll_reserve(cx) {
let msg = ActiveSessionMessage::ProtocolBreach { peer_id: this.remote_peer_id };
this.pending_message_to_session = Some(msg);
}
}
}

Expand Down Expand Up @@ -865,6 +862,7 @@ mod tests {
} => {
let (_to_session_tx, messages_rx) = mpsc::channel(10);
let (commands_to_session, commands_rx) = mpsc::channel(10);
let poll_sender = PollSender::new(self.active_session_tx.clone());

self.to_sessions.push(commands_to_session);

Expand All @@ -875,8 +873,8 @@ mod tests {
remote_capabilities: Arc::clone(&capabilities),
session_id,
commands_rx: ReceiverStream::new(commands_rx),
to_session_manager: MeteredSender::new(
self.active_session_tx.clone(),
to_session_manager: MeteredPollSender::new(
poll_sender,
"network_active_session",
),
pending_message_to_session: None,
Expand Down
8 changes: 5 additions & 3 deletions crates/net/network/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use reth_eth_wire::{
errors::EthStreamError,
DisconnectReason, EthVersion, HelloMessage, Status, UnauthedEthStream, UnauthedP2PStream,
};
use reth_metrics::common::mpsc::MeteredSender;
use reth_metrics::common::mpsc::MeteredPollSender;
use reth_net_common::{
bandwidth_meter::{BandwidthMeter, MeteredStream},
stream::HasRemoteAddr,
Expand All @@ -34,6 +34,7 @@ use tokio::{
sync::{mpsc, oneshot},
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::PollSender;
use tracing::{instrument, trace};

mod active;
Expand Down Expand Up @@ -95,7 +96,7 @@ pub struct SessionManager {
///
/// When active session state is reached, the corresponding [`ActiveSessionHandle`] will get a
/// clone of this sender half.
active_session_tx: MeteredSender<ActiveSessionMessage>,
active_session_tx: MeteredPollSender<ActiveSessionMessage>,
/// Receiver half that listens for [`ActiveSessionMessage`] produced by pending sessions.
active_session_rx: ReceiverStream<ActiveSessionMessage>,
/// Used to measure inbound & outbound bandwidth across all managed streams
Expand All @@ -119,6 +120,7 @@ impl SessionManager {
) -> Self {
let (pending_sessions_tx, pending_sessions_rx) = mpsc::channel(config.session_event_buffer);
let (active_session_tx, active_session_rx) = mpsc::channel(config.session_event_buffer);
let active_session_tx = PollSender::new(active_session_tx);

Self {
next_id: 0,
Expand All @@ -135,7 +137,7 @@ impl SessionManager {
active_sessions: Default::default(),
pending_sessions_tx,
pending_session_rx: ReceiverStream::new(pending_sessions_rx),
active_session_tx: MeteredSender::new(active_session_tx, "network_active_session"),
active_session_tx: MeteredPollSender::new(active_session_tx, "network_active_session"),
active_session_rx: ReceiverStream::new(active_session_rx),
bandwidth_meter,
metrics: Default::default(),
Expand Down
Loading