From fd7b5ac29c88253260293571e0b0424c05c1a023 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 27 Oct 2023 04:10:31 +1100 Subject: [PATCH] refactor(request-response): don't use upgrade infrastructure This patch refactors `libp2p-request-response` to not use the "upgrade infrastructure" provided by `libp2p-swarm`. Instead, we directly convert the negotiated streams into futures that read and write the messages. Related: #3268. Related: #2863. Pull-Request: #3914. Co-authored-by: Yiannis Marangos --- Cargo.lock | 3 + examples/file-sharing/src/network.rs | 4 +- protocols/autonat/CHANGELOG.md | 3 + protocols/autonat/src/behaviour.rs | 16 +- protocols/autonat/src/behaviour/as_client.rs | 10 +- protocols/autonat/src/behaviour/as_server.rs | 8 +- protocols/autonat/tests/test_client.rs | 8 +- protocols/autonat/tests/test_server.rs | 9 +- protocols/rendezvous/src/client.rs | 14 +- protocols/request-response/CHANGELOG.md | 8 +- protocols/request-response/Cargo.toml | 3 + protocols/request-response/src/handler.rs | 369 ++++++++---- .../request-response/src/handler/protocol.rs | 117 +--- protocols/request-response/src/lib.rs | 226 ++++--- .../request-response/tests/error_reporting.rs | 555 ++++++++++++++++++ protocols/request-response/tests/ping.rs | 7 +- 16 files changed, 1030 insertions(+), 330 deletions(-) create mode 100644 protocols/request-response/tests/error_reporting.rs diff --git a/Cargo.lock b/Cargo.lock index 72b8ffed191..647e20395bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3007,11 +3007,14 @@ dependencies = [ name = "libp2p-request-response" version = "0.26.0" dependencies = [ + "anyhow", "async-std", "async-trait", "cbor4ii", "env_logger 0.10.0", "futures", + "futures-bounded", + "futures-timer", "futures_ringbuf", "instant", "libp2p-core", diff --git a/examples/file-sharing/src/network.rs b/examples/file-sharing/src/network.rs index 9245c2f8217..2ea16ef180c 100644 --- a/examples/file-sharing/src/network.rs +++ b/examples/file-sharing/src/network.rs @@ -8,7 +8,7 @@ use libp2p::{ identity, kad, multiaddr::Protocol, noise, - request_response::{self, ProtocolSupport, RequestId, ResponseChannel}, + request_response::{self, OutboundRequestId, ProtocolSupport, ResponseChannel}, swarm::{NetworkBehaviour, Swarm, SwarmEvent}, tcp, yamux, PeerId, }; @@ -175,7 +175,7 @@ pub(crate) struct EventLoop { pending_start_providing: HashMap>, pending_get_providers: HashMap>>, pending_request_file: - HashMap, Box>>>, + HashMap, Box>>>, } impl EventLoop { diff --git a/protocols/autonat/CHANGELOG.md b/protocols/autonat/CHANGELOG.md index 852e5da7b89..2b14598bd3e 100644 --- a/protocols/autonat/CHANGELOG.md +++ b/protocols/autonat/CHANGELOG.md @@ -1,5 +1,8 @@ ## 0.12.0 - unreleased +- Remove `Clone`, `PartialEq` and `Eq` implementations on `Event` and its sub-structs. + The `Event` also contains errors which are not clonable or comparable. + See [PR 3914](https://github.com/libp2p/rust-libp2p/pull/3914). ## 0.11.0 diff --git a/protocols/autonat/src/behaviour.rs b/protocols/autonat/src/behaviour.rs index d43ee224fc9..e9a73fd3fcb 100644 --- a/protocols/autonat/src/behaviour.rs +++ b/protocols/autonat/src/behaviour.rs @@ -32,7 +32,7 @@ use instant::Instant; use libp2p_core::{multiaddr::Protocol, ConnectedPoint, Endpoint, Multiaddr}; use libp2p_identity::PeerId; use libp2p_request_response::{ - self as request_response, ProtocolSupport, RequestId, ResponseChannel, + self as request_response, InboundRequestId, OutboundRequestId, ProtocolSupport, ResponseChannel, }; use libp2p_swarm::{ behaviour::{ @@ -133,7 +133,7 @@ impl ProbeId { } /// Event produced by [`Behaviour`]. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub enum Event { /// Event on an inbound probe. InboundProbe(InboundProbeEvent), @@ -187,14 +187,14 @@ pub struct Behaviour { PeerId, ( ProbeId, - RequestId, + InboundRequestId, Vec, ResponseChannel, ), >, // Ongoing outbound probes and mapped to the inner request id. - ongoing_outbound: HashMap, + ongoing_outbound: HashMap, // Connected peers with the observed address of each connection. // If the endpoint of a connection is relayed or not global (in case of Config::only_global_ips), @@ -220,9 +220,11 @@ pub struct Behaviour { impl Behaviour { pub fn new(local_peer_id: PeerId, config: Config) -> Self { let protocols = iter::once((DEFAULT_PROTOCOL_NAME, ProtocolSupport::Full)); - let mut cfg = request_response::Config::default(); - cfg.set_request_timeout(config.timeout); - let inner = request_response::Behaviour::with_codec(AutoNatCodec, protocols, cfg); + let inner = request_response::Behaviour::with_codec( + AutoNatCodec, + protocols, + request_response::Config::default().with_request_timeout(config.timeout), + ); Self { local_peer_id, inner, diff --git a/protocols/autonat/src/behaviour/as_client.rs b/protocols/autonat/src/behaviour/as_client.rs index 45608ea98fd..6f37d32620b 100644 --- a/protocols/autonat/src/behaviour/as_client.rs +++ b/protocols/autonat/src/behaviour/as_client.rs @@ -29,7 +29,7 @@ use futures_timer::Delay; use instant::Instant; use libp2p_core::Multiaddr; use libp2p_identity::PeerId; -use libp2p_request_response::{self as request_response, OutboundFailure, RequestId}; +use libp2p_request_response::{self as request_response, OutboundFailure, OutboundRequestId}; use libp2p_swarm::{ConnectionId, ListenAddresses, ToSwarm}; use rand::{seq::SliceRandom, thread_rng}; use std::{ @@ -39,7 +39,7 @@ use std::{ }; /// Outbound probe failed or was aborted. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub enum OutboundProbeError { /// Probe was aborted because no server is known, or all servers /// are throttled through [`Config::throttle_server_period`]. @@ -53,7 +53,7 @@ pub enum OutboundProbeError { Response(ResponseError), } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub enum OutboundProbeEvent { /// A dial-back request was sent to a remote peer. Request { @@ -91,7 +91,7 @@ pub(crate) struct AsClient<'a> { pub(crate) throttled_servers: &'a mut Vec<(PeerId, Instant)>, pub(crate) nat_status: &'a mut NatStatus, pub(crate) confidence: &'a mut usize, - pub(crate) ongoing_outbound: &'a mut HashMap, + pub(crate) ongoing_outbound: &'a mut HashMap, pub(crate) last_probe: &'a mut Option, pub(crate) schedule_probe: &'a mut Delay, pub(crate) listen_addresses: &'a ListenAddresses, @@ -117,7 +117,7 @@ impl<'a> HandleInnerEvent for AsClient<'a> { let probe_id = self .ongoing_outbound .remove(&request_id) - .expect("RequestId exists."); + .expect("OutboundRequestId exists."); let event = match response.result.clone() { Ok(address) => OutboundProbeEvent::Response { diff --git a/protocols/autonat/src/behaviour/as_server.rs b/protocols/autonat/src/behaviour/as_server.rs index b4c67a6a350..65c9738647e 100644 --- a/protocols/autonat/src/behaviour/as_server.rs +++ b/protocols/autonat/src/behaviour/as_server.rs @@ -26,7 +26,7 @@ use instant::Instant; use libp2p_core::{multiaddr::Protocol, Multiaddr}; use libp2p_identity::PeerId; use libp2p_request_response::{ - self as request_response, InboundFailure, RequestId, ResponseChannel, + self as request_response, InboundFailure, InboundRequestId, ResponseChannel, }; use libp2p_swarm::{ dial_opts::{DialOpts, PeerCondition}, @@ -38,7 +38,7 @@ use std::{ }; /// Inbound probe failed. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub enum InboundProbeError { /// Receiving the dial-back request or sending a response failed. InboundRequest(InboundFailure), @@ -46,7 +46,7 @@ pub enum InboundProbeError { Response(ResponseError), } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub enum InboundProbeEvent { /// A dial-back request was received from a remote peer. Request { @@ -85,7 +85,7 @@ pub(crate) struct AsServer<'a> { PeerId, ( ProbeId, - RequestId, + InboundRequestId, Vec, ResponseChannel, ), diff --git a/protocols/autonat/tests/test_client.rs b/protocols/autonat/tests/test_client.rs index 1911d1a6b2d..743f4cc1b51 100644 --- a/protocols/autonat/tests/test_client.rs +++ b/protocols/autonat/tests/test_client.rs @@ -61,7 +61,7 @@ async fn test_auto_probe() { match client.next_behaviour_event().await { Event::OutboundProbe(OutboundProbeEvent::Error { peer, error, .. }) => { assert!(peer.is_none()); - assert_eq!(error, OutboundProbeError::NoAddresses); + assert!(matches!(error, OutboundProbeError::NoAddresses)); } other => panic!("Unexpected behaviour event: {other:?}."), } @@ -181,10 +181,10 @@ async fn test_confidence() { peer, error, } if !test_public => { - assert_eq!( + assert!(matches!( error, OutboundProbeError::Response(ResponseError::DialError) - ); + )); (peer.unwrap(), probe_id) } other => panic!("Unexpected Outbound Event: {other:?}"), @@ -261,7 +261,7 @@ async fn test_throttle_server_period() { match client.next_behaviour_event().await { Event::OutboundProbe(OutboundProbeEvent::Error { peer, error, .. }) => { assert!(peer.is_none()); - assert_eq!(error, OutboundProbeError::NoServer); + assert!(matches!(error, OutboundProbeError::NoServer)); } other => panic!("Unexpected behaviour event: {other:?}."), } diff --git a/protocols/autonat/tests/test_server.rs b/protocols/autonat/tests/test_server.rs index c952179d42c..b0610ef59a4 100644 --- a/protocols/autonat/tests/test_server.rs +++ b/protocols/autonat/tests/test_server.rs @@ -168,7 +168,10 @@ async fn test_dial_error() { }) => { assert_eq!(probe_id, request_probe_id); assert_eq!(peer, client_id); - assert_eq!(error, InboundProbeError::Response(ResponseError::DialError)); + assert!(matches!( + error, + InboundProbeError::Response(ResponseError::DialError) + )); } other => panic!("Unexpected behaviour event: {other:?}."), } @@ -252,10 +255,10 @@ async fn test_throttle_peer_max() { }) => { assert_eq!(client_id, peer); assert_ne!(first_probe_id, probe_id); - assert_eq!( + assert!(matches!( error, InboundProbeError::Response(ResponseError::DialRefused) - ) + )); } other => panic!("Unexpected behaviour event: {other:?}."), }; diff --git a/protocols/rendezvous/src/client.rs b/protocols/rendezvous/src/client.rs index ec573e5ae4d..e4aedd9da7a 100644 --- a/protocols/rendezvous/src/client.rs +++ b/protocols/rendezvous/src/client.rs @@ -26,7 +26,7 @@ use futures::stream::FuturesUnordered; use futures::stream::StreamExt; use libp2p_core::{Endpoint, Multiaddr, PeerRecord}; use libp2p_identity::{Keypair, PeerId, SigningError}; -use libp2p_request_response::{ProtocolSupport, RequestId}; +use libp2p_request_response::{OutboundRequestId, ProtocolSupport}; use libp2p_swarm::{ ConnectionDenied, ConnectionId, ExternalAddresses, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm, @@ -41,8 +41,8 @@ pub struct Behaviour { keypair: Keypair, - waiting_for_register: HashMap, - waiting_for_discovery: HashMap)>, + waiting_for_register: HashMap, + waiting_for_discovery: HashMap)>, /// Hold addresses of all peers that we have discovered so far. /// @@ -336,7 +336,7 @@ impl NetworkBehaviour for Behaviour { } impl Behaviour { - fn event_for_outbound_failure(&mut self, req_id: &RequestId) -> Option { + fn event_for_outbound_failure(&mut self, req_id: &OutboundRequestId) -> Option { if let Some((rendezvous_node, namespace)) = self.waiting_for_register.remove(req_id) { return Some(Event::RegisterFailed { rendezvous_node, @@ -356,7 +356,11 @@ impl Behaviour { None } - fn handle_response(&mut self, request_id: &RequestId, response: Message) -> Option { + fn handle_response( + &mut self, + request_id: &OutboundRequestId, + response: Message, + ) -> Option { match response { RegisterResponse(Ok(ttl)) => { if let Some((rendezvous_node, namespace)) = diff --git a/protocols/request-response/CHANGELOG.md b/protocols/request-response/CHANGELOG.md index 34dc0704198..138401c2f50 100644 --- a/protocols/request-response/CHANGELOG.md +++ b/protocols/request-response/CHANGELOG.md @@ -2,7 +2,13 @@ - Remove `request_response::Config::set_connection_keep_alive` in favor of `SwarmBuilder::idle_connection_timeout`. See [PR 4679](https://github.com/libp2p/rust-libp2p/pull/4679). - +- Allow at most 100 concurrent inbound + outbound streams per instance of `request_response::Behaviour`. + This limit is configurable via `Config::with_max_concurrent_streams`. + See [PR 3914](https://github.com/libp2p/rust-libp2p/pull/3914). +- Report IO failures on inbound and outbound streams. + See [PR 3914](https://github.com/libp2p/rust-libp2p/pull/3914). +- Introduce dedicated types for `InboundRequestId` and `OutboundRequestId`. + See [PR 3914](https://github.com/libp2p/rust-libp2p/pull/3914). - Keep peer addresses in `HashSet` instead of `SmallVec` to prevent adding duplicate addresses. See [PR 4700](https://github.com/libp2p/rust-libp2p/pull/4700). diff --git a/protocols/request-response/Cargo.toml b/protocols/request-response/Cargo.toml index 823c7f49656..5c894bcd60f 100644 --- a/protocols/request-response/Cargo.toml +++ b/protocols/request-response/Cargo.toml @@ -24,12 +24,15 @@ serde_json = { version = "1.0.107", optional = true } smallvec = "1.11.1" void = "1.0.2" log = "0.4.20" +futures-timer = "3.0.2" +futures-bounded = { workspace = true } [features] json = ["dep:serde", "dep:serde_json", "libp2p-swarm/macros"] cbor = ["dep:serde", "dep:cbor4ii", "libp2p-swarm/macros"] [dev-dependencies] +anyhow = "1.0.75" async-std = { version = "1.6.2", features = ["attributes"] } env_logger = "0.10.0" libp2p-noise = { workspace = true } diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index 3a5fa8b0e61..f4f5bf96c6c 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -23,10 +23,11 @@ pub(crate) mod protocol; pub use protocol::ProtocolSupport; use crate::codec::Codec; -use crate::handler::protocol::{RequestProtocol, ResponseProtocol}; -use crate::{RequestId, EMPTY_QUEUE_SHRINK_THRESHOLD}; +use crate::handler::protocol::Protocol; +use crate::{InboundRequestId, OutboundRequestId, EMPTY_QUEUE_SHRINK_THRESHOLD}; -use futures::{channel::oneshot, future::BoxFuture, prelude::*, stream::FuturesUnordered}; +use futures::channel::mpsc; +use futures::{channel::oneshot, prelude::*}; use libp2p_swarm::handler::{ ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, @@ -38,7 +39,7 @@ use libp2p_swarm::{ use smallvec::SmallVec; use std::{ collections::VecDeque, - fmt, + fmt, io, sync::{ atomic::{AtomicU64, Ordering}, Arc, @@ -56,27 +57,34 @@ where inbound_protocols: SmallVec<[TCodec::Protocol; 2]>, /// The request/response message codec. codec: TCodec, - /// The timeout for inbound and outbound substreams (i.e. request - /// and response processing). - substream_timeout: Duration, /// Queue of events to emit in `poll()`. pending_events: VecDeque>, /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`. - outbound: VecDeque>, - /// Inbound upgrades waiting for the incoming request. - inbound: FuturesUnordered< - BoxFuture< - 'static, - Result< - ( - (RequestId, TCodec::Request), - oneshot::Sender, - ), - oneshot::Canceled, - >, - >, - >, + pending_outbound: VecDeque>, + + requested_outbound: VecDeque>, + /// A channel for receiving inbound requests. + inbound_receiver: mpsc::Receiver<( + InboundRequestId, + TCodec::Request, + oneshot::Sender, + )>, + /// The [`mpsc::Sender`] for the above receiver. Cloned for each inbound request. + inbound_sender: mpsc::Sender<( + InboundRequestId, + TCodec::Request, + oneshot::Sender, + )>, + inbound_request_id: Arc, + + worker_streams: futures_bounded::FuturesMap, io::Error>>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RequestId { + Inbound(InboundRequestId), + Outbound(OutboundRequestId), } impl Handler @@ -88,47 +96,134 @@ where codec: TCodec, substream_timeout: Duration, inbound_request_id: Arc, + max_concurrent_streams: usize, ) -> Self { + let (inbound_sender, inbound_receiver) = mpsc::channel(0); Self { inbound_protocols, codec, - substream_timeout, - outbound: VecDeque::new(), - inbound: FuturesUnordered::new(), + pending_outbound: VecDeque::new(), + requested_outbound: Default::default(), + inbound_receiver, + inbound_sender, pending_events: VecDeque::new(), inbound_request_id, + worker_streams: futures_bounded::FuturesMap::new( + substream_timeout, + max_concurrent_streams, + ), } } + /// Returns the next inbound request ID. + fn next_inbound_request_id(&mut self) -> InboundRequestId { + InboundRequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed)) + } + fn on_fully_negotiated_inbound( &mut self, FullyNegotiatedInbound { - protocol: sent, - info: request_id, + protocol: (mut stream, protocol), + info: (), }: FullyNegotiatedInbound< ::InboundProtocol, ::InboundOpenInfo, >, ) { - if sent { - self.pending_events - .push_back(Event::ResponseSent(request_id)) - } else { - self.pending_events - .push_back(Event::ResponseOmission(request_id)) + let mut codec = self.codec.clone(); + let request_id = self.next_inbound_request_id(); + let mut sender = self.inbound_sender.clone(); + + let recv = async move { + // A channel for notifying the inbound upgrade when the + // response is sent. + let (rs_send, rs_recv) = oneshot::channel(); + + let read = codec.read_request(&protocol, &mut stream); + let request = read.await?; + sender + .send((request_id, request, rs_send)) + .await + .expect("`ConnectionHandler` owns both ends of the channel"); + drop(sender); + + if let Ok(response) = rs_recv.await { + let write = codec.write_response(&protocol, &mut stream, response); + write.await?; + + stream.close().await?; + Ok(Event::ResponseSent(request_id)) + } else { + stream.close().await?; + Ok(Event::ResponseOmission(request_id)) + } + }; + + if self + .worker_streams + .try_push(RequestId::Inbound(request_id), recv.boxed()) + .is_err() + { + log::warn!("Dropping inbound stream because we are at capacity") + } + } + + fn on_fully_negotiated_outbound( + &mut self, + FullyNegotiatedOutbound { + protocol: (mut stream, protocol), + info: (), + }: FullyNegotiatedOutbound< + ::OutboundProtocol, + ::OutboundOpenInfo, + >, + ) { + let message = self + .requested_outbound + .pop_front() + .expect("negotiated a stream without a pending message"); + + let mut codec = self.codec.clone(); + let request_id = message.request_id; + + let send = async move { + let write = codec.write_request(&protocol, &mut stream, message.request); + write.await?; + stream.close().await?; + let read = codec.read_response(&protocol, &mut stream); + let response = read.await?; + + Ok(Event::Response { + request_id, + response, + }) + }; + + if self + .worker_streams + .try_push(RequestId::Outbound(request_id), send.boxed()) + .is_err() + { + log::warn!("Dropping outbound stream because we are at capacity") } } fn on_dial_upgrade_error( &mut self, - DialUpgradeError { info, error }: DialUpgradeError< + DialUpgradeError { error, info: () }: DialUpgradeError< ::OutboundOpenInfo, ::OutboundProtocol, >, ) { + let message = self + .requested_outbound + .pop_front() + .expect("negotiated a stream without a pending message"); + match error { StreamUpgradeError::Timeout => { - self.pending_events.push_back(Event::OutboundTimeout(info)); + self.pending_events + .push_back(Event::OutboundTimeout(message.request_id)); } StreamUpgradeError::NegotiationFailed => { // The remote merely doesn't support the protocol(s) we requested. @@ -137,24 +232,26 @@ where // An event is reported to permit user code to react to the fact that // the remote peer does not support the requested protocol(s). self.pending_events - .push_back(Event::OutboundUnsupportedProtocols(info)); - } - StreamUpgradeError::Apply(e) => { - log::debug!("outbound stream {info} failed: {e}"); + .push_back(Event::OutboundUnsupportedProtocols(message.request_id)); } + StreamUpgradeError::Apply(e) => void::unreachable(e), StreamUpgradeError::Io(e) => { - log::debug!("outbound stream {info} failed: {e}"); + log::debug!( + "outbound stream for request {} failed: {e}, retrying", + message.request_id + ); + self.requested_outbound.push_back(message); } } } fn on_listen_upgrade_error( &mut self, - ListenUpgradeError { error, info }: ListenUpgradeError< + ListenUpgradeError { error, .. }: ListenUpgradeError< ::InboundOpenInfo, ::InboundProtocol, >, ) { - log::debug!("inbound stream {info} failed: {error}"); + void::unreachable(error) } } @@ -165,25 +262,36 @@ where { /// A request has been received. Request { - request_id: RequestId, + request_id: InboundRequestId, request: TCodec::Request, sender: oneshot::Sender, }, /// A response has been received. Response { - request_id: RequestId, + request_id: OutboundRequestId, response: TCodec::Response, }, /// A response to an inbound request has been sent. - ResponseSent(RequestId), + ResponseSent(InboundRequestId), /// A response to an inbound request was omitted as a result /// of dropping the response `sender` of an inbound `Request`. - ResponseOmission(RequestId), + ResponseOmission(InboundRequestId), /// An outbound request timed out while sending the request /// or waiting for the response. - OutboundTimeout(RequestId), + OutboundTimeout(OutboundRequestId), /// An outbound request failed to negotiate a mutually supported protocol. - OutboundUnsupportedProtocols(RequestId), + OutboundUnsupportedProtocols(OutboundRequestId), + OutboundStreamFailed { + request_id: OutboundRequestId, + error: io::Error, + }, + /// An inbound request timed out while waiting for the request + /// or sending the response. + InboundTimeout(InboundRequestId), + InboundStreamFailed { + request_id: InboundRequestId, + error: io::Error, + }, } impl fmt::Debug for Event { @@ -220,67 +328,103 @@ impl fmt::Debug for Event { .debug_tuple("Event::OutboundUnsupportedProtocols") .field(request_id) .finish(), + Event::OutboundStreamFailed { request_id, error } => f + .debug_struct("Event::OutboundStreamFailed") + .field("request_id", &request_id) + .field("error", &error) + .finish(), + Event::InboundTimeout(request_id) => f + .debug_tuple("Event::InboundTimeout") + .field(request_id) + .finish(), + Event::InboundStreamFailed { request_id, error } => f + .debug_struct("Event::InboundStreamFailed") + .field("request_id", &request_id) + .field("error", &error) + .finish(), } } } +pub struct OutboundMessage { + pub(crate) request_id: OutboundRequestId, + pub(crate) request: TCodec::Request, + pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, +} + +impl fmt::Debug for OutboundMessage +where + TCodec: Codec, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OutboundMessage").finish_non_exhaustive() + } +} + impl ConnectionHandler for Handler where TCodec: Codec + Send + Clone + 'static, { - type FromBehaviour = RequestProtocol; + type FromBehaviour = OutboundMessage; type ToBehaviour = Event; type Error = void::Void; - type InboundProtocol = ResponseProtocol; - type OutboundProtocol = RequestProtocol; - type OutboundOpenInfo = RequestId; - type InboundOpenInfo = RequestId; + type InboundProtocol = Protocol; + type OutboundProtocol = Protocol; + type OutboundOpenInfo = (); + type InboundOpenInfo = (); fn listen_protocol(&self) -> SubstreamProtocol { - // A channel for notifying the handler when the inbound - // upgrade received the request. - let (rq_send, rq_recv) = oneshot::channel(); - - // A channel for notifying the inbound upgrade when the - // response is sent. - let (rs_send, rs_recv) = oneshot::channel(); - - let request_id = RequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed)); - - // By keeping all I/O inside the `ResponseProtocol` and thus the - // inbound substream upgrade via above channels, we ensure that it - // is all subject to the configured timeout without extra bookkeeping - // for inbound substreams as well as their timeouts and also make the - // implementation of inbound and outbound upgrades symmetric in - // this sense. - let proto = ResponseProtocol { - protocols: self.inbound_protocols.clone(), - codec: self.codec.clone(), - request_sender: rq_send, - response_receiver: rs_recv, - request_id, - }; - - // The handler waits for the request to come in. It then emits - // `Event::Request` together with a - // `ResponseChannel`. - self.inbound - .push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed()); - - SubstreamProtocol::new(proto, request_id).with_timeout(self.substream_timeout) + SubstreamProtocol::new( + Protocol { + protocols: self.inbound_protocols.clone(), + }, + (), + ) } fn on_behaviour_event(&mut self, request: Self::FromBehaviour) { - self.outbound.push_back(request); + self.pending_outbound.push_back(request); } fn poll( &mut self, cx: &mut Context<'_>, - ) -> Poll< - ConnectionHandlerEvent, RequestId, Self::ToBehaviour, Self::Error>, - > { - // Drain pending events. + ) -> Poll, (), Self::ToBehaviour, Self::Error>> + { + match self.worker_streams.poll_unpin(cx) { + Poll::Ready((_, Ok(Ok(event)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundStreamFailed { + request_id: id, + error: e, + }, + )); + } + Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundStreamFailed { + request_id: id, + error: e, + }, + )); + } + Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundTimeout(id), + )); + } + Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundTimeout(id), + )); + } + Poll::Pending => {} + } + + // Drain pending events that were produced by `worker_streams`. if let Some(event) = self.pending_events.pop_front() { return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { @@ -288,37 +432,30 @@ where } // Check for inbound requests. - while let Poll::Ready(Some(result)) = self.inbound.poll_next_unpin(cx) { - match result { - Ok(((id, rq), rs_sender)) => { - // We received an inbound request. - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request { - request_id: id, - request: rq, - sender: rs_sender, - })); - } - Err(oneshot::Canceled) => { - // The inbound upgrade has errored or timed out reading - // or waiting for the request. The handler is informed - // via `on_connection_event` call with `ConnectionEvent::ListenUpgradeError`. - } - } + if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) { + // We received an inbound request. + + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request { + request_id: id, + request: rq, + sender: rs_sender, + })); } // Emit outbound requests. - if let Some(request) = self.outbound.pop_front() { - let info = request.request_id; + if let Some(request) = self.pending_outbound.pop_front() { + let protocols = request.protocols.clone(); + self.requested_outbound.push_back(request); + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(request, info) - .with_timeout(self.substream_timeout), + protocol: SubstreamProtocol::new(Protocol { protocols }, ()), }); } - debug_assert!(self.outbound.is_empty()); + debug_assert!(self.pending_outbound.is_empty()); - if self.outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { - self.outbound.shrink_to_fit(); + if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { + self.pending_outbound.shrink_to_fit(); } Poll::Pending @@ -337,14 +474,8 @@ where ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { self.on_fully_negotiated_inbound(fully_negotiated_inbound) } - ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { - protocol: response, - info: request_id, - }) => { - self.pending_events.push_back(Event::Response { - request_id, - response, - }); + ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => { + self.on_fully_negotiated_outbound(fully_negotiated_outbound) } ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) diff --git a/protocols/request-response/src/handler/protocol.rs b/protocols/request-response/src/handler/protocol.rs index 1368a3c1f98..833cacdd6ce 100644 --- a/protocols/request-response/src/handler/protocol.rs +++ b/protocols/request-response/src/handler/protocol.rs @@ -23,14 +23,10 @@ //! receives a request and sends a response, whereas the //! outbound upgrade send a request and receives a response. -use crate::codec::Codec; -use crate::RequestId; - -use futures::{channel::oneshot, future::BoxFuture, prelude::*}; +use futures::future::{ready, Ready}; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use libp2p_swarm::Stream; use smallvec::SmallVec; -use std::{fmt, io}; /// The level of support for a particular protocol. #[derive(Debug, Clone)] @@ -65,22 +61,15 @@ impl ProtocolSupport { /// /// Receives a request and sends a response. #[derive(Debug)] -pub struct ResponseProtocol -where - TCodec: Codec, -{ - pub(crate) codec: TCodec, - pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, - pub(crate) request_sender: oneshot::Sender<(RequestId, TCodec::Request)>, - pub(crate) response_receiver: oneshot::Receiver, - pub(crate) request_id: RequestId, +pub struct Protocol

{ + pub(crate) protocols: SmallVec<[P; 2]>, } -impl UpgradeInfo for ResponseProtocol +impl

UpgradeInfo for Protocol

where - TCodec: Codec, + P: AsRef + Clone, { - type Info = TCodec::Protocol; + type Info = P; type InfoIter = smallvec::IntoIter<[Self::Info; 2]>; fn protocol_info(&self) -> Self::InfoIter { @@ -88,94 +77,28 @@ where } } -impl InboundUpgrade for ResponseProtocol +impl

InboundUpgrade for Protocol

where - TCodec: Codec + Send + 'static, + P: AsRef + Clone, { - type Output = bool; - type Error = io::Error; - type Future = BoxFuture<'static, Result>; - - fn upgrade_inbound(mut self, mut io: Stream, protocol: Self::Info) -> Self::Future { - async move { - let read = self.codec.read_request(&protocol, &mut io); - let request = read.await?; - match self.request_sender.send((self.request_id, request)) { - Ok(()) => {}, - Err(_) => panic!( - "Expect request receiver to be alive i.e. protocol handler to be alive.", - ), - } + type Output = (Stream, P); + type Error = void::Void; + type Future = Ready>; - if let Ok(response) = self.response_receiver.await { - let write = self.codec.write_response(&protocol, &mut io, response); - write.await?; - - io.close().await?; - // Response was sent. Indicate to handler to emit a `ResponseSent` event. - Ok(true) - } else { - io.close().await?; - // No response was sent. Indicate to handler to emit a `ResponseOmission` event. - Ok(false) - } - }.boxed() + fn upgrade_inbound(self, io: Stream, protocol: Self::Info) -> Self::Future { + ready(Ok((io, protocol))) } } -/// Request substream upgrade protocol. -/// -/// Sends a request and receives a response. -pub struct RequestProtocol +impl

OutboundUpgrade for Protocol

where - TCodec: Codec, + P: AsRef + Clone, { - pub(crate) codec: TCodec, - pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, - pub(crate) request_id: RequestId, - pub(crate) request: TCodec::Request, -} + type Output = (Stream, P); + type Error = void::Void; + type Future = Ready>; -impl fmt::Debug for RequestProtocol -where - TCodec: Codec, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RequestProtocol") - .field("request_id", &self.request_id) - .finish() - } -} - -impl UpgradeInfo for RequestProtocol -where - TCodec: Codec, -{ - type Info = TCodec::Protocol; - type InfoIter = smallvec::IntoIter<[Self::Info; 2]>; - - fn protocol_info(&self) -> Self::InfoIter { - self.protocols.clone().into_iter() - } -} - -impl OutboundUpgrade for RequestProtocol -where - TCodec: Codec + Send + 'static, -{ - type Output = TCodec::Response; - type Error = io::Error; - type Future = BoxFuture<'static, Result>; - - fn upgrade_outbound(mut self, mut io: Stream, protocol: Self::Info) -> Self::Future { - async move { - let write = self.codec.write_request(&protocol, &mut io, self.request); - write.await?; - io.close().await?; - let read = self.codec.read_response(&protocol, &mut io); - let response = read.await?; - Ok(response) - } - .boxed() + fn upgrade_outbound(self, io: Stream, protocol: Self::Info) -> Self::Future { + ready(Ok((io, protocol))) } } diff --git a/protocols/request-response/src/lib.rs b/protocols/request-response/src/lib.rs index 42aa12774ea..f036fb85956 100644 --- a/protocols/request-response/src/lib.rs +++ b/protocols/request-response/src/lib.rs @@ -76,7 +76,7 @@ pub mod json; pub use codec::Codec; pub use handler::ProtocolSupport; -use crate::handler::protocol::RequestProtocol; +use crate::handler::OutboundMessage; use futures::channel::oneshot; use handler::Handler; use libp2p_core::{ConnectedPoint, Endpoint, Multiaddr}; @@ -90,7 +90,7 @@ use libp2p_swarm::{ use smallvec::SmallVec; use std::{ collections::{HashMap, HashSet, VecDeque}, - fmt, + fmt, io, sync::{atomic::AtomicU64, Arc}, task::{Context, Poll}, time::Duration, @@ -102,7 +102,7 @@ pub enum Message { /// A request message. Request { /// The ID of this request. - request_id: RequestId, + request_id: InboundRequestId, /// The request message. request: TRequest, /// The channel waiting for the response. @@ -117,7 +117,7 @@ pub enum Message { /// The ID of the request that produced this response. /// /// See [`Behaviour::send_request`]. - request_id: RequestId, + request_id: OutboundRequestId, /// The response message. response: TResponse, }, @@ -138,7 +138,7 @@ pub enum Event { /// The peer to whom the request was sent. peer: PeerId, /// The (local) ID of the failed request. - request_id: RequestId, + request_id: OutboundRequestId, /// The error that occurred. error: OutboundFailure, }, @@ -147,7 +147,7 @@ pub enum Event { /// The peer from whom the request was received. peer: PeerId, /// The ID of the failed inbound request. - request_id: RequestId, + request_id: InboundRequestId, /// The error that occurred. error: InboundFailure, }, @@ -159,13 +159,13 @@ pub enum Event { /// The peer to whom the response was sent. peer: PeerId, /// The ID of the inbound request whose response was sent. - request_id: RequestId, + request_id: InboundRequestId, }, } /// Possible failures occurring in the context of sending /// an outbound request and receiving the response. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub enum OutboundFailure { /// The request could not be sent because a dialing attempt failed. DialFailure, @@ -181,6 +181,8 @@ pub enum OutboundFailure { ConnectionClosed, /// The remote supports none of the requested protocols. UnsupportedProtocols, + /// An IO failure happened on an outbound stream. + Io(io::Error), } impl fmt::Display for OutboundFailure { @@ -194,6 +196,7 @@ impl fmt::Display for OutboundFailure { OutboundFailure::UnsupportedProtocols => { write!(f, "The remote supports none of the requested protocols") } + OutboundFailure::Io(e) => write!(f, "IO error on outbound stream: {e}"), } } } @@ -202,7 +205,7 @@ impl std::error::Error for OutboundFailure {} /// Possible failures occurring in the context of receiving an /// inbound request and sending a response. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug)] pub enum InboundFailure { /// The inbound request timed out, either while reading the /// incoming request or before a response is sent, e.g. if @@ -218,6 +221,8 @@ pub enum InboundFailure { /// due to the [`ResponseChannel`] being dropped instead of /// being passed to [`Behaviour::send_response`]. ResponseOmission, + /// An IO failure happened on an inbound stream. + Io(io::Error), } impl fmt::Display for InboundFailure { @@ -237,6 +242,7 @@ impl fmt::Display for InboundFailure { f, "The response channel was dropped without sending a response to the remote" ), + InboundFailure::Io(e) => write!(f, "IO error on inbound stream: {e}"), } } } @@ -264,17 +270,27 @@ impl ResponseChannel { } } -/// The ID of an inbound or outbound request. +/// The ID of an inbound request. +/// +/// Note: [`InboundRequestId`]'s uniqueness is only guaranteed between +/// inbound requests of the same originating [`Behaviour`]. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct InboundRequestId(u64); + +impl fmt::Display for InboundRequestId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// The ID of an outbound request. /// -/// Note: [`RequestId`]'s uniqueness is only guaranteed between two -/// inbound and likewise between two outbound requests. There is no -/// uniqueness guarantee in a set of both inbound and outbound -/// [`RequestId`]s nor in a set of inbound or outbound requests -/// originating from different [`Behaviour`]'s. +/// Note: [`OutboundRequestId`]'s uniqueness is only guaranteed between +/// outbound requests of the same originating [`Behaviour`]. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub struct RequestId(u64); +pub struct OutboundRequestId(u64); -impl fmt::Display for RequestId { +impl fmt::Display for OutboundRequestId { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0) } @@ -284,22 +300,37 @@ impl fmt::Display for RequestId { #[derive(Debug, Clone)] pub struct Config { request_timeout: Duration, + max_concurrent_streams: usize, } impl Default for Config { fn default() -> Self { Self { request_timeout: Duration::from_secs(10), + max_concurrent_streams: 100, } } } impl Config { /// Sets the timeout for inbound and outbound requests. + #[deprecated(note = "Use `Config::with_request_timeout` for one-liner constructions.")] pub fn set_request_timeout(&mut self, v: Duration) -> &mut Self { self.request_timeout = v; self } + + /// Sets the timeout for inbound and outbound requests. + pub fn with_request_timeout(mut self, v: Duration) -> Self { + self.request_timeout = v; + self + } + + /// Sets the upper bound for the number of concurrent inbound + outbound streams. + pub fn with_max_concurrent_streams(mut self, num_streams: usize) -> Self { + self.max_concurrent_streams = num_streams; + self + } } /// A request/response protocol for some message codec. @@ -312,16 +343,16 @@ where /// The supported outbound protocols. outbound_protocols: SmallVec<[TCodec::Protocol; 2]>, /// The next (local) request ID. - next_request_id: RequestId, + next_outbound_request_id: OutboundRequestId, /// The next (inbound) request ID. - next_inbound_id: Arc, + next_inbound_request_id: Arc, /// The protocol configuration. config: Config, /// The protocol codec for reading and writing requests and responses. codec: TCodec, /// Pending events to return from `poll`. pending_events: - VecDeque, RequestProtocol>>, + VecDeque, OutboundMessage>>, /// The currently connected peers, their pending outbound and inbound responses and their known, /// reachable addresses, if any. connected: HashMap>, @@ -329,7 +360,7 @@ where addresses: HashMap>, /// Requests that have not yet been sent and are waiting for a connection /// to be established. - pending_outbound_requests: HashMap; 10]>>, + pending_outbound_requests: HashMap; 10]>>, } impl Behaviour @@ -368,8 +399,8 @@ where Behaviour { inbound_protocols, outbound_protocols, - next_request_id: RequestId(1), - next_inbound_id: Arc::new(AtomicU64::new(1)), + next_outbound_request_id: OutboundRequestId(1), + next_inbound_request_id: Arc::new(AtomicU64::new(1)), config: cfg, codec, pending_events: VecDeque::new(), @@ -391,13 +422,12 @@ where /// > address discovery, or known addresses of peers must be /// > managed via [`Behaviour::add_address`] and /// > [`Behaviour::remove_address`]. - pub fn send_request(&mut self, peer: &PeerId, request: TCodec::Request) -> RequestId { - let request_id = self.next_request_id(); - let request = RequestProtocol { + pub fn send_request(&mut self, peer: &PeerId, request: TCodec::Request) -> OutboundRequestId { + let request_id = self.next_outbound_request_id(); + let request = OutboundMessage { request_id, - codec: self.codec.clone(), - protocols: self.outbound_protocols.clone(), request, + protocols: self.outbound_protocols.clone(), }; if let Some(request) = self.try_send_request(peer, request) { @@ -468,14 +498,14 @@ where /// Checks whether an outbound request to the peer with the provided /// [`PeerId`] initiated by [`Behaviour::send_request`] is still /// pending, i.e. waiting for a response. - pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &RequestId) -> bool { + pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &OutboundRequestId) -> bool { // Check if request is already sent on established connection. let est_conn = self .connected .get(peer) .map(|cs| { cs.iter() - .any(|c| c.pending_inbound_responses.contains(request_id)) + .any(|c| c.pending_outbound_responses.contains(request_id)) }) .unwrap_or(false); // Check if request is still pending to be sent. @@ -491,20 +521,20 @@ where /// Checks whether an inbound request from the peer with the provided /// [`PeerId`] is still pending, i.e. waiting for a response by the local /// node through [`Behaviour::send_response`]. - pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &RequestId) -> bool { + pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &InboundRequestId) -> bool { self.connected .get(peer) .map(|cs| { cs.iter() - .any(|c| c.pending_outbound_responses.contains(request_id)) + .any(|c| c.pending_inbound_responses.contains(request_id)) }) .unwrap_or(false) } - /// Returns the next request ID. - fn next_request_id(&mut self) -> RequestId { - let request_id = self.next_request_id; - self.next_request_id.0 += 1; + /// Returns the next outbound request ID. + fn next_outbound_request_id(&mut self) -> OutboundRequestId { + let request_id = self.next_outbound_request_id; + self.next_outbound_request_id.0 += 1; request_id } @@ -514,15 +544,15 @@ where fn try_send_request( &mut self, peer: &PeerId, - request: RequestProtocol, - ) -> Option> { + request: OutboundMessage, + ) -> Option> { if let Some(connections) = self.connected.get_mut(peer) { if connections.is_empty() { return Some(request); } let ix = (request.request_id.0 as usize) % connections.len(); let conn = &mut connections[ix]; - conn.pending_inbound_responses.insert(request.request_id); + conn.pending_outbound_responses.insert(request.request_id); self.pending_events.push_back(ToSwarm::NotifyHandler { peer_id: *peer, handler: NotifyHandler::One(conn.id), @@ -537,13 +567,13 @@ where /// Remove pending outbound response for the given peer and connection. /// /// Returns `true` if the provided connection to the given peer is still - /// alive and the [`RequestId`] was previously present and is now removed. + /// alive and the [`OutboundRequestId`] was previously present and is now removed. /// Returns `false` otherwise. fn remove_pending_outbound_response( &mut self, peer: &PeerId, connection: ConnectionId, - request: RequestId, + request: OutboundRequestId, ) -> bool { self.get_connection_mut(peer, connection) .map(|c| c.pending_outbound_responses.remove(&request)) @@ -553,16 +583,16 @@ where /// Remove pending inbound response for the given peer and connection. /// /// Returns `true` if the provided connection to the given peer is still - /// alive and the [`RequestId`] was previously present and is now removed. + /// alive and the [`InboundRequestId`] was previously present and is now removed. /// Returns `false` otherwise. fn remove_pending_inbound_response( &mut self, peer: &PeerId, connection: ConnectionId, - request: &RequestId, + request: InboundRequestId, ) -> bool { self.get_connection_mut(peer, connection) - .map(|c| c.pending_inbound_responses.remove(request)) + .map(|c| c.pending_inbound_responses.remove(&request)) .unwrap_or(false) } @@ -628,7 +658,7 @@ where self.connected.remove(&peer_id); } - for request_id in connection.pending_outbound_responses { + for request_id in connection.pending_inbound_responses { self.pending_events .push_back(ToSwarm::GenerateEvent(Event::InboundFailure { peer: peer_id, @@ -637,7 +667,7 @@ where })); } - for request_id in connection.pending_inbound_responses { + for request_id in connection.pending_outbound_responses { self.pending_events .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure { peer: peer_id, @@ -681,7 +711,7 @@ where if let Some(pending_requests) = self.pending_outbound_requests.remove(&peer) { for request in pending_requests { connection - .pending_inbound_responses + .pending_outbound_responses .insert(request.request_id); handler.on_behaviour_event(request); } @@ -709,7 +739,8 @@ where self.inbound_protocols.clone(), self.codec.clone(), self.config.request_timeout, - self.next_inbound_id.clone(), + self.next_inbound_request_id.clone(), + self.config.max_concurrent_streams, ); self.preload_new_handler(&mut handler, peer, connection_id, None); @@ -751,7 +782,8 @@ where self.inbound_protocols.clone(), self.codec.clone(), self.config.request_timeout, - self.next_inbound_id.clone(), + self.next_inbound_request_id.clone(), + self.config.max_concurrent_streams, ); self.preload_new_handler( @@ -795,7 +827,7 @@ where request_id, response, } => { - let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); + let removed = self.remove_pending_outbound_response(&peer, connection, request_id); debug_assert!( removed, "Expect request_id to be pending before receiving response.", @@ -812,35 +844,26 @@ where request_id, request, sender, - } => { - let channel = ResponseChannel { sender }; - let message = Message::Request { - request_id, - request, - channel, - }; - self.pending_events - .push_back(ToSwarm::GenerateEvent(Event::Message { peer, message })); + } => match self.get_connection_mut(&peer, connection) { + Some(connection) => { + let inserted = connection.pending_inbound_responses.insert(request_id); + debug_assert!(inserted, "Expect id of new request to be unknown."); - match self.get_connection_mut(&peer, connection) { - Some(connection) => { - let inserted = connection.pending_outbound_responses.insert(request_id); - debug_assert!(inserted, "Expect id of new request to be unknown."); - } - // Connection closed after `Event::Request` has been emitted. - None => { - self.pending_events.push_back(ToSwarm::GenerateEvent( - Event::InboundFailure { - peer, - request_id, - error: InboundFailure::ConnectionClosed, - }, - )); - } + let channel = ResponseChannel { sender }; + let message = Message::Request { + request_id, + request, + channel, + }; + self.pending_events + .push_back(ToSwarm::GenerateEvent(Event::Message { peer, message })); } - } + None => { + log::debug!("Connection ({connection}) closed after `Event::Request` ({request_id}) has been emitted."); + } + }, handler::Event::ResponseSent(request_id) => { - let removed = self.remove_pending_outbound_response(&peer, connection, request_id); + let removed = self.remove_pending_inbound_response(&peer, connection, request_id); debug_assert!( removed, "Expect request_id to be pending before response is sent." @@ -853,7 +876,7 @@ where })); } handler::Event::ResponseOmission(request_id) => { - let removed = self.remove_pending_outbound_response(&peer, connection, request_id); + let removed = self.remove_pending_inbound_response(&peer, connection, request_id); debug_assert!( removed, "Expect request_id to be pending before response is omitted.", @@ -867,7 +890,7 @@ where })); } handler::Event::OutboundTimeout(request_id) => { - let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); + let removed = self.remove_pending_outbound_response(&peer, connection, request_id); debug_assert!( removed, "Expect request_id to be pending before request times out." @@ -881,7 +904,7 @@ where })); } handler::Event::OutboundUnsupportedProtocols(request_id) => { - let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); + let removed = self.remove_pending_outbound_response(&peer, connection, request_id); debug_assert!( removed, "Expect request_id to be pending before failing to connect.", @@ -894,6 +917,47 @@ where error: OutboundFailure::UnsupportedProtocols, })); } + handler::Event::OutboundStreamFailed { request_id, error } => { + let removed = self.remove_pending_outbound_response(&peer, connection, request_id); + debug_assert!(removed, "Expect request_id to be pending upon failure"); + + self.pending_events + .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure { + peer, + request_id, + error: OutboundFailure::Io(error), + })) + } + handler::Event::InboundTimeout(request_id) => { + let removed = self.remove_pending_inbound_response(&peer, connection, request_id); + + if removed { + self.pending_events + .push_back(ToSwarm::GenerateEvent(Event::InboundFailure { + peer, + request_id, + error: InboundFailure::Timeout, + })); + } else { + // This happens when timeout is emitted before `read_request` finishes. + log::debug!("Inbound request timeout for an unknown request_id ({request_id})"); + } + } + handler::Event::InboundStreamFailed { request_id, error } => { + let removed = self.remove_pending_inbound_response(&peer, connection, request_id); + + if removed { + self.pending_events + .push_back(ToSwarm::GenerateEvent(Event::InboundFailure { + peer, + request_id, + error: InboundFailure::Io(error), + })); + } else { + // This happens when `read_request` fails. + log::debug!("Inbound failure is reported for an unknown request_id ({request_id}): {error}"); + } + } } } @@ -921,10 +985,10 @@ struct Connection { /// Pending outbound responses where corresponding inbound requests have /// been received on this connection and emitted via `poll` but have not yet /// been answered. - pending_outbound_responses: HashSet, + pending_outbound_responses: HashSet, /// Pending inbound responses for previously sent requests on this /// connection. - pending_inbound_responses: HashSet, + pending_inbound_responses: HashSet, } impl Connection { diff --git a/protocols/request-response/tests/error_reporting.rs b/protocols/request-response/tests/error_reporting.rs new file mode 100644 index 00000000000..cf651d395f5 --- /dev/null +++ b/protocols/request-response/tests/error_reporting.rs @@ -0,0 +1,555 @@ +use anyhow::{bail, Result}; +use async_std::task::sleep; +use async_trait::async_trait; +use futures::prelude::*; +use libp2p_identity::PeerId; +use libp2p_request_response as request_response; +use libp2p_request_response::ProtocolSupport; +use libp2p_swarm::{StreamProtocol, Swarm}; +use libp2p_swarm_test::SwarmExt; +use request_response::{ + Codec, InboundFailure, InboundRequestId, OutboundFailure, OutboundRequestId, ResponseChannel, +}; +use std::pin::pin; +use std::time::Duration; +use std::{io, iter}; + +#[async_std::test] +async fn report_outbound_failure_on_read_response() { + let _ = env_logger::try_init(); + + let (peer1_id, mut swarm1) = new_swarm(); + let (peer2_id, mut swarm2) = new_swarm(); + + swarm1.listen().await; + swarm2.connect(&mut swarm1).await; + + let server_task = async move { + let (peer, req_id, action, resp_channel) = wait_request(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(action, Action::FailOnReadResponse); + swarm1 + .behaviour_mut() + .send_response(resp_channel, Action::FailOnReadResponse) + .unwrap(); + + let (peer, req_id_done) = wait_response_sent(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(req_id_done, req_id); + + // Keep the connection alive, otherwise swarm2 may receive `ConnectionClosed` instead + wait_no_events(&mut swarm1).await; + }; + + // Expects OutboundFailure::Io failure with `FailOnReadResponse` error + let client_task = async move { + let req_id = swarm2 + .behaviour_mut() + .send_request(&peer1_id, Action::FailOnReadResponse); + + let (peer, req_id_done, error) = wait_outbound_failure(&mut swarm2).await.unwrap(); + assert_eq!(peer, peer1_id); + assert_eq!(req_id_done, req_id); + + let error = match error { + OutboundFailure::Io(e) => e, + e => panic!("Unexpected error: {e:?}"), + }; + + assert_eq!(error.kind(), io::ErrorKind::Other); + assert_eq!( + error.into_inner().unwrap().to_string(), + "FailOnReadResponse" + ); + }; + + let server_task = pin!(server_task); + let client_task = pin!(client_task); + futures::future::select(server_task, client_task).await; +} + +#[async_std::test] +async fn report_outbound_failure_on_write_request() { + let _ = env_logger::try_init(); + + let (peer1_id, mut swarm1) = new_swarm(); + let (_peer2_id, mut swarm2) = new_swarm(); + + swarm1.listen().await; + swarm2.connect(&mut swarm1).await; + + // Expects no events because `Event::Request` is produced after `read_request`. + // Keep the connection alive, otherwise swarm2 may receive `ConnectionClosed` instead. + let server_task = wait_no_events(&mut swarm1); + + // Expects OutboundFailure::Io failure with `FailOnWriteRequest` error. + let client_task = async move { + let req_id = swarm2 + .behaviour_mut() + .send_request(&peer1_id, Action::FailOnWriteRequest); + + let (peer, req_id_done, error) = wait_outbound_failure(&mut swarm2).await.unwrap(); + assert_eq!(peer, peer1_id); + assert_eq!(req_id_done, req_id); + + let error = match error { + OutboundFailure::Io(e) => e, + e => panic!("Unexpected error: {e:?}"), + }; + + assert_eq!(error.kind(), io::ErrorKind::Other); + assert_eq!( + error.into_inner().unwrap().to_string(), + "FailOnWriteRequest" + ); + }; + + let server_task = pin!(server_task); + let client_task = pin!(client_task); + futures::future::select(server_task, client_task).await; +} + +#[async_std::test] +async fn report_outbound_timeout_on_read_response() { + let _ = env_logger::try_init(); + + // `swarm1` needs to have a bigger timeout to avoid racing + let (peer1_id, mut swarm1) = new_swarm_with_timeout(Duration::from_millis(200)); + let (peer2_id, mut swarm2) = new_swarm_with_timeout(Duration::from_millis(100)); + + swarm1.listen().await; + swarm2.connect(&mut swarm1).await; + + let server_task = async move { + let (peer, req_id, action, resp_channel) = wait_request(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(action, Action::TimeoutOnReadResponse); + swarm1 + .behaviour_mut() + .send_response(resp_channel, Action::TimeoutOnReadResponse) + .unwrap(); + + let (peer, req_id_done) = wait_response_sent(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(req_id_done, req_id); + + // Keep the connection alive, otherwise swarm2 may receive `ConnectionClosed` instead + wait_no_events(&mut swarm1).await; + }; + + // Expects OutboundFailure::Timeout + let client_task = async move { + let req_id = swarm2 + .behaviour_mut() + .send_request(&peer1_id, Action::TimeoutOnReadResponse); + + let (peer, req_id_done, error) = wait_outbound_failure(&mut swarm2).await.unwrap(); + assert_eq!(peer, peer1_id); + assert_eq!(req_id_done, req_id); + assert!(matches!(error, OutboundFailure::Timeout)); + }; + + let server_task = pin!(server_task); + let client_task = pin!(client_task); + futures::future::select(server_task, client_task).await; +} + +#[async_std::test] +async fn report_inbound_failure_on_read_request() { + let _ = env_logger::try_init(); + + let (peer1_id, mut swarm1) = new_swarm(); + let (_peer2_id, mut swarm2) = new_swarm(); + + swarm1.listen().await; + swarm2.connect(&mut swarm1).await; + + // Expects no events because `Event::Request` is produced after `read_request`. + // Keep the connection alive, otherwise swarm2 may receive `ConnectionClosed` instead. + let server_task = wait_no_events(&mut swarm1); + + // Expects io::ErrorKind::UnexpectedEof + let client_task = async move { + let req_id = swarm2 + .behaviour_mut() + .send_request(&peer1_id, Action::FailOnReadRequest); + + let (peer, req_id_done, error) = wait_outbound_failure(&mut swarm2).await.unwrap(); + assert_eq!(peer, peer1_id); + assert_eq!(req_id_done, req_id); + + match error { + OutboundFailure::Io(e) if e.kind() == io::ErrorKind::UnexpectedEof => {} + e => panic!("Unexpected error: {e:?}"), + }; + }; + + let server_task = pin!(server_task); + let client_task = pin!(client_task); + futures::future::select(server_task, client_task).await; +} + +#[async_std::test] +async fn report_inbound_failure_on_write_response() { + let _ = env_logger::try_init(); + + let (peer1_id, mut swarm1) = new_swarm(); + let (peer2_id, mut swarm2) = new_swarm(); + + swarm1.listen().await; + swarm2.connect(&mut swarm1).await; + + // Expects OutboundFailure::Io failure with `FailOnWriteResponse` error + let server_task = async move { + let (peer, req_id, action, resp_channel) = wait_request(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(action, Action::FailOnWriteResponse); + swarm1 + .behaviour_mut() + .send_response(resp_channel, Action::FailOnWriteResponse) + .unwrap(); + + let (peer, req_id_done, error) = wait_inbound_failure(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(req_id_done, req_id); + + let error = match error { + InboundFailure::Io(e) => e, + e => panic!("Unexpected error: {e:?}"), + }; + + assert_eq!(error.kind(), io::ErrorKind::Other); + assert_eq!( + error.into_inner().unwrap().to_string(), + "FailOnWriteResponse" + ); + }; + + // Expects OutboundFailure::ConnectionClosed or io::ErrorKind::UnexpectedEof + let client_task = async move { + let req_id = swarm2 + .behaviour_mut() + .send_request(&peer1_id, Action::FailOnWriteResponse); + + let (peer, req_id_done, error) = wait_outbound_failure(&mut swarm2).await.unwrap(); + assert_eq!(peer, peer1_id); + assert_eq!(req_id_done, req_id); + + match error { + OutboundFailure::ConnectionClosed => { + // ConnectionClosed is allowed here because we mainly test the behavior + // of `server_task`. + } + OutboundFailure::Io(e) if e.kind() == io::ErrorKind::UnexpectedEof => {} + e => panic!("Unexpected error: {e:?}"), + }; + + // Keep alive the task, so only `server_task` can finish + wait_no_events(&mut swarm2).await; + }; + + let server_task = pin!(server_task); + let client_task = pin!(client_task); + futures::future::select(server_task, client_task).await; +} + +#[async_std::test] +async fn report_inbound_timeout_on_write_response() { + let _ = env_logger::try_init(); + + // `swarm2` needs to have a bigger timeout to avoid racing + let (peer1_id, mut swarm1) = new_swarm_with_timeout(Duration::from_millis(100)); + let (peer2_id, mut swarm2) = new_swarm_with_timeout(Duration::from_millis(200)); + + swarm1.listen().await; + swarm2.connect(&mut swarm1).await; + + // Expects InboundFailure::Timeout + let server_task = async move { + let (peer, req_id, action, resp_channel) = wait_request(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(action, Action::TimeoutOnWriteResponse); + swarm1 + .behaviour_mut() + .send_response(resp_channel, Action::TimeoutOnWriteResponse) + .unwrap(); + + let (peer, req_id_done, error) = wait_inbound_failure(&mut swarm1).await.unwrap(); + assert_eq!(peer, peer2_id); + assert_eq!(req_id_done, req_id); + assert!(matches!(error, InboundFailure::Timeout)); + }; + + // Expects OutboundFailure::ConnectionClosed or io::ErrorKind::UnexpectedEof + let client_task = async move { + let req_id = swarm2 + .behaviour_mut() + .send_request(&peer1_id, Action::TimeoutOnWriteResponse); + + let (peer, req_id_done, error) = wait_outbound_failure(&mut swarm2).await.unwrap(); + assert_eq!(peer, peer1_id); + assert_eq!(req_id_done, req_id); + + match error { + OutboundFailure::ConnectionClosed => { + // ConnectionClosed is allowed here because we mainly test the behavior + // of `server_task`. + } + OutboundFailure::Io(e) if e.kind() == io::ErrorKind::UnexpectedEof => {} + e => panic!("Unexpected error: {e:?}"), + } + + // Keep alive the task, so only `server_task` can finish + wait_no_events(&mut swarm2).await; + }; + + let server_task = pin!(server_task); + let client_task = pin!(client_task); + futures::future::select(server_task, client_task).await; +} + +#[derive(Clone, Default)] +struct TestCodec; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Action { + FailOnReadRequest, + FailOnReadResponse, + TimeoutOnReadResponse, + FailOnWriteRequest, + FailOnWriteResponse, + TimeoutOnWriteResponse, +} + +impl From for u8 { + fn from(value: Action) -> Self { + match value { + Action::FailOnReadRequest => 0, + Action::FailOnReadResponse => 1, + Action::TimeoutOnReadResponse => 2, + Action::FailOnWriteRequest => 3, + Action::FailOnWriteResponse => 4, + Action::TimeoutOnWriteResponse => 5, + } + } +} + +impl TryFrom for Action { + type Error = io::Error; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Action::FailOnReadRequest), + 1 => Ok(Action::FailOnReadResponse), + 2 => Ok(Action::TimeoutOnReadResponse), + 3 => Ok(Action::FailOnWriteRequest), + 4 => Ok(Action::FailOnWriteResponse), + 5 => Ok(Action::TimeoutOnWriteResponse), + _ => Err(io::Error::new(io::ErrorKind::Other, "invalid action")), + } + } +} + +#[async_trait] +impl Codec for TestCodec { + type Protocol = StreamProtocol; + type Request = Action; + type Response = Action; + + async fn read_request( + &mut self, + _protocol: &Self::Protocol, + io: &mut T, + ) -> io::Result + where + T: AsyncRead + Unpin + Send, + { + let mut buf = Vec::new(); + io.read_to_end(&mut buf).await?; + + if buf.is_empty() { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + + assert_eq!(buf.len(), 1); + + match buf[0].try_into()? { + Action::FailOnReadRequest => { + Err(io::Error::new(io::ErrorKind::Other, "FailOnReadRequest")) + } + action => Ok(action), + } + } + + async fn read_response( + &mut self, + _protocol: &Self::Protocol, + io: &mut T, + ) -> io::Result + where + T: AsyncRead + Unpin + Send, + { + let mut buf = Vec::new(); + io.read_to_end(&mut buf).await?; + + if buf.is_empty() { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + + assert_eq!(buf.len(), 1); + + match buf[0].try_into()? { + Action::FailOnReadResponse => { + Err(io::Error::new(io::ErrorKind::Other, "FailOnReadResponse")) + } + Action::TimeoutOnReadResponse => loop { + sleep(Duration::MAX).await; + }, + action => Ok(action), + } + } + + async fn write_request( + &mut self, + _protocol: &Self::Protocol, + io: &mut T, + req: Self::Request, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + match req { + Action::FailOnWriteRequest => { + Err(io::Error::new(io::ErrorKind::Other, "FailOnWriteRequest")) + } + action => { + let bytes = [action.into()]; + io.write_all(&bytes).await?; + Ok(()) + } + } + } + + async fn write_response( + &mut self, + _protocol: &Self::Protocol, + io: &mut T, + res: Self::Response, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + match res { + Action::FailOnWriteResponse => { + Err(io::Error::new(io::ErrorKind::Other, "FailOnWriteResponse")) + } + Action::TimeoutOnWriteResponse => loop { + sleep(Duration::MAX).await; + }, + action => { + let bytes = [action.into()]; + io.write_all(&bytes).await?; + Ok(()) + } + } + } +} + +fn new_swarm_with_timeout( + timeout: Duration, +) -> (PeerId, Swarm>) { + let protocols = iter::once((StreamProtocol::new("/test/1"), ProtocolSupport::Full)); + let cfg = request_response::Config::default().with_request_timeout(timeout); + + let swarm = + Swarm::new_ephemeral(|_| request_response::Behaviour::::new(protocols, cfg)); + let peed_id = *swarm.local_peer_id(); + + (peed_id, swarm) +} + +fn new_swarm() -> (PeerId, Swarm>) { + new_swarm_with_timeout(Duration::from_millis(100)) +} + +async fn wait_no_events(swarm: &mut Swarm>) { + loop { + if let Ok(ev) = swarm.select_next_some().await.try_into_behaviour_event() { + panic!("Unexpected event: {ev:?}") + } + } +} + +async fn wait_request( + swarm: &mut Swarm>, +) -> Result<(PeerId, InboundRequestId, Action, ResponseChannel)> { + loop { + match swarm.select_next_some().await.try_into_behaviour_event() { + Ok(request_response::Event::Message { + peer, + message: + request_response::Message::Request { + request_id, + request, + channel, + }, + }) => { + return Ok((peer, request_id, request, channel)); + } + Ok(ev) => bail!("Unexpected event: {ev:?}"), + Err(..) => {} + } + } +} + +async fn wait_response_sent( + swarm: &mut Swarm>, +) -> Result<(PeerId, InboundRequestId)> { + loop { + match swarm.select_next_some().await.try_into_behaviour_event() { + Ok(request_response::Event::ResponseSent { + peer, request_id, .. + }) => { + return Ok((peer, request_id)); + } + Ok(ev) => bail!("Unexpected event: {ev:?}"), + Err(..) => {} + } + } +} + +async fn wait_inbound_failure( + swarm: &mut Swarm>, +) -> Result<(PeerId, InboundRequestId, InboundFailure)> { + loop { + match swarm.select_next_some().await.try_into_behaviour_event() { + Ok(request_response::Event::InboundFailure { + peer, + request_id, + error, + }) => { + return Ok((peer, request_id, error)); + } + Ok(ev) => bail!("Unexpected event: {ev:?}"), + Err(..) => {} + } + } +} + +async fn wait_outbound_failure( + swarm: &mut Swarm>, +) -> Result<(PeerId, OutboundRequestId, OutboundFailure)> { + loop { + match swarm.select_next_some().await.try_into_behaviour_event() { + Ok(request_response::Event::OutboundFailure { + peer, + request_id, + error, + }) => { + return Ok((peer, request_id, error)); + } + Ok(ev) => bail!("Unexpected event: {ev:?}"), + Err(..) => {} + } + } +} diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index e0424488f48..37f21264d49 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -28,7 +28,7 @@ use libp2p_swarm::{StreamProtocol, Swarm, SwarmEvent}; use libp2p_swarm_test::SwarmExt; use rand::{self, Rng}; use serde::{Deserialize, Serialize}; -use std::iter; +use std::{io, iter}; #[async_std::test] #[cfg(feature = "cbor")] @@ -288,7 +288,10 @@ async fn emits_inbound_connection_closed_if_channel_is_dropped() { e => panic!("unexpected event from peer 2: {e:?}"), }; - assert_eq!(error, request_response::OutboundFailure::ConnectionClosed); + assert!(matches!( + error, + request_response::OutboundFailure::Io(e) if e.kind() == io::ErrorKind::UnexpectedEof, + )); } // Simple Ping-Pong Protocol