Skip to content

Commit

Permalink
fix: lru-weighted-cache mem leak (apache#480)
Browse files Browse the repository at this point in the history
* fix: lru-weighted-cache mem leak

* return error when cap is 0
  • Loading branch information
jiacai2050 authored Dec 13, 2022
1 parent 9aaa10b commit 15c9074
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 74 deletions.
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 14 additions & 6 deletions analytic_engine/src/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

//! Setup the analytic engine

use std::{path::Path, pin::Pin, sync::Arc};
use std::{num::NonZeroUsize, path::Path, pin::Pin, sync::Arc};

use async_trait::async_trait;
use common_util::define_result;
Expand Down Expand Up @@ -82,6 +82,11 @@ pub enum Error {
OpenKafka {
source: message_queue::kafka::kafka_impl::Error,
},

#[snafu(display("Failed to create mem cache, err:{}", source))]
OpenMemCache {
source: object_store::mem_cache::Error,
},
}

define_result!(Error);
Expand Down Expand Up @@ -423,11 +428,14 @@ fn open_storage(
}

if opts.mem_cache_capacity.as_bytes() > 0 {
store = Arc::new(MemCacheStore::new(
opts.mem_cache_partition_bits,
opts.mem_cache_capacity.as_bytes() as usize,
store,
)) as _;
store = Arc::new(
MemCacheStore::try_new(
opts.mem_cache_partition_bits,
NonZeroUsize::new(opts.mem_cache_capacity.as_bytes() as usize).unwrap(),
store,
)
.context(OpenMemCache)?,
) as _;
}

Ok(store)
Expand Down
146 changes: 84 additions & 62 deletions components/object_store/src/mem_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,85 +5,104 @@
//! 2. Builtin Partition to reduce lock contention

use std::{
collections::hash_map::DefaultHasher,
fmt::Display,
collections::hash_map::{DefaultHasher, RandomState},
fmt::{self, Display},
hash::{Hash, Hasher},
num::NonZeroUsize,
ops::Range,
sync::Arc,
};

use async_trait::async_trait;
use bytes::Bytes;
use clru::{CLruCache, CLruCacheConfig, WeightScale};
use futures::stream::BoxStream;
use lru_weighted_cache::{LruWeightedCache, Weighted};
use snafu::{OptionExt, Snafu};
use tokio::{io::AsyncWrite, sync::Mutex};
use upstream::{path::Path, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result};

struct CachedBytes(Bytes);
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("mem cache cap must large than 0",))]
InvalidCapacity,
}

struct CustomScale;

impl Weighted for CachedBytes {
fn weight(&self) -> usize {
self.0.len()
impl WeightScale<String, Bytes> for CustomScale {
fn weight(&self, _key: &String, value: &Bytes) -> usize {
value.len()
}
}

#[derive(Debug)]
struct Partition {
inner: Mutex<LruWeightedCache<String, CachedBytes>>,
inner: Mutex<CLruCache<String, Bytes, RandomState, CustomScale>>,
}

impl Partition {
fn new(mem_cap: usize) -> Self {
fn new(mem_cap: NonZeroUsize) -> Self {
let cache = CLruCache::with_config(CLruCacheConfig::new(mem_cap).with_scale(CustomScale));

Self {
inner: Mutex::new(LruWeightedCache::new(1, mem_cap).expect("invalid params")),
inner: Mutex::new(cache),
}
}
}

impl Partition {
// TODO(chenxiang): also support `&str`, this need to changes to
// lru_weighted_cache
async fn get(&self, key: &String) -> Option<Bytes> {
async fn get(&self, key: &str) -> Option<Bytes> {
let mut guard = self.inner.lock().await;
guard.get(key).map(|v| v.0.clone())
guard.get(key).cloned()
}

async fn insert(&self, key: String, value: Bytes) {
let mut guard = self.inner.lock().await;
// don't care error now.
_ = guard.insert(key, CachedBytes(value));
_ = guard.put_with_weight(key, value);
}

#[cfg(test)]
async fn keys(&self) -> Vec<String> {
let guard = self.inner.lock().await;
guard
.iter()
.map(|(key, _)| key)
.cloned()
.collect::<Vec<_>>()
}
}

#[derive(Debug)]
struct MemCache {
/// Max memory this store can use
mem_cap: usize,
mem_cap: NonZeroUsize,
partitions: Vec<Arc<Partition>>,
partition_mask: usize,
}

impl MemCache {
fn new(partition_bits: usize, mem_cap: usize) -> Self {
fn try_new(partition_bits: usize, mem_cap: NonZeroUsize) -> std::result::Result<Self, Error> {
let partition_num = 1 << partition_bits;
let cap_per_part = mem_cap / partition_num;
let cap_per_part = mem_cap
.checked_mul(NonZeroUsize::new(partition_num).unwrap())
.context(InvalidCapacity)?;
let partitions = (0..partition_num)
.map(|_| Arc::new(Partition::new(cap_per_part)))
.collect::<Vec<_>>();

Self {
Ok(Self {
mem_cap,
partitions,
partition_mask: partition_num - 1,
}
})
}

fn locate_partition(&self, key: &String) -> Arc<Partition> {
fn locate_partition(&self, key: &str) -> Arc<Partition> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
self.partitions[hasher.finish() as usize & self.partition_mask].clone()
}

async fn get(&self, key: &String) -> Option<Bytes> {
async fn get(&self, key: &str) -> Option<Bytes> {
let partition = self.locate_partition(key);
partition.get(key).await
}
Expand All @@ -92,37 +111,48 @@ impl MemCache {
let partition = self.locate_partition(&key);
partition.insert(key, value).await;
}

#[cfg(test)]
async fn to_string(&self) -> String {
futures::future::join_all(
self.partitions
.iter()
.map(|part| async { part.keys().await.join(",") }),
)
.await
.into_iter()
.enumerate()
.map(|(part_no, keys)| format!("{}: [{}]", part_no, keys))
.collect::<Vec<_>>()
.join("\n")
}
}

impl Display for MemCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemCache")
.field("mem_cap", &self.mem_cap)
.field("mask", &self.partition_mask)
.field("partitons", &self.partitions)
.field("partitons", &self.partitions.len())
.finish()
}
}

#[derive(Debug)]
pub struct MemCacheStore {
cache: MemCache,
underlying_store: Arc<dyn ObjectStore>,
}

impl MemCacheStore {
// Note: mem_cap must be larger than 0
pub fn new(
pub fn try_new(
partition_bits: usize,
mem_cap: usize,
mem_cap: NonZeroUsize,
underlying_store: Arc<dyn ObjectStore>,
) -> Self {
assert!(mem_cap > 0);

Self {
cache: MemCache::new(partition_bits, mem_cap),
) -> std::result::Result<Self, Error> {
MemCache::try_new(partition_bits, mem_cap).map(|cache| Self {
cache,
underlying_store,
}
})
}

fn cache_key(location: &Path, range: &Range<usize>) -> String {
Expand All @@ -136,6 +166,12 @@ impl Display for MemCacheStore {
}
}

impl fmt::Debug for MemCacheStore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemCacheStore").finish()
}
}

#[async_trait]
impl ObjectStore for MemCacheStore {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
Expand Down Expand Up @@ -216,7 +252,7 @@ mod test {
let local_path = tempdir().unwrap();
let local_store = Arc::new(LocalFileSystem::new_with_prefix(local_path.path()).unwrap());

MemCacheStore::new(bits, mem_cap, local_store)
MemCacheStore::try_new(bits, NonZeroUsize::new(mem_cap).unwrap(), local_store).unwrap()
}

#[tokio::test]
Expand All @@ -239,10 +275,6 @@ mod test {
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_some());
assert_eq!(
r#"MemCache { mem_cap: 13, mask: 0, partitons: [Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 13, max_total_weight: 13, current_weight: 5 } } }] }"#,
format!("{}", store)
);

// get bytes from [5, 10), insert to cache
let range5_10 = 5..10;
Expand All @@ -257,18 +289,19 @@ mod test {
.get(&MemCacheStore::cache_key(&location, &range5_10))
.await
.is_some());
assert_eq!(
r#"MemCache { mem_cap: 13, mask: 0, partitons: [Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 13, max_total_weight: 13, current_weight: 10 } } }] }"#,
format!("{}", store)
);

// get bytes from [5, 10), insert to cache
// get bytes from [10, 15), insert to cache
// cache is full, evict [0, 5)
let range10_15 = 5..10;
let range10_15 = 10..15;
_ = store
.get_range(&location, range10_15.clone())
.await
.unwrap();
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_none());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range5_10))
Expand All @@ -279,20 +312,6 @@ mod test {
.get(&MemCacheStore::cache_key(&location, &range10_15))
.await
.is_some());
assert_eq!(
r#"MemCache { mem_cap: 13, mask: 0, partitons: [Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 13, max_total_weight: 13, current_weight: 10 } } }] }"#,
format!("{}", store)
);

let range10_13 = 10..13;
_ = store
.get_range(&location, range10_13.clone())
.await
.unwrap();
assert_eq!(
r#"MemCache { mem_cap: 13, mask: 0, partitons: [Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 13, max_total_weight: 13, current_weight: 13 } } }] }"#,
format!("{}", store)
);
}

#[tokio::test]
Expand All @@ -314,8 +333,11 @@ mod test {
.unwrap();

assert_eq!(
r#"MemCache { mem_cap: 100, mask: 3, partitons: [Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 25, max_total_weight: 25, current_weight: 0 } } }, Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 25, max_total_weight: 25, current_weight: 5 } } }, Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 25, max_total_weight: 25, current_weight: 0 } } }, Partition { inner: Mutex { data: LruWeightedCache { max_item_weight: 25, max_total_weight: 25, current_weight: 5 } } }] }"#,
format!("{}", store)
r#"0: []
1: [partition.sst-100-105]
2: []
3: [partition.sst-0-5]"#,
store.cache.to_string().await
);

assert!(store
Expand Down

0 comments on commit 15c9074

Please sign in to comment.