Skip to content

Commit

Permalink
Use InflightProtocolDataQueue in libp2p-kad
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaseizinger committed Nov 14, 2023
1 parent ca6f8c4 commit 3f72224
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 115 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions misc/protocol-utils/src/ipd_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ impl<D, Req, Res> InflightProtocolDataQueue<D, Req, Res> {
pub fn enqueue_request(&mut self, request: Req, data: D) {
self.pending_requests.push_back(request);
self.data_of_inflight_requests.push_back(data);

debug_assert_eq!(
self.pending_requests.len(),
self.data_of_inflight_requests.len()
);
}

/// Submits a response to the queue.
Expand All @@ -46,6 +41,11 @@ impl<D, Req, Res> InflightProtocolDataQueue<D, Req, Res> {
self.received_responses.push_back(res);
}

/// How many protocols are currently in-flight.
pub fn num_inflight(&self) -> usize {
self.data_of_inflight_requests.len() - self.received_responses.len()
}

pub fn next_completed(&mut self) -> Option<(Res, D)> {
let res = self.received_responses.pop_front()?;
let data = self.data_of_inflight_requests.pop_front()?;
Expand Down
1 change: 1 addition & 0 deletions protocols/kad/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ asynchronous-codec = "0.6"
futures = "0.3.29"
libp2p-core = { workspace = true }
libp2p-swarm = { workspace = true }
libp2p-protocol-utils = { workspace = true }
quick-protobuf = "0.8"
quick-protobuf-codec = { workspace = true }
libp2p-identity = { workspace = true, features = ["rand"] }
Expand Down
207 changes: 101 additions & 106 deletions protocols/kad/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,15 @@ use futures::prelude::*;
use futures::stream::SelectAll;
use libp2p_core::{upgrade, ConnectedPoint};
use libp2p_identity::PeerId;
use libp2p_swarm::handler::{
ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
};
use libp2p_protocol_utils::InflightProtocolDataQueue;
use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound};
use libp2p_swarm::{
ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol,
SupportedProtocols,
};
use std::collections::VecDeque;
use std::task::Waker;
use std::{error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll};
use void::Void;

const MAX_NUM_SUBSTREAMS: usize = 32;

Expand All @@ -62,12 +61,13 @@ pub struct Handler {
/// List of active outbound substreams with the state they are in.
outbound_substreams: SelectAll<OutboundSubstreamState>,

/// Number of outbound streams being upgraded right now.
num_requested_outbound_streams: usize,

/// List of outbound substreams that are waiting to become active next.
/// Contains the request we want to send, and the user data if we expect an answer.
pending_messages: VecDeque<(KadRequestMsg, Option<QueryId>)>,
pending_streams: InflightProtocolDataQueue<
(KadRequestMsg, Option<QueryId>),
ProtocolConfig,
Result<KadOutStreamSink<Stream>, StreamUpgradeError<Void>>,
>,

/// List of active inbound substreams with the state they are in.
inbound_substreams: SelectAll<InboundSubstreamState>,
Expand Down Expand Up @@ -293,7 +293,7 @@ pub enum HandlerEvent {
#[derive(Debug)]
pub enum HandlerQueryErr {
/// Error while trying to perform the query.
Upgrade(StreamUpgradeError<io::Error>),
Upgrade(StreamUpgradeError<Void>),
/// Received an answer that doesn't correspond to the request.
UnexpectedMessage,
/// I/O error in the substream.
Expand Down Expand Up @@ -329,8 +329,8 @@ impl error::Error for HandlerQueryErr {
}
}

impl From<StreamUpgradeError<io::Error>> for HandlerQueryErr {
fn from(err: StreamUpgradeError<io::Error>) -> Self {
impl From<StreamUpgradeError<Void>> for HandlerQueryErr {
fn from(err: StreamUpgradeError<Void>) -> Self {
HandlerQueryErr::Upgrade(err)
}
}
Expand Down Expand Up @@ -481,40 +481,12 @@ impl Handler {
next_connec_unique_id: UniqueConnecId(0),
inbound_substreams: Default::default(),
outbound_substreams: Default::default(),
num_requested_outbound_streams: 0,
pending_messages: Default::default(),
pending_streams: Default::default(),
protocol_status: None,
remote_supported_protocols: Default::default(),
}
}

fn on_fully_negotiated_outbound(
&mut self,
FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound<
<Self as ConnectionHandler>::OutboundProtocol,
<Self as ConnectionHandler>::OutboundOpenInfo,
>,
) {
if let Some((msg, query_id)) = self.pending_messages.pop_front() {
self.outbound_substreams
.push(OutboundSubstreamState::PendingSend(protocol, msg, query_id));
} else {
debug_assert!(false, "Requested outbound stream without message")
}

self.num_requested_outbound_streams -= 1;

if self.protocol_status.is_none() {
// Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want
// the behaviour to add this peer to the routing table, if possible.
self.protocol_status = Some(ProtocolStatus {
supported: true,
reported: false,
});
}
}

fn on_fully_negotiated_inbound(
&mut self,
FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
Expand Down Expand Up @@ -572,26 +544,6 @@ impl Handler {
substream: protocol,
});
}

fn on_dial_upgrade_error(
&mut self,
DialUpgradeError {
info: (), error, ..
}: DialUpgradeError<
<Self as ConnectionHandler>::OutboundOpenInfo,
<Self as ConnectionHandler>::OutboundProtocol,
>,
) {
// TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't
// continue trying

if let Some((_, Some(query_id))) = self.pending_messages.pop_front() {
self.outbound_substreams
.push(OutboundSubstreamState::ReportError(error.into(), query_id));
}

self.num_requested_outbound_streams -= 1;
}
}

impl ConnectionHandler for Handler {
Expand Down Expand Up @@ -626,16 +578,20 @@ impl ConnectionHandler for Handler {
}
}
HandlerIn::FindNodeReq { key, query_id } => {
let msg = KadRequestMsg::FindNode { key };
self.pending_messages.push_back((msg, Some(query_id)));
self.pending_streams.enqueue_request(
self.protocol_config.clone(),
(KadRequestMsg::FindNode { key }, Some(query_id)),
);
}
HandlerIn::FindNodeRes {
closer_peers,
request_id,
} => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
HandlerIn::GetProvidersReq { key, query_id } => {
let msg = KadRequestMsg::GetProviders { key };
self.pending_messages.push_back((msg, Some(query_id)));
self.pending_streams.enqueue_request(
self.protocol_config.clone(),
(KadRequestMsg::GetProviders { key }, Some(query_id)),
);
}
HandlerIn::GetProvidersRes {
closer_peers,
Expand All @@ -649,16 +605,22 @@ impl ConnectionHandler for Handler {
},
),
HandlerIn::AddProvider { key, provider } => {
let msg = KadRequestMsg::AddProvider { key, provider };
self.pending_messages.push_back((msg, None));
self.pending_streams.enqueue_request(
self.protocol_config.clone(),
(KadRequestMsg::AddProvider { key, provider }, None),
);
}
HandlerIn::GetRecord { key, query_id } => {
let msg = KadRequestMsg::GetValue { key };
self.pending_messages.push_back((msg, Some(query_id)));
self.pending_streams.enqueue_request(
self.protocol_config.clone(),
(KadRequestMsg::GetValue { key }, Some(query_id)),
);
}
HandlerIn::PutRecord { record, query_id } => {
let msg = KadRequestMsg::PutValue { record };
self.pending_messages.push_back((msg, Some(query_id)));
self.pending_streams.enqueue_request(
self.protocol_config.clone(),
(KadRequestMsg::PutValue { record }, Some(query_id)),
);
}
HandlerIn::GetRecordRes {
record,
Expand Down Expand Up @@ -712,44 +674,67 @@ impl ConnectionHandler for Handler {
) -> Poll<
ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
> {
match &mut self.protocol_status {
Some(status) if !status.reported => {
status.reported = true;
let event = if status.supported {
HandlerEvent::ProtocolConfirmed {
endpoint: self.endpoint.clone(),
}
} else {
HandlerEvent::ProtocolNotSupported {
endpoint: self.endpoint.clone(),
}
};
loop {
match &mut self.protocol_status {
Some(status) if !status.reported => {
status.reported = true;
let event = if status.supported {
HandlerEvent::ProtocolConfirmed {
endpoint: self.endpoint.clone(),
}
} else {
HandlerEvent::ProtocolNotSupported {
endpoint: self.endpoint.clone(),
}
};

return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
}
_ => {}
}
_ => {}
}

if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) {
return Poll::Ready(event);
}
match self.pending_streams.next_completed() {
Some((Ok(stream), (message, query_id))) => {
self.outbound_substreams
.push(OutboundSubstreamState::PendingSend(
stream, message, query_id,
));
continue;
}
// TODO: Check if the remote doesn't support kademlia and stop trying if it doesn't
Some((Err(error), (_, Some(query_id)))) => {
self.outbound_substreams
.push(OutboundSubstreamState::ReportError(error.into(), query_id));
continue;
}
Some((Err(error), (message, None))) => {
tracing::debug!(?message, "Failed to establish stream: {error}");
continue;
}
None => {}
}

if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
return Poll::Ready(event);
}
if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) {
return Poll::Ready(event);
}

let num_in_progress_outbound_substreams =
self.outbound_substreams.len() + self.num_requested_outbound_streams;
if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS
&& self.num_requested_outbound_streams < self.pending_messages.len()
{
self.num_requested_outbound_streams += 1;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
});
}
if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
return Poll::Ready(event);
}

let num_in_progress_outbound_substreams =
self.outbound_substreams.len() + self.pending_streams.num_inflight();

Poll::Pending
if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS {
if let Some(next) = self.pending_streams.next_request() {
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(next, ()),
});
}
}

return Poll::Pending;
}
}

fn on_connection_event(
Expand All @@ -762,14 +747,24 @@ impl ConnectionHandler for Handler {
>,
) {
match event {
ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
self.on_fully_negotiated_outbound(fully_negotiated_outbound)
ConnectionEvent::FullyNegotiatedOutbound(ev) => {
self.pending_streams.submit_response(Ok(ev.protocol));

if self.protocol_status.is_none() {
// Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want
// the behaviour to add this peer to the routing table, if possible.
self.protocol_status = Some(ProtocolStatus {
supported: true,
reported: false,
});
}
}
ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
self.on_fully_negotiated_inbound(fully_negotiated_inbound)
}
ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
self.on_dial_upgrade_error(dial_upgrade_error)
ConnectionEvent::DialUpgradeError(ev) => {
self.pending_streams.submit_response(Err(ev.error));
}
ConnectionEvent::RemoteProtocolsChange(change) => {
let dirty = self.remote_supported_protocols.on_protocols_change(change);
Expand Down
9 changes: 5 additions & 4 deletions protocols/kad/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use std::marker::PhantomData;
use std::{convert::TryFrom, time::Duration};
use std::{io, iter};
use tracing::debug;
use void::Void;

/// The protocol name used for negotiating with multistream-select.
pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0");
Expand Down Expand Up @@ -220,8 +221,8 @@ where
C: AsyncRead + AsyncWrite + Unpin,
{
type Output = KadInStreamSink<C>;
type Future = future::Ready<Result<Self::Output, io::Error>>;
type Error = io::Error;
type Future = future::Ready<Result<Self::Output, Self::Error>>;
type Error = Void;

fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
let codec = Codec::new(self.max_packet_size);
Expand All @@ -235,8 +236,8 @@ where
C: AsyncRead + AsyncWrite + Unpin,
{
type Output = KadOutStreamSink<C>;
type Future = future::Ready<Result<Self::Output, io::Error>>;
type Error = io::Error;
type Future = future::Ready<Result<Self::Output, Self::Error>>;
type Error = Void;

fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
let codec = Codec::new(self.max_packet_size);
Expand Down

0 comments on commit 3f72224

Please sign in to comment.