From 9ed8a3b2cafbef60b2dd50a5360de11559236007 Mon Sep 17 00:00:00 2001 From: yongman Date: Wed, 10 Aug 2022 15:59:46 +0800 Subject: [PATCH] Fix client leak after connection disconnect (#55) * Remove client from hashmap after connection disconnect Signed-off-by: yongman * No need remove client manually in client kill Signed-off-by: yongman * Fix client management dead lock potentially Signed-off-by: yongman --- src/cmd/fake.rs | 111 ++++++++++++++++++++++++------------------------ src/cmd/mod.rs | 10 ++--- src/server.rs | 28 +++++++----- 3 files changed, 79 insertions(+), 70 deletions(-) diff --git a/src/cmd/fake.rs b/src/cmd/fake.rs index c24451d..9dc51b5 100644 --- a/src/cmd/fake.rs +++ b/src/cmd/fake.rs @@ -9,7 +9,7 @@ use crate::tikv::errors::{ REDIS_INVALID_CLIENT_ID_ERR, REDIS_NOT_SUPPORTED_ERR, REDIS_NO_SUCH_CLIENT_ERR, REDIS_VALUE_IS_NOT_INTEGER_ERR, }; -use crate::utils::{resp_int, resp_str}; +use crate::utils::resp_int; use crate::{ config::LOGGER, tikv::errors::REDIS_UNKNOWN_SUBCOMMAND, @@ -42,7 +42,7 @@ impl Fake { self, command: &str, dst: &mut Connection, - cur_client: Arc>, + cur_client: u64, clients: Arc>>>>, ) -> crate::Result<()> { let response = self.do_apply(command, cur_client, clients).await; @@ -63,7 +63,7 @@ impl Fake { async fn do_apply( self, command: &str, - cur_client: Arc>, + cur_client: u64, clients: Arc>>>>, ) -> Frame { if !self.valid { @@ -73,16 +73,14 @@ impl Fake { "READWRITE" => resp_ok(), "READONLY" => resp_ok(), "CLIENT" => { - // TODO client more management will be added later match self.args[0].clone().to_uppercase().as_str() { - "ID" => resp_int(cur_client.lock().await.id() as i64), + "ID" => resp_int(cur_client as i64), "LIST" => { if self.args.len() == 1 { + let clients_guard = clients.lock().await; return resp_bulk( - encode_clients_info( - clients.lock().await.clone().into_values().collect(), - ) - .await, + encode_clients_info(clients_guard.clone().into_values().collect()) + .await, ); } @@ -115,23 +113,19 @@ impl Fake { // three arguments format (old format) if self.args.len() == 2 { let mut target_client = None; - { - let lk_clients = clients.lock().await; - for client in lk_clients.values() { - let lk_client = client.lock().await; - if lk_client.peer_addr() == self.args[1] { - target_client = Some(client.clone()); - break; - } + let clients_guard = clients.lock().await; + for client in clients_guard.values() { + // make sure get the client guard with clients aguard obtained + let client = client.lock().await; + if client.peer_addr() == self.args[1] { + target_client = Some(client.clone()); + break; } } return match target_client { Some(client) => { - let lk_client = client.lock().await; - let mut lk_clients = clients.lock().await; - lk_client.kill().await; - lk_clients.remove(&lk_client.id()); + client.kill().await; resp_ok() } None => resp_err(REDIS_NO_SUCH_CLIENT_ERR), @@ -169,39 +163,34 @@ impl Fake { } // retrieve current client id in advance for preventing dead lock during clients traverse - let cur_client_id = cur_client.lock().await.id(); let mut eligible_clients: Vec>> = vec![]; - { - let lk_clients = clients.lock().await; - for client in lk_clients.values() { - let lk_client = client.lock().await; - if !filter_peer_addr.is_empty() - && lk_client.peer_addr() != filter_peer_addr - { - continue; - } - if !filter_local_addr.is_empty() - && lk_client.local_addr() != filter_local_addr - { - continue; - } - if filter_id != 0 && lk_client.id() != filter_id { - continue; - } - if cur_client_id == lk_client.id() && filter_skipme { - continue; - } - - eligible_clients.push(client.clone()); + let clients_guard = clients.lock().await; + for client in clients_guard.values() { + let client_guard = client.lock().await; + if !filter_peer_addr.is_empty() + && client_guard.peer_addr() != filter_peer_addr + { + continue; + } + if !filter_local_addr.is_empty() + && client_guard.local_addr() != filter_local_addr + { + continue; + } + if filter_id != 0 && client_guard.id() != filter_id { + continue; } + if cur_client == client_guard.id() && filter_skipme { + continue; + } + + eligible_clients.push(client.clone()); } let killed = eligible_clients.len() as i64; - let mut lk_clients = clients.lock().await; for eligible_client in eligible_clients { - let lk_eligible_client = eligible_client.lock().await; - lk_eligible_client.kill().await; - lk_clients.remove(&lk_eligible_client.id()); + // make sure get the client guard with clients aguard obtained + eligible_client.lock().await.kill().await; } resp_int(killed) @@ -211,18 +200,30 @@ impl Fake { return resp_invalid_arguments(); } - let mut w_cur_client = cur_client.lock().await; - w_cur_client.set_name(&self.args[1]); + let clients_guard = clients.lock().await; + clients_guard + .get(&cur_client) + .unwrap() + .lock() + .await + .set_name(&self.args[1]); + resp_ok() } "GETNAME" => { - let r_cur_client = cur_client.lock().await; - let name = r_cur_client.name(); + let clients_guard = clients.lock().await; + let name = clients_guard + .get(&cur_client) + .unwrap() + .lock() + .await + .name() + .to_owned(); if name.is_empty() { return resp_nil(); } - resp_str(name) + resp_bulk(name.into_bytes()) } _ => resp_err(REDIS_UNKNOWN_SUBCOMMAND), } @@ -248,8 +249,8 @@ impl Fake { async fn encode_clients_info(clients: Vec>>) -> Vec { let mut resp_list = String::new(); for client in clients { - let r_client = client.lock().await; - resp_list.push_str(&r_client.to_string()); + let client = client.lock().await; + resp_list.push_str(&client.to_string()); resp_list.push('\n'); } diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index a79e11b..856aee6 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -681,7 +681,7 @@ impl Command { db: &Db, topo: &Topo, dst: &mut Connection, - cur_client: Arc>, + cur_client_id: u64, clients: Arc>>>>, lua: &mut Option, shutdown: &mut Shutdown, @@ -766,10 +766,10 @@ impl Command { Debug(cmd) => cmd.apply(dst).await, Cluster(cmd) => cmd.apply(topo, dst).await, - ReadWrite(cmd) => cmd.apply("readwrite", dst, cur_client, clients).await, - ReadOnly(cmd) => cmd.apply("readonly", dst, cur_client, clients).await, - Client(cmd) => cmd.apply("client", dst, cur_client, clients).await, - Info(cmd) => cmd.apply("info", dst, cur_client, clients).await, + ReadWrite(cmd) => cmd.apply("readwrite", dst, cur_client_id, clients).await, + ReadOnly(cmd) => cmd.apply("readonly", dst, cur_client_id, clients).await, + Client(cmd) => cmd.apply("client", dst, cur_client_id, clients).await, + Info(cmd) => cmd.apply("info", dst, cur_client_id, clients).await, Scan(cmd) => cmd.apply(dst).await, Xscan(cmd) => cmd.apply(dst).await, diff --git a/src/server.rs b/src/server.rs index 8ce364e..c6e3dda 100644 --- a/src/server.rs +++ b/src/server.rs @@ -132,7 +132,7 @@ struct Handler { db: Db, topo: Cluster, - cur_client: Arc>, + cur_client_id: u64, clients: Arc>>>>, /// The TCP connection decorated with the redis protocol encoder / decoder @@ -449,7 +449,7 @@ impl Listener { db: self.db_holder.db(), topo: self.topo_holder.clone(), - cur_client: arc_client.clone(), + cur_client_id: client_id, clients: self.clients.clone(), // Initialize the connection state. This allocates read/write @@ -475,13 +475,15 @@ impl Listener { // dropped. _shutdown_complete: self.shutdown_complete_tx.clone(), }; - local_pool.spawn_pinned(|| async move { + + local_pool.spawn_pinned(move || async move { // Process the connection. If an error is encountered, log it. CURRENT_CONNECTION_COUNTER.inc(); TOTAL_CONNECTION_PROCESSED.inc(); if let Err(err) = handler.run().await { error!(LOGGER, "connection error {:?}", err); } + handler.clients.lock().await.remove(&client_id); CURRENT_CONNECTION_COUNTER.dec(); }); } @@ -564,7 +566,7 @@ impl TlsListener { let mut handler = Handler { db: self.db_holder.db(), topo: self.topo_holder.clone(), - cur_client: arc_client.clone(), + cur_client_id: client_id, clients: self.clients.clone(), connection: Connection::new_tls(&local_addr, &peer_addr, tls_stream), inner_txn: false, @@ -575,13 +577,14 @@ impl TlsListener { _shutdown_complete: self.tls_shutdown_complete_tx.clone(), }; - local_pool.spawn_pinned(|| async move { + local_pool.spawn_pinned(move || async move { // Process the connection. If an error is encountered, log it. CURRENT_TLS_CONNECTION_COUNTER.inc(); TOTAL_CONNECTION_PROCESSED.inc(); if let Err(err) = handler.run().await { error!(LOGGER, "tls connection error {:?}", err); } + handler.clients.lock().await.remove(&client_id); CURRENT_TLS_CONNECTION_COUNTER.dec(); }); } @@ -706,10 +709,15 @@ impl Handler { let cmd = Command::from_frame(frame)?; let cmd_name = cmd.get_name().to_owned(); - { - let mut w_client = self.cur_client.lock().await; - w_client.interact(&cmd_name); - } + let clients_guard = self.clients.lock().await; + // make sure get the client guard with clients aguard obtained + clients_guard + .get(&self.cur_client_id) + .unwrap() + .lock() + .await + .interact(&cmd_name); + drop(clients_guard); let start_at = Instant::now(); REQUEST_COUNTER.inc(); @@ -819,7 +827,7 @@ impl Handler { &self.db, &self.topo, &mut self.connection, - self.cur_client.clone(), + self.cur_client_id, self.clients.clone(), &mut self.lua, &mut self.shutdown,