Skip to content

Commit

Permalink
fix(pool): use of closed network connection
Browse files Browse the repository at this point in the history
  • Loading branch information
cococolanosugar committed Mar 7, 2024
1 parent 9c324f0 commit 304a547
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 21 deletions.
77 changes: 56 additions & 21 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,26 @@ type Adapter struct {
username string
password string
tlsConfig *tls.Config
conn redis.Conn
_conn redis.Conn
_pool *redis.Pool
isFiltered bool
}

func (a *Adapter) getConn() redis.Conn {
if a._pool != nil {
return a._pool.Get()
}
return a._conn
}

// finalizer is the destructor for Adapter.
func finalizer(a *Adapter) {
a.conn.Close()
if a._conn != nil {
a._conn.Close()
}
if a._pool != nil {
a._pool.Close()
}
}

func newAdapter(network string, address string, key string,
Expand Down Expand Up @@ -98,7 +111,24 @@ func NewAdapterWithKey(network string, address string, key string) (*Adapter, er
func NewAdapterWithPool(pool *redis.Pool) (*Adapter, error) {
a := &Adapter{}
a.key = "casbin_rules"
a.conn = pool.Get()
a._conn = pool.Get()
a._pool = pool

// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)

return a, nil
}

// NewAdapterWithPoolAndOptions is the constructor for Adapter.
func NewAdapterWithPoolAndOptions(pool *redis.Pool, options ...Option) (*Adapter, error) {
a := &Adapter{}
a.key = "casbin_rules"
for _, option := range options {
option(a)
}
a._conn = pool.Get()
a._pool = pool

// Call the destructor when the object is released.
runtime.SetFinalizer(a, finalizer)
Expand Down Expand Up @@ -166,34 +196,39 @@ func (a *Adapter) open() error {
return err
}

a.conn = conn
a._conn = conn
} else if a.password == "" {
conn, err := redis.Dial(a.network, a.address, redis.DialTLSConfig(a.tlsConfig), redis.DialUseTLS(useTls))
if err != nil {
return err
}

a.conn = conn
a._conn = conn
} else {
conn, err := redis.Dial(a.network, a.address, redis.DialPassword(a.password), redis.DialTLSConfig(a.tlsConfig), redis.DialUseTLS(useTls))
if err != nil {
return err
}

a.conn = conn
a._conn = conn
}
return nil
}

func (a *Adapter) close() {
a.conn.Close()
if a._conn != nil {
a._conn.Close()
}
if a._pool != nil {
a._pool.Close()
}
}

func (a *Adapter) createTable() {
}

func (a *Adapter) dropTable() {
_, _ = a.conn.Do("DEL", a.key)
_, _ = a.getConn().Do("DEL", a.key)
}

func (c *CasbinRule) toStringPolicy() []string {
Expand Down Expand Up @@ -230,14 +265,14 @@ func loadPolicyLine(line CasbinRule, model model.Model) {

// LoadPolicy loads policy from database.
func (a *Adapter) LoadPolicy(model model.Model) error {
num, err := redis.Int(a.conn.Do("LLEN", a.key))
num, err := redis.Int(a.getConn().Do("LLEN", a.key))
if err == redis.ErrNil {
return nil
}
if err != nil {
return err
}
values, err := redis.Values(a.conn.Do("LRANGE", a.key, 0, num))
values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num))
if err != nil {
return err
}
Expand Down Expand Up @@ -314,7 +349,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
}
}

_, err := a.conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)
_, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)
return err
}

Expand All @@ -325,7 +360,7 @@ func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
if err != nil {
return err
}
_, err = a.conn.Do("RPUSH", a.key, text)
_, err = a.getConn().Do("RPUSH", a.key, text)
return err
}

Expand All @@ -336,7 +371,7 @@ func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
if err != nil {
return err
}
_, err = a.conn.Do("LREM", a.key, 1, text)
_, err = a.getConn().Do("LREM", a.key, 1, text)
return err
}

Expand All @@ -351,7 +386,7 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error
}
texts = append(texts, text)
}
_, err := a.conn.Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)
_, err := a.getConn().Do("RPUSH", redis.Args{}.Add(a.key).AddFlat(texts)...)
return err
}

Expand All @@ -363,7 +398,7 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err
if err != nil {
return err
}
_, err = a.conn.Do("LREM", a.key, 1, text)
_, err = a.getConn().Do("LREM", a.key, 1, text)
if err != nil {
return err
}
Expand Down Expand Up @@ -449,14 +484,14 @@ func filterFieldToLuaPattern(sec string, ptype string, fieldIndex int, fieldValu
}

func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter) error {
num, err := redis.Int(a.conn.Do("LLEN", a.key))
num, err := redis.Int(a.getConn().Do("LLEN", a.key))
if err == redis.ErrNil {
return nil
}
if err != nil {
return err
}
values, err := redis.Values(a.conn.Do("LRANGE", a.key, 0, num))
values, err := redis.Values(a.getConn().Do("LRANGE", a.key, 0, num))
if err != nil {
return err
}
Expand Down Expand Up @@ -524,7 +559,7 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
redis.call('lrem', key, 0, '__CASBIN_DELETED__')
return
`)
_, err := getScript.Do(a.conn, a.key, pattern)
_, err := getScript.Do(a.getConn(), a.key, pattern)
return err
}

Expand Down Expand Up @@ -557,7 +592,7 @@ func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []st
end
return false
`)
_, err = getScript.Do(a.conn, a.key, textOld, textNew)
_, err = getScript.Do(a.getConn(), a.key, textOld, textNew)
return err
}

Expand Down Expand Up @@ -605,7 +640,7 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules []
return false
`)
args := redis.Args{}.Add(a.key).AddFlat(oldPolicies).AddFlat(newPolicies)
_, err := getScript.Do(a.conn, args...)
_, err := getScript.Do(a.getConn(), args...)
return err
}

Expand Down Expand Up @@ -649,7 +684,7 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
args := redis.Args{}.Add(a.key).Add(pattern).AddFlat(newP)
//r, err := getScript.Do(a.conn, args...)
//reply, err := redis.Values(r, err)
reply, err := redis.Values(getScript.Do(a.conn, args...))
reply, err := redis.Values(getScript.Do(a.getConn(), args...))
if err != nil {
return nil, err
}
Expand Down
19 changes: 19 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,22 @@ func TestPoolAdapters(t *testing.T) {
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)
}

func TestPoolAndOptionsAdapters(t *testing.T) {
a, err := NewAdapterWithPoolAndOptions(&redis.Pool{
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", "127.0.0.1:6379")
},
}, WithKey("casbin:policy:test"))
if err != nil {
t.Fatal(err)
}

testSaveLoad(t, a)
testAutoSave(t, a)
testFilteredPolicy(t, a)
testAddPolicies(t, a)
testRemovePolicies(t, a)
testUpdatePolicies(t, a)
testUpdateFilteredPolicies(t, a)
}

0 comments on commit 304a547

Please sign in to comment.