Skip to content

Commit

Permalink
Merge pull request #6 from ackintosh/experimental-fix-replay
Browse files Browse the repository at this point in the history
Fix replay_active_requests
  • Loading branch information
njgheorghita authored Oct 16, 2023
2 parents 5e05a7d + 57226fe commit afe7466
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 109 deletions.
64 changes: 36 additions & 28 deletions src/handler/active_requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,47 @@ impl ActiveRequests {
let nonce = *request_call.packet().message_nonce();
self.active_requests_mapping
.entry(node_address.clone())
.or_insert_with(Vec::new)
.or_default()
.push(request_call);
self.active_requests_nonce_mapping
.insert(nonce, node_address);
}

/// Update the underlying packet for the request via message nonce.
pub fn update_packet(&mut self, old_nonce: MessageNonce, new_packet: Packet) {
let node_address =
if let Some(node_address) = self.active_requests_nonce_mapping.remove(&old_nonce) {
node_address
} else {
debug_unreachable!("expected to find nonce in active_requests_nonce_mapping");
error!("expected to find nonce in active_requests_nonce_mapping");
return;
};

self.active_requests_nonce_mapping
.insert(new_packet.header.message_nonce, node_address.clone());

match self.active_requests_mapping.entry(node_address) {
Entry::Occupied(mut requests) => {
let maybe_request_call = requests
.get_mut()
.iter_mut()
.find(|req| req.packet().message_nonce() == &old_nonce);

if let Some(request_call) = maybe_request_call {
request_call.update_packet(new_packet);
} else {
debug_unreachable!("expected to find request call in active_requests_mapping");
error!("expected to find request call in active_requests_mapping");
}
}
Entry::Vacant(_) => {
debug_unreachable!("expected to find node address in active_requests_mapping");
error!("expected to find node address in active_requests_mapping");
}
}
}

pub fn get(&self, node_address: &NodeAddress) -> Option<&Vec<RequestCall>> {
self.active_requests_mapping.get(node_address)
}
Expand Down Expand Up @@ -80,33 +115,6 @@ impl ActiveRequests {
Some(requests)
}

/// Remove requests associated with a node, except for the request that has the given message nonce.
pub fn remove_requests_except(
&mut self,
node_address: &NodeAddress,
except: &MessageNonce,
) -> Option<Vec<RequestCall>> {
let request_ids = self
.active_requests_mapping
.get(node_address)?
.iter()
.filter(|req| req.packet().message_nonce() != except)
.map(|req| req.id().into())
.collect::<Vec<_>>();

let mut requests = vec![];
for id in request_ids.iter() {
match self.remove_request(node_address, id) {
Some(request_call) => requests.push(request_call),
None => {
debug_unreachable!("expected to find request with id");
error!("expected to find request with id");
}
}
}
Some(requests)
}

/// Remove a single request identified by its id.
pub fn remove_request(
&mut self,
Expand Down
73 changes: 36 additions & 37 deletions src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use crate::{
use delay_map::HashMapDelay;
use enr::{CombinedKey, NodeId};
use futures::prelude::*;
use more_asserts::debug_unreachable;
use parking_lot::RwLock;
use smallvec::SmallVec;
use std::{
Expand Down Expand Up @@ -481,7 +482,7 @@ impl Handler {
trace!("Request queued for node: {}", node_address);
self.pending_requests
.entry(node_address)
.or_insert_with(Vec::new)
.or_default()
.push(PendingRequest {
contact,
request_id,
Expand Down Expand Up @@ -965,44 +966,42 @@ impl Handler {
node_address,
message_nonce
);
let active_requests = if let Some(nonce) = message_nonce {
// Except the active request that was used to establish the new session, as it has
// already been handled and shouldn't be replayed.
self.active_requests
.remove_requests_except(node_address, &nonce)
} else {
self.active_requests.remove_requests(node_address)
}
.unwrap_or_default();
for req in active_requests {
let (req_id, contact, body) = req.into_request_parts();
trace!(
"Active request to be replayed. {}, {contact}, {body}",
RequestId::from(&req_id),
);
// Remove the request from the packet filter here since the request is added in
// `self.send_request()` again.
self.remove_expected_response(contact.socket_addr());
if let Err(request_error) = self.send_request::<P>(contact, req_id.clone(), body).await
{
warn!("Failed to send next awaiting request {request_error}");
// Inform the service that the request failed
match req_id {
HandlerReqId::Internal(_) => {
// An internal request could not be sent. For now we do nothing about
// this.
}
HandlerReqId::External(id) => {
if let Err(e) = self
.service_send
.send(HandlerOut::RequestFailed(id, request_error))
.await
{
warn!("Failed to inform that request failed {e}");
}

let packets = if let Some(session) = self.sessions.get_mut(node_address) {
let mut packets = vec![];
for request_call in self
.active_requests
.get(node_address)
.unwrap_or(&vec![])
.iter()
.filter(|req| {
// Except the active request that was used to establish the new session, as it has
// already been handled and shouldn't be replayed.
if let Some(nonce) = message_nonce.as_ref() {
req.packet().message_nonce() != nonce
} else {
true
}
}
})
{
let new_packet = session
.encrypt_message::<P>(self.node_id, &request_call.encode())
.unwrap();

packets.push((*request_call.packet().message_nonce(), new_packet));
}

packets
} else {
debug_unreachable!("Attempted to replay active requests but session doesn't exist.");
error!("Attempted to replay active requests but session doesn't exist.");
return;
};

for (old_nonce, new_packet) in packets {
self.active_requests
.update_packet(old_nonce, new_packet.clone());
self.send(node_address.clone(), new_packet).await;
}
}

Expand Down
5 changes: 0 additions & 5 deletions src/handler/request_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,4 @@ impl RequestCall {
pub fn remaining_responses_mut(&mut self) -> &mut Option<u64> {
&mut self.remaining_responses
}

/// Returns the id, contact, and request body for this call.
pub fn into_request_parts(self) -> (HandlerReqId, NodeContact, RequestBody) {
(self.request_id, self.contact, self.request)
}
}
57 changes: 29 additions & 28 deletions src/handler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,34 +400,6 @@ async fn test_active_requests_remove_requests() {
assert!(active_requests.remove_requests(&req_3_addr).is_none());
}

#[tokio::test]
async fn test_active_requests_remove_requests_except() {
const EXPIRY: Duration = Duration::from_secs(5);
let mut active_requests = ActiveRequests::new(EXPIRY);

let node_1 = create_node();
let node_2 = create_node();
let (req_1, req_1_addr) = create_req_call(&node_1);
let (req_2, req_2_addr) = create_req_call(&node_2);
let (req_3, req_3_addr) = create_req_call(&node_2);

let req_2_nonce = req_2.packet().header.message_nonce;
let req_3_id: RequestId = req_3.id().into();

active_requests.insert(req_1_addr, req_1);
active_requests.insert(req_2_addr.clone(), req_2);
active_requests.insert(req_3_addr, req_3);

let removed_requests = active_requests
.remove_requests_except(&req_2_addr, &req_2_nonce)
.unwrap();
active_requests.check_invariant();

assert_eq!(1, removed_requests.len());
let removed_request_id: RequestId = removed_requests.first().unwrap().id().into();
assert_eq!(removed_request_id, req_3_id);
}

#[tokio::test]
async fn test_active_requests_remove_request() {
const EXPIRY: Duration = Duration::from_secs(5);
Expand Down Expand Up @@ -504,6 +476,35 @@ async fn test_active_requests_remove_by_nonce() {
assert!(active_requests.remove_by_nonce(&random_nonce).is_none());
}

#[tokio::test]
async fn test_active_requests_update_packet() {
const EXPIRY: Duration = Duration::from_secs(5);
let mut active_requests = ActiveRequests::new(EXPIRY);

let node_1 = create_node();
let node_2 = create_node();
let (req_1, req_1_addr) = create_req_call(&node_1);
let (req_2, req_2_addr) = create_req_call(&node_2);
let (req_3, req_3_addr) = create_req_call(&node_2);

let old_nonce = *req_2.packet().message_nonce();
active_requests.insert(req_1_addr, req_1);
active_requests.insert(req_2_addr.clone(), req_2);
active_requests.insert(req_3_addr, req_3);
active_requests.check_invariant();

let new_packet = Packet::new_random(&node_2.node_id()).unwrap();
let new_nonce = new_packet.message_nonce();
active_requests.update_packet(old_nonce, new_packet.clone());
active_requests.check_invariant();

assert_eq!(2, active_requests.get(&req_2_addr).unwrap().len());
assert!(active_requests.remove_by_nonce(&old_nonce).is_none());
let (addr, req) = active_requests.remove_by_nonce(new_nonce).unwrap();
assert_eq!(addr, req_2_addr);
assert_eq!(req.packet(), &new_packet);
}

#[tokio::test]
async fn test_self_request_ipv4() {
init();
Expand Down
22 changes: 11 additions & 11 deletions src/ipmode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ impl IpMode {
}
}

/// Copied from the standard library. See <https://github.com/rust-lang/rust/issues/27709>
/// The current code is behind the `ip` feature.
pub const fn to_ipv4_mapped(ip: &std::net::Ipv6Addr) -> Option<std::net::Ipv4Addr> {
match ip.octets() {
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => {
Some(std::net::Ipv4Addr::new(a, b, c, d))
}
_ => None,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -230,14 +241,3 @@ mod tests {
.test();
}
}

/// Copied from the standard library. See <https://github.com/rust-lang/rust/issues/27709>
/// The current code is behind the `ip` feature.
pub const fn to_ipv4_mapped(ip: &std::net::Ipv6Addr) -> Option<std::net::Ipv4Addr> {
match ip.octets() {
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, a, b, c, d] => {
Some(std::net::Ipv4Addr::new(a, b, c, d))
}
_ => None,
}
}

0 comments on commit afe7466

Please sign in to comment.