Skip to content

Commit

Permalink
Add Options.OnConnect
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed May 26, 2017
1 parent 7e8890b commit 4a3a300
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 14 deletions.
5 changes: 4 additions & 1 deletion cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type ClusterOptions struct {

// Following options are copied from Options struct.

OnConnect func(*Conn) error

MaxRetries int
Password string

Expand Down Expand Up @@ -65,6 +67,8 @@ func (opt *ClusterOptions) clientOptions() *Options {
const disableIdleCheck = -1

return &Options{
OnConnect: opt.OnConnect,

MaxRetries: opt.MaxRetries,
Password: opt.Password,
ReadOnly: opt.ReadOnly,
Expand All @@ -77,7 +81,6 @@ func (opt *ClusterOptions) clientOptions() *Options {
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,

// IdleCheckFrequency is not copied to disable reaper
IdleCheckFrequency: disableIdleCheck,
}
}
Expand Down
4 changes: 2 additions & 2 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Cmdable interface {
Pipeline() Pipeliner
Pipelined(fn func(Pipeliner) error) ([]Cmder, error)

ClientGetName() *StringCmd
Echo(message interface{}) *StringCmd
Ping() *StatusCmd
Quit() *StatusCmd
Expand Down Expand Up @@ -242,7 +243,6 @@ type StatefulCmdable interface {
Auth(password string) *StatusCmd
Select(index int) *StatusCmd
ClientSetName(name string) *BoolCmd
ClientGetName() *StringCmd
ReadOnly() *StatusCmd
ReadWrite() *StatusCmd
}
Expand Down Expand Up @@ -1649,7 +1649,7 @@ func (c *statefulCmdable) ClientSetName(name string) *BoolCmd {
}

// ClientGetName returns the name of the connection.
func (c *statefulCmdable) ClientGetName() *StringCmd {
func (c *cmdable) ClientGetName() *StringCmd {
cmd := NewStringCmd("client", "getname")
c.process(cmd)
return cmd
Expand Down
3 changes: 3 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ type Options struct {
// Network and Addr options.
Dialer func() (net.Conn, error)

// Hook that is called when new connection is established.
OnConnect func(*Conn) error

// Optional password. Must match the password specified in the
// requirepass server configuration option.
Password string
Expand Down
72 changes: 61 additions & 11 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}

// Options returns read-only Options that were used to create the client.
func (c *baseClient) Options() *Options {
return c.opt
}

func (c *baseClient) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.connPool.Get()
if err != nil {
Expand Down Expand Up @@ -55,13 +50,23 @@ func (c *baseClient) putConn(cn *pool.Conn, err error) bool {
func (c *baseClient) initConn(cn *pool.Conn) error {
cn.Inited = true

if c.opt.Password == "" && c.opt.DB == 0 && !c.opt.ReadOnly {
if c.opt.Password == "" &&
c.opt.DB == 0 &&
!c.opt.ReadOnly &&
c.opt.OnConnect == nil {
return nil
}

// Temp client for Auth and Select.
client := newClient(c.opt, pool.NewSingleConnPool(cn))
_, err := client.Pipelined(func(pipe Pipeliner) error {
// Temp client to initialize connection.
conn := &Conn{
baseClient: baseClient{
opt: c.opt,
connPool: pool.NewSingleConnPool(cn),
},
}
conn.setProcessor(conn.Process)

_, err := conn.Pipelined(func(pipe Pipeliner) error {
if c.opt.Password != "" {
pipe.Auth(c.opt.Password)
}
Expand All @@ -76,7 +81,14 @@ func (c *baseClient) initConn(cn *pool.Conn) error {

return nil
})
return err
if err != nil {
return err
}

if c.opt.OnConnect != nil {
return c.opt.OnConnect(conn)
}
return nil
}

func (c *baseClient) Process(cmd Cmder) error {
Expand Down Expand Up @@ -182,7 +194,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer {
}
}

func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
Expand Down Expand Up @@ -311,6 +323,11 @@ func (c *Client) copy() *Client {
return c2
}

// Options returns read-only Options that were used to create the client.
func (c *Client) Options() *Options {
return c.opt
}

// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
s := c.connPool.Stats()
Expand Down Expand Up @@ -375,3 +392,36 @@ func (c *Client) PSubscribe(channels ...string) *PubSub {
}
return pubsub
}

//------------------------------------------------------------------------------

// Conn is like Client, but its pool contains single connection.
type Conn struct {
baseClient
statefulCmdable
}

func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}

func (c *Conn) Pipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.pipelineProcessCmds),
}
pipe.setProcessor(pipe.Process)
return &pipe
}

func (c *Conn) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().pipelined(fn)
}

// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
func (c *Conn) TxPipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.txPipelineProcessCmds),
}
pipe.setProcessor(pipe.Process)
return &pipe
}
23 changes: 23 additions & 0 deletions redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,26 @@ var _ = Describe("Client timeout", func() {
testTimeout()
})
})

var _ = Describe("Client OnConnect", func() {
var client *redis.Client

BeforeEach(func() {
opt := redisOptions()
opt.OnConnect = func(cn *redis.Conn) error {
return cn.ClientSetName("on_connect").Err()
}

client = redis.NewClient(opt)
})

AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})

It("calls OnConnect", func() {
name, err := client.ClientGetName().Result()
Expect(err).NotTo(HaveOccurred())
Expect(name).To(Equal("on_connect"))
})
})
4 changes: 4 additions & 0 deletions ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ type RingOptions struct {

// Following options are copied from Options struct.

OnConnect func(*Conn) error

DB int
Password string

Expand All @@ -52,6 +54,8 @@ func (opt *RingOptions) init() {

func (opt *RingOptions) clientOptions() *Options {
return &Options{
OnConnect: opt.OnConnect,

DB: opt.DB,
Password: opt.Password,

Expand Down
4 changes: 4 additions & 0 deletions sentinel.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type FailoverOptions struct {

// Following options are copied from Options struct.

OnConnect func(*Conn) error

Password string
DB int

Expand All @@ -42,6 +44,8 @@ func (opt *FailoverOptions) options() *Options {
return &Options{
Addr: "FailoverClient",

OnConnect: opt.OnConnect,

DB: opt.DB,
Password: opt.Password,

Expand Down

0 comments on commit 4a3a300

Please sign in to comment.