Skip to content

Commit

Permalink
Fix client leak after connection disconnect (#55)
Browse files Browse the repository at this point in the history
* Remove client from hashmap after connection disconnect

Signed-off-by: yongman <yming0221@gmail.com>

* No need remove client manually in client kill

Signed-off-by: yongman <yming0221@gmail.com>

* Fix client management dead lock potentially

Signed-off-by: yongman <yming0221@gmail.com>
  • Loading branch information
yongman authored Aug 10, 2022
1 parent 02a9495 commit 9ed8a3b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 70 deletions.
111 changes: 56 additions & 55 deletions src/cmd/fake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::tikv::errors::{
REDIS_INVALID_CLIENT_ID_ERR, REDIS_NOT_SUPPORTED_ERR, REDIS_NO_SUCH_CLIENT_ERR,
REDIS_VALUE_IS_NOT_INTEGER_ERR,
};
use crate::utils::{resp_int, resp_str};
use crate::utils::resp_int;
use crate::{
config::LOGGER,
tikv::errors::REDIS_UNKNOWN_SUBCOMMAND,
Expand Down Expand Up @@ -42,7 +42,7 @@ impl Fake {
self,
command: &str,
dst: &mut Connection,
cur_client: Arc<Mutex<Client>>,
cur_client: u64,
clients: Arc<Mutex<HashMap<u64, Arc<Mutex<Client>>>>>,
) -> crate::Result<()> {
let response = self.do_apply(command, cur_client, clients).await;
Expand All @@ -63,7 +63,7 @@ impl Fake {
async fn do_apply(
self,
command: &str,
cur_client: Arc<Mutex<Client>>,
cur_client: u64,
clients: Arc<Mutex<HashMap<u64, Arc<Mutex<Client>>>>>,
) -> Frame {
if !self.valid {
Expand All @@ -73,16 +73,14 @@ impl Fake {
"READWRITE" => resp_ok(),
"READONLY" => resp_ok(),
"CLIENT" => {
// TODO client more management will be added later
match self.args[0].clone().to_uppercase().as_str() {
"ID" => resp_int(cur_client.lock().await.id() as i64),
"ID" => resp_int(cur_client as i64),
"LIST" => {
if self.args.len() == 1 {
let clients_guard = clients.lock().await;
return resp_bulk(
encode_clients_info(
clients.lock().await.clone().into_values().collect(),
)
.await,
encode_clients_info(clients_guard.clone().into_values().collect())
.await,
);
}

Expand Down Expand Up @@ -115,23 +113,19 @@ impl Fake {
// three arguments format (old format)
if self.args.len() == 2 {
let mut target_client = None;
{
let lk_clients = clients.lock().await;
for client in lk_clients.values() {
let lk_client = client.lock().await;
if lk_client.peer_addr() == self.args[1] {
target_client = Some(client.clone());
break;
}
let clients_guard = clients.lock().await;
for client in clients_guard.values() {
// make sure get the client guard with clients aguard obtained
let client = client.lock().await;
if client.peer_addr() == self.args[1] {
target_client = Some(client.clone());
break;
}
}

return match target_client {
Some(client) => {
let lk_client = client.lock().await;
let mut lk_clients = clients.lock().await;
lk_client.kill().await;
lk_clients.remove(&lk_client.id());
client.kill().await;
resp_ok()
}
None => resp_err(REDIS_NO_SUCH_CLIENT_ERR),
Expand Down Expand Up @@ -169,39 +163,34 @@ impl Fake {
}

// retrieve current client id in advance for preventing dead lock during clients traverse
let cur_client_id = cur_client.lock().await.id();
let mut eligible_clients: Vec<Arc<Mutex<Client>>> = vec![];
{
let lk_clients = clients.lock().await;
for client in lk_clients.values() {
let lk_client = client.lock().await;
if !filter_peer_addr.is_empty()
&& lk_client.peer_addr() != filter_peer_addr
{
continue;
}
if !filter_local_addr.is_empty()
&& lk_client.local_addr() != filter_local_addr
{
continue;
}
if filter_id != 0 && lk_client.id() != filter_id {
continue;
}
if cur_client_id == lk_client.id() && filter_skipme {
continue;
}

eligible_clients.push(client.clone());
let clients_guard = clients.lock().await;
for client in clients_guard.values() {
let client_guard = client.lock().await;
if !filter_peer_addr.is_empty()
&& client_guard.peer_addr() != filter_peer_addr
{
continue;
}
if !filter_local_addr.is_empty()
&& client_guard.local_addr() != filter_local_addr
{
continue;
}
if filter_id != 0 && client_guard.id() != filter_id {
continue;
}
if cur_client == client_guard.id() && filter_skipme {
continue;
}

eligible_clients.push(client.clone());
}

let killed = eligible_clients.len() as i64;
let mut lk_clients = clients.lock().await;
for eligible_client in eligible_clients {
let lk_eligible_client = eligible_client.lock().await;
lk_eligible_client.kill().await;
lk_clients.remove(&lk_eligible_client.id());
// make sure get the client guard with clients aguard obtained
eligible_client.lock().await.kill().await;
}

resp_int(killed)
Expand All @@ -211,18 +200,30 @@ impl Fake {
return resp_invalid_arguments();
}

let mut w_cur_client = cur_client.lock().await;
w_cur_client.set_name(&self.args[1]);
let clients_guard = clients.lock().await;
clients_guard
.get(&cur_client)
.unwrap()
.lock()
.await
.set_name(&self.args[1]);

resp_ok()
}
"GETNAME" => {
let r_cur_client = cur_client.lock().await;
let name = r_cur_client.name();
let clients_guard = clients.lock().await;
let name = clients_guard
.get(&cur_client)
.unwrap()
.lock()
.await
.name()
.to_owned();
if name.is_empty() {
return resp_nil();
}

resp_str(name)
resp_bulk(name.into_bytes())
}
_ => resp_err(REDIS_UNKNOWN_SUBCOMMAND),
}
Expand All @@ -248,8 +249,8 @@ impl Fake {
async fn encode_clients_info(clients: Vec<Arc<Mutex<Client>>>) -> Vec<u8> {
let mut resp_list = String::new();
for client in clients {
let r_client = client.lock().await;
resp_list.push_str(&r_client.to_string());
let client = client.lock().await;
resp_list.push_str(&client.to_string());
resp_list.push('\n');
}

Expand Down
10 changes: 5 additions & 5 deletions src/cmd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ impl Command {
db: &Db,
topo: &Topo,
dst: &mut Connection,
cur_client: Arc<Mutex<Client>>,
cur_client_id: u64,
clients: Arc<Mutex<HashMap<u64, Arc<Mutex<Client>>>>>,
lua: &mut Option<Lua>,
shutdown: &mut Shutdown,
Expand Down Expand Up @@ -766,10 +766,10 @@ impl Command {
Debug(cmd) => cmd.apply(dst).await,

Cluster(cmd) => cmd.apply(topo, dst).await,
ReadWrite(cmd) => cmd.apply("readwrite", dst, cur_client, clients).await,
ReadOnly(cmd) => cmd.apply("readonly", dst, cur_client, clients).await,
Client(cmd) => cmd.apply("client", dst, cur_client, clients).await,
Info(cmd) => cmd.apply("info", dst, cur_client, clients).await,
ReadWrite(cmd) => cmd.apply("readwrite", dst, cur_client_id, clients).await,
ReadOnly(cmd) => cmd.apply("readonly", dst, cur_client_id, clients).await,
Client(cmd) => cmd.apply("client", dst, cur_client_id, clients).await,
Info(cmd) => cmd.apply("info", dst, cur_client_id, clients).await,

Scan(cmd) => cmd.apply(dst).await,
Xscan(cmd) => cmd.apply(dst).await,
Expand Down
28 changes: 18 additions & 10 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct Handler {
db: Db,

topo: Cluster,
cur_client: Arc<Mutex<Client>>,
cur_client_id: u64,
clients: Arc<Mutex<HashMap<u64, Arc<Mutex<Client>>>>>,

/// The TCP connection decorated with the redis protocol encoder / decoder
Expand Down Expand Up @@ -449,7 +449,7 @@ impl Listener {
db: self.db_holder.db(),

topo: self.topo_holder.clone(),
cur_client: arc_client.clone(),
cur_client_id: client_id,
clients: self.clients.clone(),

// Initialize the connection state. This allocates read/write
Expand All @@ -475,13 +475,15 @@ impl Listener {
// dropped.
_shutdown_complete: self.shutdown_complete_tx.clone(),
};
local_pool.spawn_pinned(|| async move {

local_pool.spawn_pinned(move || async move {
// Process the connection. If an error is encountered, log it.
CURRENT_CONNECTION_COUNTER.inc();
TOTAL_CONNECTION_PROCESSED.inc();
if let Err(err) = handler.run().await {
error!(LOGGER, "connection error {:?}", err);
}
handler.clients.lock().await.remove(&client_id);
CURRENT_CONNECTION_COUNTER.dec();
});
}
Expand Down Expand Up @@ -564,7 +566,7 @@ impl TlsListener {
let mut handler = Handler {
db: self.db_holder.db(),
topo: self.topo_holder.clone(),
cur_client: arc_client.clone(),
cur_client_id: client_id,
clients: self.clients.clone(),
connection: Connection::new_tls(&local_addr, &peer_addr, tls_stream),
inner_txn: false,
Expand All @@ -575,13 +577,14 @@ impl TlsListener {
_shutdown_complete: self.tls_shutdown_complete_tx.clone(),
};

local_pool.spawn_pinned(|| async move {
local_pool.spawn_pinned(move || async move {
// Process the connection. If an error is encountered, log it.
CURRENT_TLS_CONNECTION_COUNTER.inc();
TOTAL_CONNECTION_PROCESSED.inc();
if let Err(err) = handler.run().await {
error!(LOGGER, "tls connection error {:?}", err);
}
handler.clients.lock().await.remove(&client_id);
CURRENT_TLS_CONNECTION_COUNTER.dec();
});
}
Expand Down Expand Up @@ -706,10 +709,15 @@ impl Handler {
let cmd = Command::from_frame(frame)?;
let cmd_name = cmd.get_name().to_owned();

{
let mut w_client = self.cur_client.lock().await;
w_client.interact(&cmd_name);
}
let clients_guard = self.clients.lock().await;
// make sure get the client guard with clients aguard obtained
clients_guard
.get(&self.cur_client_id)
.unwrap()
.lock()
.await
.interact(&cmd_name);
drop(clients_guard);

let start_at = Instant::now();
REQUEST_COUNTER.inc();
Expand Down Expand Up @@ -819,7 +827,7 @@ impl Handler {
&self.db,
&self.topo,
&mut self.connection,
self.cur_client.clone(),
self.cur_client_id,
self.clients.clone(),
&mut self.lua,
&mut self.shutdown,
Expand Down

0 comments on commit 9ed8a3b

Please sign in to comment.