From 2f9f7cf9528b1bfbed32b00ab9424657f67adfee Mon Sep 17 00:00:00 2001 From: George Date: Fri, 8 Mar 2024 13:42:15 +0800 Subject: [PATCH] fix(pool): connection pool exhausted (#44) --- adapter.go | 91 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 16 deletions(-) diff --git a/adapter.go b/adapter.go index 8d7d359..56a198a 100644 --- a/adapter.go +++ b/adapter.go @@ -60,6 +60,14 @@ func (a *Adapter) getConn() redis.Conn { return a._conn } +func (a *Adapter) release(conn redis.Conn) { + if a._pool != nil { + if conn != nil { + conn.Close() + } + } +} + // finalizer is the destructor for Adapter. func finalizer(a *Adapter) { if a._conn != nil { @@ -111,7 +119,11 @@ func NewAdapterWithKey(network string, address string, key string) (*Adapter, er func NewAdapterWithPool(pool *redis.Pool) (*Adapter, error) { a := &Adapter{} a.key = "casbin_rules" - a._conn = pool.Get() + + conn := pool.Get() + defer a.release(conn) + + a._conn = conn a._pool = pool // Call the destructor when the object is released. @@ -127,7 +139,11 @@ func NewAdapterWithPoolAndOptions(pool *redis.Pool, options ...Option) (*Adapter for _, option := range options { option(a) } - a._conn = pool.Get() + + conn := pool.Get() + defer a.release(conn) + + a._conn = conn a._pool = pool // Call the destructor when the object is released. @@ -228,7 +244,10 @@ func (a *Adapter) createTable() { } func (a *Adapter) dropTable() { - _, _ = a.getConn().Do("DEL", a.key) + conn := a.getConn() + defer a.release(conn) + + _, _ = conn.Do("DEL", a.key) } func (c *CasbinRule) toStringPolicy() []string { @@ -265,14 +284,17 @@ func loadPolicyLine(line CasbinRule, model model.Model) { // LoadPolicy loads policy from database. func (a *Adapter) LoadPolicy(model model.Model) error { - num, err := redis.Int(a.getConn().Do("LLEN", a.key)) + conn := a.getConn() + defer a.release(conn) + + num, err := redis.Int(conn.Do("LLEN", a.key)) if err == redis.ErrNil { return nil } if err != nil { return err } - values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num)) + values, err := redis.Values(conn.Do("LRANGE", a.key, 0, num)) if err != nil { return err } @@ -349,7 +371,10 @@ func (a *Adapter) SavePolicy(model model.Model) error { } } - _, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) + conn := a.getConn() + defer a.release(conn) + + _, err := conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) return err } @@ -360,7 +385,11 @@ func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { if err != nil { return err } - _, err = a.getConn().Do("RPUSH", a.key, text) + + conn := a.getConn() + defer a.release(conn) + + _, err = conn.Do("RPUSH", a.key, text) return err } @@ -371,7 +400,11 @@ func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { if err != nil { return err } - _, err = a.getConn().Do("LREM", a.key, 1, text) + + conn := a.getConn() + defer a.release(conn) + + _, err = conn.Do("LREM", a.key, 1, text) return err } @@ -386,19 +419,26 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error } texts = append(texts, text) } - _, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) + + conn := a.getConn() + defer a.release(conn) + + _, err := conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...) return err } // RemovePolicies removes policy rules from the storage. func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { + conn := a.getConn() + defer a.release(conn) + for _, rule := range rules { line := savePolicyLine(ptype, rule) text, err := json.Marshal(line) if err != nil { return err } - _, err = a.getConn().Do("LREM", a.key, 1, text) + _, err = conn.Do("LREM", a.key, 1, text) if err != nil { return err } @@ -484,14 +524,17 @@ func filterFieldToLuaPattern(sec string, ptype string, fieldIndex int, fieldValu } func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter) error { - num, err := redis.Int(a.getConn().Do("LLEN", a.key)) + conn := a.getConn() + defer a.release(conn) + + num, err := redis.Int(conn.Do("LLEN", a.key)) if err == redis.ErrNil { return nil } if err != nil { return err } - values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num)) + values, err := redis.Values(conn.Do("LRANGE", a.key, 0, num)) if err != nil { return err } @@ -559,7 +602,11 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, redis.call('lrem', key, 0, '__CASBIN_DELETED__') return `) - _, err := getScript.Do(a.getConn(), a.key, pattern) + + conn := a.getConn() + defer a.release(conn) + + _, err := getScript.Do(conn, a.key, pattern) return err } @@ -592,7 +639,11 @@ func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []st end return false `) - _, err = getScript.Do(a.getConn(), a.key, textOld, textNew) + + conn := a.getConn() + defer a.release(conn) + + _, err = getScript.Do(conn, a.key, textOld, textNew) return err } @@ -640,7 +691,11 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [] return false `) args := redis.Args{}.Add(a.key).AddFlat(oldPolicies).AddFlat(newPolicies) - _, err := getScript.Do(a.getConn(), args...) + + conn := a.getConn() + defer a.release(conn) + + _, err := getScript.Do(conn, args...) return err } @@ -684,7 +739,11 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [ args := redis.Args{}.Add(a.key).Add(pattern).AddFlat(newP) //r, err := getScript.Do(a.conn, args...) //reply, err := redis.Values(r, err) - reply, err := redis.Values(getScript.Do(a.getConn(), args...)) + + conn := a.getConn() + defer a.release(conn) + + reply, err := redis.Values(getScript.Do(conn, args...)) if err != nil { return nil, err }