Skip to content

Commit

Permalink
chore: refactor cache (zeromicro#1532)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevwan authored Feb 13, 2022
1 parent e8c307e commit 2732d3c
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 122 deletions.
62 changes: 31 additions & 31 deletions core/stores/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,32 @@ type (
// DelCtx deletes cached values with keys.
DelCtx(ctx context.Context, keys ...string) error
// Get gets the cache with key and fills into v.
Get(key string, v interface{}) error
Get(key string, val interface{}) error
// GetCtx gets the cache with key and fills into v.
GetCtx(ctx context.Context, key string, v interface{}) error
GetCtx(ctx context.Context, key string, val interface{}) error
// IsNotFound checks if the given error is the defined errNotFound.
IsNotFound(err error) bool
// Set sets the cache with key and v, using c.expiry.
Set(key string, v interface{}) error
Set(key string, val interface{}) error
// SetCtx sets the cache with key and v, using c.expiry.
SetCtx(ctx context.Context, key string, v interface{}) error
SetCtx(ctx context.Context, key string, val interface{}) error
// SetWithExpire sets the cache with key and v, using given expire.
SetWithExpire(key string, v interface{}, expire time.Duration) error
SetWithExpire(key string, val interface{}, expire time.Duration) error
// SetWithExpireCtx sets the cache with key and v, using given expire.
SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error
SetWithExpireCtx(ctx context.Context, key string, val interface{}, expire time.Duration) error
// Take takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
Take(v interface{}, key string, query func(v interface{}) error) error
Take(val interface{}, key string, query func(val interface{}) error) error
// TakeCtx takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error
TakeCtx(ctx context.Context, val interface{}, key string, query func(val interface{}) error) error
// TakeWithExpire takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error
TakeWithExpire(val interface{}, key string, query func(val interface{}, expire time.Duration) error) error
// TakeWithExpireCtx takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error
TakeWithExpireCtx(ctx context.Context, val interface{}, key string,
query func(val interface{}, expire time.Duration) error) error
}

cacheCluster struct {
Expand Down Expand Up @@ -117,18 +118,18 @@ func (cc cacheCluster) DelCtx(ctx context.Context, keys ...string) error {
}

// Get gets the cache with key and fills into v.
func (cc cacheCluster) Get(key string, v interface{}) error {
return cc.GetCtx(context.Background(), key, v)
func (cc cacheCluster) Get(key string, val interface{}) error {
return cc.GetCtx(context.Background(), key, val)
}

// GetCtx gets the cache with key and fills into v.
func (cc cacheCluster) GetCtx(ctx context.Context, key string, v interface{}) error {
func (cc cacheCluster) GetCtx(ctx context.Context, key string, val interface{}) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}

return c.(Cache).GetCtx(ctx, key, v)
return c.(Cache).GetCtx(ctx, key, val)
}

// IsNotFound checks if the given error is the defined errNotFound.
Expand All @@ -137,66 +138,65 @@ func (cc cacheCluster) IsNotFound(err error) bool {
}

// Set sets the cache with key and v, using c.expiry.
func (cc cacheCluster) Set(key string, v interface{}) error {
return cc.SetCtx(context.Background(), key, v)
func (cc cacheCluster) Set(key string, val interface{}) error {
return cc.SetCtx(context.Background(), key, val)
}

// SetCtx sets the cache with key and v, using c.expiry.
func (cc cacheCluster) SetCtx(ctx context.Context, key string, v interface{}) error {
func (cc cacheCluster) SetCtx(ctx context.Context, key string, val interface{}) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}

return c.(Cache).SetCtx(ctx, key, v)
return c.(Cache).SetCtx(ctx, key, val)
}

// SetWithExpire sets the cache with key and v, using given expire.
func (cc cacheCluster) SetWithExpire(key string, v interface{}, expire time.Duration) error {
return cc.SetWithExpireCtx(context.Background(), key, v, expire)
func (cc cacheCluster) SetWithExpire(key string, val interface{}, expire time.Duration) error {
return cc.SetWithExpireCtx(context.Background(), key, val, expire)
}

// SetWithExpireCtx sets the cache with key and v, using given expire.
func (cc cacheCluster) SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error {
func (cc cacheCluster) SetWithExpireCtx(ctx context.Context, key string, val interface{}, expire time.Duration) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}

return c.(Cache).SetWithExpireCtx(ctx, key, v, expire)
return c.(Cache).SetWithExpireCtx(ctx, key, val, expire)
}

// Take takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
func (cc cacheCluster) Take(v interface{}, key string, query func(v interface{}) error) error {
return cc.TakeCtx(context.Background(), v, key, query)
func (cc cacheCluster) Take(val interface{}, key string, query func(val interface{}) error) error {
return cc.TakeCtx(context.Background(), val, key, query)
}

// TakeCtx takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
func (cc cacheCluster) TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error {
func (cc cacheCluster) TakeCtx(ctx context.Context, val interface{}, key string, query func(val interface{}) error) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}

return c.(Cache).TakeCtx(ctx, v, key, query)
return c.(Cache).TakeCtx(ctx, val, key, query)
}

// TakeWithExpire takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
func (cc cacheCluster) TakeWithExpire(v interface{}, key string,
query func(v interface{}, expire time.Duration) error) error {
return cc.TakeWithExpireCtx(context.Background(), v, key, query)
func (cc cacheCluster) TakeWithExpire(val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
return cc.TakeWithExpireCtx(context.Background(), val, key, query)
}

// TakeWithExpireCtx takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
func (cc cacheCluster) TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
func (cc cacheCluster) TakeWithExpireCtx(ctx context.Context, val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}

return c.(Cache).TakeWithExpireCtx(ctx, v, key, query)
return c.(Cache).TakeWithExpireCtx(ctx, val, key, query)
}
118 changes: 59 additions & 59 deletions core/stores/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type mockedNode struct {
}

func (mc *mockedNode) Del(keys ...string) error {
return mc.DelCtx(context.Background(), keys...)
}

func (mc *mockedNode) DelCtx(_ context.Context, keys ...string) error {
var be errorx.BatchError

for _, key := range keys {
Expand All @@ -39,10 +43,14 @@ func (mc *mockedNode) Del(keys ...string) error {
return be.Err()
}

func (mc *mockedNode) Get(key string, v interface{}) error {
func (mc *mockedNode) Get(key string, val interface{}) error {
return mc.GetCtx(context.Background(), key, val)
}

func (mc *mockedNode) GetCtx(ctx context.Context, key string, val interface{}) error {
bs, ok := mc.vals[key]
if ok {
return json.Unmarshal(bs, v)
return json.Unmarshal(bs, val)
}

return mc.errNotFound
Expand All @@ -52,8 +60,12 @@ func (mc *mockedNode) IsNotFound(err error) bool {
return errors.Is(err, mc.errNotFound)
}

func (mc *mockedNode) Set(key string, v interface{}) error {
data, err := json.Marshal(v)
func (mc *mockedNode) Set(key string, val interface{}) error {
return mc.SetCtx(context.Background(), key, val)
}

func (mc *mockedNode) SetCtx(ctx context.Context, key string, val interface{}) error {
data, err := json.Marshal(val)
if err != nil {
return err
}
Expand All @@ -62,50 +74,38 @@ func (mc *mockedNode) Set(key string, v interface{}) error {
return nil
}

func (mc *mockedNode) SetWithExpire(key string, v interface{}, _ time.Duration) error {
return mc.Set(key, v)
}

func (mc *mockedNode) Take(v interface{}, key string, query func(v interface{}) error) error {
if _, ok := mc.vals[key]; ok {
return mc.Get(key, v)
}

if err := query(v); err != nil {
return err
}

return mc.Set(key, v)
func (mc *mockedNode) SetWithExpire(key string, val interface{}, expire time.Duration) error {
return mc.SetWithExpireCtx(context.Background(), key, val, expire)
}

func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
return mc.Take(v, key, func(v interface{}) error {
return query(v, 0)
})
func (mc *mockedNode) SetWithExpireCtx(ctx context.Context, key string, val interface{}, expire time.Duration) error {
return mc.Set(key, val)
}

func (mc *mockedNode) DelCtx(_ context.Context, keys ...string) error {
return mc.Del(keys...)
func (mc *mockedNode) Take(val interface{}, key string, query func(val interface{}) error) error {
return mc.TakeCtx(context.Background(), val, key, query)
}

func (mc *mockedNode) GetCtx(_ context.Context, key string, v interface{}) error {
return mc.Get(key, v)
}
func (mc *mockedNode) TakeCtx(ctx context.Context, val interface{}, key string, query func(val interface{}) error) error {
if _, ok := mc.vals[key]; ok {
return mc.GetCtx(ctx, key, val)
}

func (mc *mockedNode) SetCtx(_ context.Context, key string, v interface{}) error {
return mc.Set(key, v)
}
if err := query(val); err != nil {
return err
}

func (mc *mockedNode) SetWithExpireCtx(_ context.Context, key string, v interface{}, expire time.Duration) error {
return mc.SetWithExpire(key, v, expire)
return mc.SetCtx(ctx, key, val)
}

func (mc *mockedNode) TakeCtx(_ context.Context, v interface{}, key string, query func(v interface{}) error) error {
return mc.Take(v, key, query)
func (mc *mockedNode) TakeWithExpire(val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
return mc.TakeWithExpireCtx(context.Background(), val, key, query)
}

func (mc *mockedNode) TakeWithExpireCtx(_ context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
return mc.TakeWithExpire(v, key, query)
func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
return mc.Take(val, key, func(val interface{}) error {
return query(val, 0)
})
}

func TestCache_SetDel(t *testing.T) {
Expand Down Expand Up @@ -141,18 +141,18 @@ func TestCache_SetDel(t *testing.T) {
}
}
for i := 0; i < total; i++ {
var v int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &v))
assert.Equal(t, i, v)
var val int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
assert.Equal(t, i, val)
}
assert.Nil(t, c.Del())
for i := 0; i < total; i++ {
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
}
for i := 0; i < total; i++ {
var v int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &v)))
assert.Equal(t, 0, v)
var val int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
assert.Equal(t, 0, val)
}
}

Expand All @@ -179,18 +179,18 @@ func TestCache_OneNode(t *testing.T) {
}
}
for i := 0; i < total; i++ {
var v int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &v))
assert.Equal(t, i, v)
var val int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
assert.Equal(t, i, val)
}
assert.Nil(t, c.Del())
for i := 0; i < total; i++ {
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
}
for i := 0; i < total; i++ {
var v int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &v)))
assert.Equal(t, 0, v)
var val int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
assert.Equal(t, 0, val)
}
}

Expand Down Expand Up @@ -230,9 +230,9 @@ func TestCache_Balance(t *testing.T) {
assert.True(t, entropy > .95, fmt.Sprintf("entropy should be greater than 0.95, but got %.2f", entropy))

for i := 0; i < total; i++ {
var v int
assert.Nil(t, c.Get(strconv.Itoa(i), &v))
assert.Equal(t, i, v)
var val int
assert.Nil(t, c.Get(strconv.Itoa(i), &val))
assert.Equal(t, i, val)
}

for i := 0; i < total/10; i++ {
Expand All @@ -244,14 +244,14 @@ func TestCache_Balance(t *testing.T) {
for i := 0; i < total/10; i++ {
var val int
if i%2 == 0 {
assert.Nil(t, c.Take(&val, strconv.Itoa(i*10), func(v interface{}) error {
*v.(*int) = i
assert.Nil(t, c.Take(&val, strconv.Itoa(i*10), func(val interface{}) error {
*val.(*int) = i
count++
return nil
}))
} else {
assert.Nil(t, c.TakeWithExpire(&val, strconv.Itoa(i*10), func(v interface{}, expire time.Duration) error {
*v.(*int) = i
assert.Nil(t, c.TakeWithExpire(&val, strconv.Itoa(i*10), func(val interface{}, expire time.Duration) error {
*val.(*int) = i
count++
return nil
}))
Expand All @@ -272,19 +272,19 @@ func TestCacheNoNode(t *testing.T) {
assert.NotNil(t, c.Get("foo", nil))
assert.NotNil(t, c.Set("foo", nil))
assert.NotNil(t, c.SetWithExpire("foo", nil, time.Second))
assert.NotNil(t, c.Take(nil, "foo", func(v interface{}) error {
assert.NotNil(t, c.Take(nil, "foo", func(val interface{}) error {
return nil
}))
assert.NotNil(t, c.TakeWithExpire(nil, "foo", func(v interface{}, duration time.Duration) error {
assert.NotNil(t, c.TakeWithExpire(nil, "foo", func(val interface{}, duration time.Duration) error {
return nil
}))
}

func calcEntropy(m map[int]int, total int) float64 {
var entropy float64

for _, v := range m {
proba := float64(v) / float64(total)
for _, val := range m {
proba := float64(val) / float64(total)
entropy -= proba * math.Log2(proba)
}

Expand Down
Loading

0 comments on commit 2732d3c

Please sign in to comment.