Skip to content

server: cleanup pending_session_id_rotations for disconnected clients #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lightway-core/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ impl<AppState: Send> Connection<AppState> {
/// [`Connection::tick_interval`] for usage.
pub fn tick(&mut self) -> ConnectionResult<()> {
self.is_tick_timer_running = false;
trace!(?self.session_id, "Processing connection tick");
trace!(session_id = ?self.session_id, "Processing connection tick");

match self.state {
State::Authenticating => {
Expand All @@ -677,6 +677,7 @@ impl<AppState: Send> Connection<AppState> {
}
_ if self.connection_type.is_datagram() => match self.session.dtls_has_timed_out() {
wolfssl::Poll::Ready(true) => {
warn!(session_id = ?self.session_id, "DTLS timed out, disconnecting client");
let _ = self.disconnect();
return Err(ConnectionError::TimedOut);
}
Expand Down
3 changes: 1 addition & 2 deletions lightway-server/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ impl Connection {
pub fn begin_session_id_rotation(self: &Arc<Self>) {
let mut conn = self.lw_conn.lock().unwrap();

// A rotation is already in flight, nothing to be done this
// time.
// A rotation is already in flight, nothing to be done this time.
if conn.pending_session_id().is_some() {
return;
}
Expand Down
85 changes: 52 additions & 33 deletions lightway-server/src/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pub(crate) const CONNECTION_AGE_EXPIRATION_INTERVAL: Duration = Duration::minute
/// How often to check for connections to expire connections where authentication has expired
const CONNECTION_AUTH_EXPIRATION_INTERVAL: Duration = Duration::hours(6);

/// How often to check for pending session ids to cleanup
const PENDING_SESSION_ID_EXPIRATION_INTERVAL: Duration = Duration::hours(6);

/// How long a connection can be idle for
const CONNECTION_MAX_IDLE_AGE: Duration = Duration::days(1);

Expand Down Expand Up @@ -79,38 +82,10 @@ pub(crate) enum ConnectionManagerError {
LwContextError(#[from] ContextError),
}

async fn evict_idle_connections(manager: Weak<ConnectionManager>) {
let mut ticker = tokio::time::interval(CONNECTION_AGE_EXPIRATION_INTERVAL.unsigned_abs());
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut ticker = tokio_stream::wrappers::IntervalStream::new(ticker);

while ticker.next().await.is_some() {
let Some(manager) = manager.upgrade() else {
info!("Connection Manager has gone away, stopping");
return;
};
manager.evict_idle_connections();
}
}

async fn evict_expired_connections(manager: Weak<ConnectionManager>) {
let mut ticker = tokio::time::interval(CONNECTION_AUTH_EXPIRATION_INTERVAL.unsigned_abs());
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut ticker = tokio_stream::wrappers::IntervalStream::new(ticker);

while ticker.next().await.is_some() {
let Some(manager) = manager.upgrade() else {
info!("Connection Manager has gone away, stopping");
return;
};
manager.evict_expired_connections();
}
}

pub(crate) struct ConnectionManager {
ctx: ServerContext<ConnectionState>,
connections: Mutex<ConnectionMap<Connection>>,
pending_session_id_rotations: Mutex<HashMap<SessionId, Arc<Connection>>>,
pending_session_id_rotations: Mutex<HashMap<SessionId, Weak<Connection>>>,
encoders: Arc<Mutex<InternalIPToEncoderMap>>,
/// Total number of sessions there have ever been
total_sessions: AtomicUsize,
Expand All @@ -127,7 +102,7 @@ async fn handle_state_change(
return;
};

info!(session = ?conn.session_id(), ?state, "State changed for {:?}", conn.peer_addr(),);
info!(session = ?conn.session_id(), ?state, "State changed for {:?}", conn.peer_addr());

match state {
State::Connecting => {}
Expand Down Expand Up @@ -270,12 +245,43 @@ impl ConnectionManager {
encoders,
});

tokio::spawn(evict_idle_connections(Arc::downgrade(&conn_manager)));
tokio::spawn(evict_expired_connections(Arc::downgrade(&conn_manager)));
conn_manager.spawn_periodic_task(
CONNECTION_AGE_EXPIRATION_INTERVAL,
Self::evict_idle_connections,
);
conn_manager.spawn_periodic_task(
CONNECTION_AUTH_EXPIRATION_INTERVAL,
Self::evict_expired_connections,
);
conn_manager.spawn_periodic_task(
PENDING_SESSION_ID_EXPIRATION_INTERVAL,
Self::cleanup_pending_session_ids,
);

conn_manager
}

pub(crate) fn spawn_periodic_task<T>(self: &Arc<Self>, interval: Duration, task: T)
where
T: Fn(&Self) + Send + Sync + 'static,
{
let weak_conn_manager = Arc::downgrade(self);

tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval.unsigned_abs());
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut ticker = tokio_stream::wrappers::IntervalStream::new(ticker);

while ticker.next().await.is_some() {
let Some(conn_manager) = weak_conn_manager.upgrade() else {
info!("Connection Manager has gone away");
return;
};
task(&conn_manager);
}
});
}

delegate! {
to self.ctx {
pub(crate) fn is_supported_version(&self, v: Version) -> bool;
Expand Down Expand Up @@ -367,6 +373,10 @@ impl ConnectionManager {
connection_map::Entry::Vacant(_e) => {
// Maybe this is a pending session rotation
if let Some(c) = self.pending_session_id_rotations.lock().get(&session_id) {
let Some(c) = c.upgrade() else {
self.pending_session_id_rotations.lock().remove(&session_id);
return Err(ConnectionManagerError::NoActiveSession);
};
let update_peer_address = addr != c.peer_addr();

return Ok((c.clone(), update_peer_address));
Expand Down Expand Up @@ -403,7 +413,7 @@ impl ConnectionManager {
) {
self.pending_session_id_rotations
.lock()
.insert(new_session_id, conn.clone());
.insert(new_session_id, Arc::downgrade(conn));

metrics::udp_session_rotation_begin();
}
Expand Down Expand Up @@ -474,6 +484,15 @@ impl ConnectionManager {
}
}

#[instrument(level = "trace", skip_all)]
fn cleanup_pending_session_ids(&self) {
tracing::trace!("Cleaning up pending_session_id_rotations");

self.pending_session_id_rotations
.lock()
.retain(|_session_id, conn| conn.upgrade().is_some());
}

pub(crate) fn close_all_connections(&self) {
let connections = self.connections.lock().remove_connections();
for conn in connections {
Expand Down