From 99bb602106ab54ea4a7acd317e0932f76177e656 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 25 May 2017 13:38:04 +0300 Subject: [PATCH 1/2] Embed Cmdable into StatefulCmdable --- cluster.go | 8 +++----- commands.go | 12 +++++++++++- pipeline.go | 2 -- redis.go | 10 ++++------ ring.go | 5 ++--- sentinel.go | 2 +- tx.go | 7 ++----- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/cluster.go b/cluster.go index 99d8d7da61..6a4bbe8f17 100644 --- a/cluster.go +++ b/cluster.go @@ -349,7 +349,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { opt: opt, nodes: newClusterNodes(opt), } - c.cmdable.process = c.Process + c.setProcessor(c.Process) // Add initial nodes. for _, addr := range opt.Addrs { @@ -678,8 +678,7 @@ func (c *ClusterClient) Pipeline() Pipeliner { pipe := Pipeline{ exec: c.pipelineExec, } - pipe.cmdable.process = pipe.Process - pipe.statefulCmdable.process = pipe.Process + pipe.setProcessor(pipe.Process) return &pipe } @@ -801,8 +800,7 @@ func (c *ClusterClient) TxPipeline() Pipeliner { pipe := Pipeline{ exec: c.txPipelineExec, } - pipe.cmdable.process = pipe.Process - pipe.statefulCmdable.process = pipe.Process + pipe.setProcessor(pipe.Process) return &pipe } diff --git a/commands.go b/commands.go index a0984321bd..51dcf941fa 100644 --- a/commands.go +++ b/commands.go @@ -238,6 +238,7 @@ type Cmdable interface { } type StatefulCmdable interface { + Cmdable Auth(password string) *StatusCmd Select(index int) *StatusCmd ClientSetName(name string) *BoolCmd @@ -255,10 +256,20 @@ type cmdable struct { process func(cmd Cmder) error } +func (c *cmdable) setProcessor(fn func(Cmder) error) { + c.process = fn +} + type statefulCmdable struct { + cmdable process func(cmd Cmder) error } +func (c *statefulCmdable) setProcessor(fn func(Cmder) error) { + c.process = fn + c.cmdable.setProcessor(fn) +} + //------------------------------------------------------------------------------ func (c *statefulCmdable) Auth(password string) *StatusCmd { @@ -280,7 +291,6 @@ func (c *cmdable) Ping() *StatusCmd { } func (c *cmdable) Wait(numSlaves int, timeout time.Duration) *IntCmd { - cmd := NewIntCmd("wait", numSlaves, int(timeout/time.Millisecond)) c.process(cmd) return cmd diff --git a/pipeline.go b/pipeline.go index 977f5eb3dc..de99f12459 100644 --- a/pipeline.go +++ b/pipeline.go @@ -10,7 +10,6 @@ import ( type pipelineExecer func([]Cmder) error type Pipeliner interface { - Cmdable StatefulCmdable Process(cmd Cmder) error Close() error @@ -26,7 +25,6 @@ var _ Pipeliner = (*Pipeline)(nil) // http://redis.io/topics/pipelining. It's safe for concurrent use // by multiple goroutines. type Pipeline struct { - cmdable statefulCmdable exec pipelineExecer diff --git a/redis.go b/redis.go index ca88df0d16..89f985ee7d 100644 --- a/redis.go +++ b/redis.go @@ -294,7 +294,7 @@ func newClient(opt *Options, pool pool.Pooler) *Client { connPool: pool, }, } - client.cmdable.process = client.Process + client.setProcessor(client.Process) return &client } @@ -307,7 +307,7 @@ func NewClient(opt *Options) *Client { func (c *Client) copy() *Client { c2 := new(Client) *c2 = *c - c2.cmdable.process = c2.Process + c2.setProcessor(c2.Process) return c2 } @@ -332,8 +332,7 @@ func (c *Client) Pipeline() Pipeliner { pipe := Pipeline{ exec: c.pipelineExecer(c.pipelineProcessCmds), } - pipe.cmdable.process = pipe.Process - pipe.statefulCmdable.process = pipe.Process + pipe.setProcessor(pipe.Process) return &pipe } @@ -346,8 +345,7 @@ func (c *Client) TxPipeline() Pipeliner { pipe := Pipeline{ exec: c.pipelineExecer(c.txPipelineProcessCmds), } - pipe.cmdable.process = pipe.Process - pipe.statefulCmdable.process = pipe.Process + pipe.setProcessor(pipe.Process) return &pipe } diff --git a/ring.go b/ring.go index 270a81f978..9c57430d02 100644 --- a/ring.go +++ b/ring.go @@ -148,7 +148,7 @@ func NewRing(opt *RingOptions) *Ring { cmdsInfoOnce: new(sync.Once), } - ring.cmdable.process = ring.Process + ring.setProcessor(ring.Process) for name, addr := range opt.Addrs { clopt := opt.clientOptions() clopt.Addr = addr @@ -385,8 +385,7 @@ func (c *Ring) Pipeline() Pipeliner { pipe := Pipeline{ exec: c.pipelineExec, } - pipe.cmdable.process = pipe.Process - pipe.statefulCmdable.process = pipe.Process + pipe.setProcessor(pipe.Process) return &pipe } diff --git a/sentinel.go b/sentinel.go index 799f530fca..da3a4312bb 100644 --- a/sentinel.go +++ b/sentinel.go @@ -82,7 +82,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { }, }, } - client.cmdable.process = client.Process + client.setProcessor(client.Process) return &client } diff --git a/tx.go b/tx.go index 21c5c70f5e..5ef89619ba 100644 --- a/tx.go +++ b/tx.go @@ -13,7 +13,6 @@ const TxFailedErr = internal.RedisError("redis: transaction failed") // by multiple goroutines, because Exec resets list of watched keys. // If you don't need WATCH it is better to use Pipeline. type Tx struct { - cmdable statefulCmdable baseClient } @@ -25,8 +24,7 @@ func (c *Client) newTx() *Tx { connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), }, } - tx.cmdable.process = tx.Process - tx.statefulCmdable.process = tx.Process + tx.setProcessor(tx.Process) return &tx } @@ -80,8 +78,7 @@ func (c *Tx) Pipeline() Pipeliner { pipe := Pipeline{ exec: c.pipelineExecer(c.txPipelineProcessCmds), } - pipe.cmdable.process = pipe.Process - pipe.statefulCmdable.process = pipe.Process + pipe.setProcessor(pipe.Process) return &pipe } From 63cdc6182c05d9d05fb0d3cff38d18cb304df2d3 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 25 May 2017 14:16:39 +0300 Subject: [PATCH 2/2] Add Options.OnConnect --- cluster.go | 5 +++- commands.go | 4 +-- options.go | 3 +++ redis.go | 72 +++++++++++++++++++++++++++++++++++++++++++-------- redis_test.go | 23 ++++++++++++++++ ring.go | 4 +++ sentinel.go | 4 +++ 7 files changed, 101 insertions(+), 14 deletions(-) diff --git a/cluster.go b/cluster.go index 6a4bbe8f17..e3c5832fcc 100644 --- a/cluster.go +++ b/cluster.go @@ -35,6 +35,8 @@ type ClusterOptions struct { // Following options are copied from Options struct. + OnConnect func(*Conn) error + MaxRetries int Password string @@ -65,6 +67,8 @@ func (opt *ClusterOptions) clientOptions() *Options { const disableIdleCheck = -1 return &Options{ + OnConnect: opt.OnConnect, + MaxRetries: opt.MaxRetries, Password: opt.Password, ReadOnly: opt.ReadOnly, @@ -77,7 +81,6 @@ func (opt *ClusterOptions) clientOptions() *Options { PoolTimeout: opt.PoolTimeout, IdleTimeout: opt.IdleTimeout, - // IdleCheckFrequency is not copied to disable reaper IdleCheckFrequency: disableIdleCheck, } } diff --git a/commands.go b/commands.go index 51dcf941fa..3956cf74d8 100644 --- a/commands.go +++ b/commands.go @@ -42,6 +42,7 @@ type Cmdable interface { Pipeline() Pipeliner Pipelined(fn func(Pipeliner) error) ([]Cmder, error) + ClientGetName() *StringCmd Echo(message interface{}) *StringCmd Ping() *StatusCmd Quit() *StatusCmd @@ -242,7 +243,6 @@ type StatefulCmdable interface { Auth(password string) *StatusCmd Select(index int) *StatusCmd ClientSetName(name string) *BoolCmd - ClientGetName() *StringCmd ReadOnly() *StatusCmd ReadWrite() *StatusCmd } @@ -1649,7 +1649,7 @@ func (c *statefulCmdable) ClientSetName(name string) *BoolCmd { } // ClientGetName returns the name of the connection. -func (c *statefulCmdable) ClientGetName() *StringCmd { +func (c *cmdable) ClientGetName() *StringCmd { cmd := NewStringCmd("client", "getname") c.process(cmd) return cmd diff --git a/options.go b/options.go index d2aefb4755..1695c0b84c 100644 --- a/options.go +++ b/options.go @@ -24,6 +24,9 @@ type Options struct { // Network and Addr options. Dialer func() (net.Conn, error) + // Hook that is called when new connection is established. + OnConnect func(*Conn) error + // Optional password. Must match the password specified in the // requirepass server configuration option. Password string diff --git a/redis.go b/redis.go index 89f985ee7d..303877fcdb 100644 --- a/redis.go +++ b/redis.go @@ -21,11 +21,6 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } -// Options returns read-only Options that were used to create the client. -func (c *baseClient) Options() *Options { - return c.opt -} - func (c *baseClient) conn() (*pool.Conn, bool, error) { cn, isNew, err := c.connPool.Get() if err != nil { @@ -55,13 +50,23 @@ func (c *baseClient) putConn(cn *pool.Conn, err error) bool { func (c *baseClient) initConn(cn *pool.Conn) error { cn.Inited = true - if c.opt.Password == "" && c.opt.DB == 0 && !c.opt.ReadOnly { + if c.opt.Password == "" && + c.opt.DB == 0 && + !c.opt.ReadOnly && + c.opt.OnConnect == nil { return nil } - // Temp client for Auth and Select. - client := newClient(c.opt, pool.NewSingleConnPool(cn)) - _, err := client.Pipelined(func(pipe Pipeliner) error { + // Temp client to initialize connection. + conn := &Conn{ + baseClient: baseClient{ + opt: c.opt, + connPool: pool.NewSingleConnPool(cn), + }, + } + conn.setProcessor(conn.Process) + + _, err := conn.Pipelined(func(pipe Pipeliner) error { if c.opt.Password != "" { pipe.Auth(c.opt.Password) } @@ -76,7 +81,14 @@ func (c *baseClient) initConn(cn *pool.Conn) error { return nil }) - return err + if err != nil { + return err + } + + if c.opt.OnConnect != nil { + return c.opt.OnConnect(conn) + } + return nil } func (c *baseClient) Process(cmd Cmder) error { @@ -182,7 +194,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer { } } -func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { +func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmds...); err != nil { setCmdsErr(cmds, err) @@ -311,6 +323,11 @@ func (c *Client) copy() *Client { return c2 } +// Options returns read-only Options that were used to create the client. +func (c *Client) Options() *Options { + return c.opt +} + // PoolStats returns connection pool stats. func (c *Client) PoolStats() *PoolStats { s := c.connPool.Stats() @@ -375,3 +392,36 @@ func (c *Client) PSubscribe(channels ...string) *PubSub { } return pubsub } + +//------------------------------------------------------------------------------ + +// Conn is like Client, but its pool contains single connection. +type Conn struct { + baseClient + statefulCmdable +} + +func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { + return c.Pipeline().pipelined(fn) +} + +func (c *Conn) Pipeline() Pipeliner { + pipe := Pipeline{ + exec: c.pipelineExecer(c.pipelineProcessCmds), + } + pipe.setProcessor(pipe.Process) + return &pipe +} + +func (c *Conn) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { + return c.TxPipeline().pipelined(fn) +} + +// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. +func (c *Conn) TxPipeline() Pipeliner { + pipe := Pipeline{ + exec: c.pipelineExecer(c.txPipelineProcessCmds), + } + pipe.setProcessor(pipe.Process) + return &pipe +} diff --git a/redis_test.go b/redis_test.go index a27e3bc14a..407d378406 100644 --- a/redis_test.go +++ b/redis_test.go @@ -296,3 +296,26 @@ var _ = Describe("Client timeout", func() { testTimeout() }) }) + +var _ = Describe("Client OnConnect", func() { + var client *redis.Client + + BeforeEach(func() { + opt := redisOptions() + opt.OnConnect = func(cn *redis.Conn) error { + return cn.ClientSetName("on_connect").Err() + } + + client = redis.NewClient(opt) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("calls OnConnect", func() { + name, err := client.ClientGetName().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(name).To(Equal("on_connect")) + }) +}) diff --git a/ring.go b/ring.go index 9c57430d02..a9666bc72f 100644 --- a/ring.go +++ b/ring.go @@ -29,6 +29,8 @@ type RingOptions struct { // Following options are copied from Options struct. + OnConnect func(*Conn) error + DB int Password string @@ -52,6 +54,8 @@ func (opt *RingOptions) init() { func (opt *RingOptions) clientOptions() *Options { return &Options{ + OnConnect: opt.OnConnect, + DB: opt.DB, Password: opt.Password, diff --git a/sentinel.go b/sentinel.go index da3a4312bb..b28c3706ef 100644 --- a/sentinel.go +++ b/sentinel.go @@ -23,6 +23,8 @@ type FailoverOptions struct { // Following options are copied from Options struct. + OnConnect func(*Conn) error + Password string DB int @@ -42,6 +44,8 @@ func (opt *FailoverOptions) options() *Options { return &Options{ Addr: "FailoverClient", + OnConnect: opt.OnConnect, + DB: opt.DB, Password: opt.Password,