diff --git a/Cargo.lock b/Cargo.lock index a7ba9e9d7f9a..563204d60453 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3018,6 +3018,7 @@ dependencies = [ "libp2p-core", "libp2p-identity", "libp2p-noise", + "libp2p-protocol-utils", "libp2p-swarm", "libp2p-swarm-test", "libp2p-tcp", diff --git a/protocols/request-response/Cargo.toml b/protocols/request-response/Cargo.toml index 1bfd03e1520f..cec43cc3ee62 100644 --- a/protocols/request-response/Cargo.toml +++ b/protocols/request-response/Cargo.toml @@ -18,6 +18,7 @@ instant = "0.1.12" libp2p-core = { workspace = true } libp2p-swarm = { workspace = true } libp2p-identity = { workspace = true } +libp2p-protocol-utils = { workspace = true } rand = "0.8" serde = { version = "1.0", optional = true} serde_json = { version = "1.0.108", optional = true } diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index 2d45e0d7dc30..a759022b1cb8 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -28,10 +28,8 @@ use crate::{InboundRequestId, OutboundRequestId, EMPTY_QUEUE_SHRINK_THRESHOLD}; use futures::channel::mpsc; use futures::{channel::oneshot, prelude::*}; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, - ListenUpgradeError, -}; +use libp2p_protocol_utils::InflightProtocolDataQueue; +use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound}; use libp2p_swarm::{ handler::{ConnectionHandler, ConnectionHandlerEvent, StreamUpgradeError}, SubstreamProtocol, @@ -47,6 +45,7 @@ use std::{ task::{Context, Poll}, time::Duration, }; +use void::Void; /// A connection handler for a request response [`Behaviour`](super::Behaviour) protocol. pub struct Handler @@ -59,10 +58,13 @@ where codec: TCodec, /// Queue of events to emit in `poll()`. pending_events: VecDeque>, - /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`. - pending_outbound: VecDeque>, - requested_outbound: VecDeque>, + pending_streams: InflightProtocolDataQueue< + (OutboundRequestId, TCodec::Request), + SmallVec<[TCodec::Protocol; 2]>, + Result<(libp2p_swarm::Stream, TCodec::Protocol), StreamUpgradeError>, + >, + /// A channel for receiving inbound requests. inbound_receiver: mpsc::Receiver<( InboundRequestId, @@ -102,8 +104,7 @@ where Self { inbound_protocols, codec, - pending_outbound: VecDeque::new(), - requested_outbound: Default::default(), + pending_streams: InflightProtocolDataQueue::default(), inbound_receiver, inbound_sender, pending_events: VecDeque::new(), @@ -167,92 +168,6 @@ where tracing::warn!("Dropping inbound stream because we are at capacity") } } - - fn on_fully_negotiated_outbound( - &mut self, - FullyNegotiatedOutbound { - protocol: (mut stream, protocol), - info: (), - }: FullyNegotiatedOutbound< - ::OutboundProtocol, - ::OutboundOpenInfo, - >, - ) { - let message = self - .requested_outbound - .pop_front() - .expect("negotiated a stream without a pending message"); - - let mut codec = self.codec.clone(); - let request_id = message.request_id; - - let send = async move { - let write = codec.write_request(&protocol, &mut stream, message.request); - write.await?; - stream.close().await?; - let read = codec.read_response(&protocol, &mut stream); - let response = read.await?; - - Ok(Event::Response { - request_id, - response, - }) - }; - - if self - .worker_streams - .try_push(RequestId::Outbound(request_id), send.boxed()) - .is_err() - { - tracing::warn!("Dropping outbound stream because we are at capacity") - } - } - - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { error, info: () }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - let message = self - .requested_outbound - .pop_front() - .expect("negotiated a stream without a pending message"); - - match error { - StreamUpgradeError::Timeout => { - self.pending_events - .push_back(Event::OutboundTimeout(message.request_id)); - } - StreamUpgradeError::NegotiationFailed => { - // The remote merely doesn't support the protocol(s) we requested. - // This is no reason to close the connection, which may - // successfully communicate with other protocols already. - // An event is reported to permit user code to react to the fact that - // the remote peer does not support the requested protocol(s). - self.pending_events - .push_back(Event::OutboundUnsupportedProtocols(message.request_id)); - } - StreamUpgradeError::Apply(e) => void::unreachable(e), - StreamUpgradeError::Io(e) => { - tracing::debug!( - "outbound stream for request {} failed: {e}, retrying", - message.request_id - ); - self.requested_outbound.push_back(message); - } - } - } - fn on_listen_upgrade_error( - &mut self, - ListenUpgradeError { error, .. }: ListenUpgradeError< - ::InboundOpenInfo, - ::InboundProtocol, - >, - ) { - void::unreachable(error) - } } /// The events emitted by the [`Handler`]. @@ -382,7 +297,14 @@ where } fn on_behaviour_event(&mut self, request: Self::FromBehaviour) { - self.pending_outbound.push_back(request); + let OutboundMessage { + request_id, + request, + protocols, + } = request; + + self.pending_streams + .enqueue_request(protocols, (request_id, request)); } #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))] @@ -390,74 +312,114 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll, (), Self::ToBehaviour>> { - match self.worker_streams.poll_unpin(cx) { - Poll::Ready((_, Ok(Ok(event)))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); - } - Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::InboundStreamFailed { - request_id: id, - error: e, - }, - )); - } - Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::OutboundStreamFailed { - request_id: id, - error: e, - }, - )); - } - Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::InboundTimeout(id), - )); - } - Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::OutboundTimeout(id), - )); + loop { + match self.worker_streams.poll_unpin(cx) { + Poll::Ready((_, Ok(Ok(event)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundStreamFailed { + request_id: id, + error: e, + }, + )); + } + Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundStreamFailed { + request_id: id, + error: e, + }, + )); + } + Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundTimeout(id), + )); + } + Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundTimeout(id), + )); + } + Poll::Pending => {} } - Poll::Pending => {} - } - - // Drain pending events that were produced by `worker_streams`. - if let Some(event) = self.pending_events.pop_front() { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); - } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { - self.pending_events.shrink_to_fit(); - } - - // Check for inbound requests. - if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) { - // We received an inbound request. - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request { - request_id: id, - request: rq, - sender: rs_sender, - })); - } + // Drain pending events that were produced by `worker_streams`. + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { + self.pending_events.shrink_to_fit(); + } - // Emit outbound requests. - if let Some(request) = self.pending_outbound.pop_front() { - let protocols = request.protocols.clone(); - self.requested_outbound.push_back(request); + // Check for inbound requests. + if let Poll::Ready(Some((id, rq, rs_sender))) = + self.inbound_receiver.poll_next_unpin(cx) + { + // We received an inbound request. + + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request { + request_id: id, + request: rq, + sender: rs_sender, + })); + } - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(Protocol { protocols }, ()), - }); - } + match self.pending_streams.next_completed() { + Some((Ok((mut stream, protocol)), (request_id, request))) => { + let mut codec = self.codec.clone(); + + let send = async move { + let write = codec.write_request(&protocol, &mut stream, request); + write.await?; + stream.close().await?; + let read = codec.read_response(&protocol, &mut stream); + let response = read.await?; + + Ok(Event::Response { + request_id, + response, + }) + }; + + if self + .worker_streams + .try_push(RequestId::Outbound(request_id), send.boxed()) + .is_err() + { + tracing::warn!("Dropping outbound stream because we are at capacity") + } + continue; + } + Some((Err(StreamUpgradeError::Timeout), (request_id, _))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundTimeout(request_id), + )); + } + Some((Err(StreamUpgradeError::NegotiationFailed), (request_id, _))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundUnsupportedProtocols(request_id), + )); + } + Some((Err(StreamUpgradeError::Io(error)), (request_id, _))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundStreamFailed { request_id, error }, + )); + } + Some((Err(StreamUpgradeError::Apply(void)), _)) => void::unreachable(void), + None => {} + } - debug_assert!(self.pending_outbound.is_empty()); + // Emit outbound requests. + if let Some(protocols) = self.pending_streams.next_request() { + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(Protocol { protocols }, ()), + }); + } - if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { - self.pending_outbound.shrink_to_fit(); + return Poll::Pending; } - - Poll::Pending } fn on_connection_event( @@ -473,15 +435,13 @@ where ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { self.on_fully_negotiated_inbound(fully_negotiated_inbound) } - ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => { - self.on_fully_negotiated_outbound(fully_negotiated_outbound) - } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + ConnectionEvent::FullyNegotiatedOutbound(ev) => { + self.pending_streams.submit_response(Ok(ev.protocol)); } - ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { - self.on_listen_upgrade_error(listen_upgrade_error) + ConnectionEvent::DialUpgradeError(ev) => { + self.pending_streams.submit_response(Err(ev.error)); } + ConnectionEvent::ListenUpgradeError(ev) => void::unreachable(ev.error), _ => {} } }