Skip to content

Commit

Permalink
check all pending pings on pong/data
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 committed Aug 9, 2024
1 parent 69e30c6 commit 9d409b6
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 46 deletions.
55 changes: 10 additions & 45 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};

use crate::future::{IntervalStream, SessionClose};
use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
use crate::server::{handle_rpc_call, ConnectionState, ServerConfig};
use crate::utils::PendingPings;
use crate::{HttpBody, HttpRequest, HttpResponse, PingConfig, LOG_TARGET};

use futures_util::future::{self, Either, Fuse};
Expand Down Expand Up @@ -354,72 +354,37 @@ where
}

#[derive(Debug, Copy, Clone)]
enum KeepAlive {
pub(crate) enum KeepAlive {
Ping(Instant),
Data(Instant),
Pong(Instant),
}

async fn ping_pong_task(
mut rx: mpsc::Receiver<KeepAlive>,
max_inactive_limit: Duration,
max_inactivity_dur: Duration,
max_missed_pings: usize,
conn_id: u32,
) {
let mut polling_interval = IntervalStream::new(interval(max_inactive_limit));
let mut pending_pings: VecDeque<Instant> = VecDeque::new();
let mut missed_pings = 0;
let mut polling_interval = IntervalStream::new(interval(max_inactivity_dur));
let mut pending_pings = PendingPings::new(max_missed_pings, max_inactivity_dur, conn_id);

loop {
tokio::select! {
// If the ping is never answered, we use this timer as a fallback.
_ = polling_interval.next() => {
let mut remove = false;

if let Some(ping_start) = pending_pings.front() {
let elapsed = ping_start.elapsed();

if elapsed >= max_inactive_limit {
missed_pings += 1;
remove = true;
tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactive_limit.as_millis());
}

if missed_pings >= max_missed_pings {
tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection");
break;
}
}

if remove {
pending_pings.pop_front();
if !pending_pings.check_pending(Instant::now()) {
break;
}
}
msg = rx.recv() => {
match msg {
Some(KeepAlive::Ping(start)) => {
pending_pings.push_back(start);
pending_pings.push(start);
}
Some(KeepAlive::Pong(end)) | Some(KeepAlive::Data(end)) => {
// Both pong and data are considered as a response to the ping.
// So we might get more responses than pings that's why it's possible
// that the pending_pings may be empty.
if let Some(start) = pending_pings.pop_front() {
// Calculate the round-trip time (RTT) of the ping/pong.
// We adjust for the time it took to send to this task.
let elapsed = start.elapsed().saturating_sub(end.elapsed());

if elapsed >= max_inactive_limit {
missed_pings += 1;
tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactive_limit.as_millis());
}

tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, conn_id={conn_id}", elapsed.as_millis());

if missed_pings >= max_missed_pings {
tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection");
break;
}
if !pending_pings.check_pending(end) {
break;
}
}
None => break,
Expand Down
91 changes: 90 additions & 1 deletion server/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};

use crate::{HttpBody, HttpRequest};
use crate::{HttpBody, HttpRequest, LOG_TARGET};

use futures_util::future::{self, Either};
use hyper_util::rt::{TokioExecutor, TokioIo};
Expand Down Expand Up @@ -83,6 +85,50 @@ where
}
}

#[derive(Debug, Clone)]
pub(crate) struct PendingPings {
list: VecDeque<Instant>,
max_missed_pings: usize,
missed_pings: usize,
max_inactivity_dur: Duration,
conn_id: u32,
}

impl PendingPings {
pub(crate) fn new(max_missed_pings: usize, max_inactivity_dur: Duration, conn_id: u32) -> Self {
Self { list: VecDeque::new(), max_missed_pings, max_inactivity_dur, missed_pings: 0, conn_id }
}

pub(crate) fn push(&mut self, instant: Instant) {
self.list.push_back(instant);
}

/// Returns `true` if the pong was answered in time, `false` otherwise.
pub(crate) fn check_pending(&mut self, end: Instant) -> bool {
for ping_start in self.list.drain(..) {
// Calculate the round-trip time (RTT) of the ping/pong.
// We adjust for the time when the pong was received.
let elapsed = ping_start.elapsed().saturating_sub(end.elapsed());

tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, conn_id={}", elapsed.as_millis(), self.conn_id);

if elapsed >= self.max_inactivity_dur {
self.missed_pings += 1;
tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={}, elapsed={}ms/max={}ms", self.conn_id, elapsed.as_millis(), self.max_inactivity_dur.as_millis());
} else {
self.missed_pings = 0;
}

if self.missed_pings >= self.max_missed_pings {
tracing::debug!(target: LOG_TARGET, "Missed {} ping/pongs for conn_id={}; closing connection", self.missed_pings, self.conn_id);
return false;
}
}

true
}
}

/// Serve a service over a TCP connection without graceful shutdown.
/// This means that pending requests will be dropped when the server is stopped.
///
Expand Down Expand Up @@ -163,3 +209,46 @@ pub(crate) mod deserialize {
Ok(req)
}
}

#[cfg(test)]
mod tests {
use super::PendingPings;
use std::time::Duration;

#[test]
fn pending_ping_works() {
let mut pending_pings = PendingPings::new(1, std::time::Duration::from_secs(1), 0);

pending_pings.push(std::time::Instant::now());
assert!(pending_pings.check_pending(std::time::Instant::now()));
assert!(pending_pings.list.is_empty());
}

#[test]
fn inactive_too_long() {
let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0);

pending_pings.push(std::time::Instant::now());
pending_pings.push(std::time::Instant::now());

std::thread::sleep(Duration::from_millis(200));

assert!(!pending_pings.check_pending(std::time::Instant::now()));
}

#[test]
fn active_reset_counter() {
let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0);

pending_pings.push(std::time::Instant::now());

std::thread::sleep(Duration::from_millis(200));

assert!(pending_pings.check_pending(std::time::Instant::now()));
assert_eq!(pending_pings.missed_pings, 1);

pending_pings.push(std::time::Instant::now());
assert!(pending_pings.check_pending(std::time::Instant::now()));
assert_eq!(pending_pings.missed_pings, 0);
}
}

0 comments on commit 9d409b6

Please sign in to comment.