Skip to content

Commit

Permalink
refactor(kad): don't use OutboundOpenInfo
Browse files Browse the repository at this point in the history
As part of pushing #3268 forward, remove the use of `OutboundOpenInfo` from `libp2p-kad`.

Related #3268.

Pull-Request: #3760.
  • Loading branch information
thomaseizinger authored Apr 28, 2023
1 parent 4ebb4d0 commit 99ad3b6
Showing 1 changed file with 32 additions and 44 deletions.
76 changes: 32 additions & 44 deletions protocols/kad/src/handler_priv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ pub struct KademliaHandler<TUserData> {

/// List of outbound substreams that are waiting to become active next.
/// Contains the request we want to send, and the user data if we expect an answer.
requested_streams:
VecDeque<SubstreamProtocol<KademliaProtocolConfig, (KadRequestMsg, Option<TUserData>)>>,
pending_messages: VecDeque<(KadRequestMsg, Option<TUserData>)>,

/// List of active inbound substreams with the state they are in.
inbound_substreams: SelectAll<InboundSubstreamState<TUserData>>,
Expand Down Expand Up @@ -499,27 +498,30 @@ where
inbound_substreams: Default::default(),
outbound_substreams: Default::default(),
num_requested_outbound_streams: 0,
requested_streams: Default::default(),
pending_messages: Default::default(),
keep_alive,
protocol_status: ProtocolStatus::Unconfirmed,
}
}

fn on_fully_negotiated_outbound(
&mut self,
FullyNegotiatedOutbound {
protocol,
info: (msg, user_data),
}: FullyNegotiatedOutbound<
FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound<
<Self as ConnectionHandler>::OutboundProtocol,
<Self as ConnectionHandler>::OutboundOpenInfo,
>,
) {
self.outbound_substreams
.push(OutboundSubstreamState::PendingSend(
protocol, msg, user_data,
));
if let Some((msg, user_data)) = self.pending_messages.pop_front() {
self.outbound_substreams
.push(OutboundSubstreamState::PendingSend(
protocol, msg, user_data,
));
} else {
debug_assert!(false, "Requested outbound stream without message")
}

self.num_requested_outbound_streams -= 1;

if let ProtocolStatus::Unconfirmed = self.protocol_status {
// Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want
Expand Down Expand Up @@ -587,20 +589,20 @@ where
fn on_dial_upgrade_error(
&mut self,
DialUpgradeError {
info: (_, user_data),
error,
..
info: (), error, ..
}: DialUpgradeError<
<Self as ConnectionHandler>::OutboundOpenInfo,
<Self as ConnectionHandler>::OutboundProtocol,
>,
) {
// TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't
// continue trying
if let Some(user_data) = user_data {

if let Some((_, Some(user_data))) = self.pending_messages.pop_front() {
self.outbound_substreams
.push(OutboundSubstreamState::ReportError(error.into(), user_data));
}

self.num_requested_outbound_streams -= 1;
}
}
Expand All @@ -614,8 +616,7 @@ where
type Error = io::Error; // TODO: better error type?
type InboundProtocol = Either<KademliaProtocolConfig, upgrade::DeniedUpgrade>;
type OutboundProtocol = KademliaProtocolConfig;
// Message of the request to send to the remote, and user data if we expect an answer.
type OutboundOpenInfo = (KadRequestMsg, Option<TUserData>);
type OutboundOpenInfo = ();
type InboundOpenInfo = ();

fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
Expand Down Expand Up @@ -645,21 +646,15 @@ where
}
KademliaHandlerIn::FindNodeReq { key, user_data } => {
let msg = KadRequestMsg::FindNode { key };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::FindNodeRes {
closer_peers,
request_id,
} => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
KademliaHandlerIn::GetProvidersReq { key, user_data } => {
let msg = KadRequestMsg::GetProviders { key };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::GetProvidersRes {
closer_peers,
Expand All @@ -674,24 +669,15 @@ where
),
KademliaHandlerIn::AddProvider { key, provider } => {
let msg = KadRequestMsg::AddProvider { key, provider };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, None),
));
self.pending_messages.push_back((msg, None));
}
KademliaHandlerIn::GetRecord { key, user_data } => {
let msg = KadRequestMsg::GetValue { key };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::PutRecord { record, user_data } => {
let msg = KadRequestMsg::PutValue { record };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::GetRecordRes {
record,
Expand Down Expand Up @@ -750,11 +736,13 @@ where

let num_in_progress_outbound_substreams =
self.outbound_substreams.len() + self.num_requested_outbound_streams;
if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS {
if let Some(protocol) = self.requested_streams.pop_front() {
self.num_requested_outbound_streams += 1;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol });
}
if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS
&& self.num_requested_outbound_streams < self.pending_messages.len()
{
self.num_requested_outbound_streams += 1;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(self.config.protocol_config.clone(), ()),
});
}

let no_streams = self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty();
Expand Down Expand Up @@ -828,7 +816,7 @@ where
{
type Item = ConnectionHandlerEvent<
KademliaProtocolConfig,
(KadRequestMsg, Option<TUserData>),
(),
KademliaHandlerEvent<TUserData>,
io::Error,
>;
Expand Down Expand Up @@ -964,7 +952,7 @@ where
{
type Item = ConnectionHandlerEvent<
KademliaProtocolConfig,
(KadRequestMsg, Option<TUserData>),
(),
KademliaHandlerEvent<TUserData>,
io::Error,
>;
Expand Down

0 comments on commit 99ad3b6

Please sign in to comment.