Skip to content

Commit 6ad9c3b

Browse files
committed
feat(plugin): add conn plugin api
Signed-off-by: monkey92t <golang@88.com>
1 parent 31ba855 commit 6ad9c3b

File tree

2 files changed

+105
-19
lines changed

2 files changed

+105
-19
lines changed

conn_plugin.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package redis
2+
3+
import "context"
4+
5+
type (
6+
// PreInitConnPlugin plugin executed before connection initialization. At this point,
7+
// the network connection has been established, but Redis authentication has not yet
8+
// taken place. You can perform specific operations before the Redis authentication,
9+
// such as third-party Redis proxy authentication or executing any necessary commands.
10+
// Please note that the `HELLO` command has not been executed yet. If you invoke any Redis
11+
// commands, the default RESP version of the Redis server will be used.
12+
PreInitConnPlugin func(ctx context.Context, conn *Conn) error
13+
14+
// InitConnPlugin redis connection authentication plugin. go-redis sets a default
15+
// authentication plugin, but if you need to implement a special authentication
16+
// mechanism for your Redis server, you can use this plugin instead of the default one.
17+
// This plugin can only be set once, and if set multiple times,
18+
// only the last set plugin will be executed.
19+
InitConnPlugin func(ctx context.Context, conn *Conn) error
20+
21+
// PostInitConnPlugin Plugin executed after connection initialization. At this point,
22+
// Redis authentication has been completed, and you can execute commands related to
23+
// the connection status, such as `SELECT DB`, `CLIENT SETNAME`.
24+
PostInitConnPlugin func(ctx context.Context, conn *Conn) error
25+
)
26+
27+
// ---------------------------------------------------------------------------------------
28+
29+
type plugin struct {
30+
preInitConnPlugins []PreInitConnPlugin
31+
initConnPlugin InitConnPlugin
32+
postInitConnPlugin []PostInitConnPlugin
33+
}
34+
35+
// RegistryPreInitConnPlugin register a PreInitConnPlugin plugin, which can be registered
36+
// multiple times. It will be executed in the order of registration.
37+
func (p *plugin) RegistryPreInitConnPlugin(pre PreInitConnPlugin) {
38+
p.preInitConnPlugins = append(p.preInitConnPlugins, pre)
39+
}
40+
41+
// RegistryInitConnPlugin register an InitConnPlugin plugin, which will override the default
42+
// authentication mechanism of go-redis. If registered multiple times, only the plugin
43+
// registered last will be executed.
44+
func (p *plugin) RegistryInitConnPlugin(init InitConnPlugin) {
45+
p.initConnPlugin = init
46+
}
47+
48+
// RegistryPostInitConnPlugin register a PostInitConnPlugin plugin, which can be registered
49+
// multiple times. It will be executed in the order of registration.
50+
func (p *plugin) RegistryPostInitConnPlugin(post PostInitConnPlugin) {
51+
p.postInitConnPlugin = append(p.postInitConnPlugin, post)
52+
}

redis.go

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
183183
//------------------------------------------------------------------------------
184184

185185
type baseClient struct {
186+
plugin
186187
opt *Options
187188
connPool pool.Pooler
188189

@@ -264,22 +265,14 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
264265
return cn, nil
265266
}
266267

267-
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
268-
if cn.Inited {
269-
return nil
270-
}
271-
cn.Inited = true
268+
func (c *baseClient) authentication(ctx context.Context, conn *Conn) error {
269+
var auth bool
272270

273271
username, password := c.opt.Username, c.opt.Password
274272
if c.opt.CredentialsProvider != nil {
275273
username, password = c.opt.CredentialsProvider()
276274
}
277275

278-
connPool := pool.NewSingleConnPool(c.connPool, cn)
279-
conn := newConn(c.opt, connPool)
280-
281-
var auth bool
282-
283276
// for redis-server versions that do not support the HELLO command,
284277
// RESP2 will continue to be used.
285278
if err := conn.Hello(ctx, 3, username, password, "").Err(); err == nil {
@@ -295,15 +288,49 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
295288
return err
296289
}
297290

298-
_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
299-
if !auth && password != "" {
300-
if username != "" {
301-
pipe.AuthACL(ctx, username, password)
302-
} else {
303-
pipe.Auth(ctx, password)
304-
}
291+
if !auth && password != "" {
292+
var authErr error
293+
if username != "" {
294+
authErr = conn.AuthACL(ctx, username, password).Err()
295+
} else {
296+
authErr = conn.Auth(ctx, password).Err()
297+
}
298+
299+
if authErr != nil {
300+
return authErr
301+
}
302+
}
303+
304+
return nil
305+
}
306+
307+
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
308+
if cn.Inited {
309+
return nil
310+
}
311+
312+
connPool := pool.NewSingleConnPool(c.connPool, cn)
313+
conn := newConn(c.opt, connPool, c.plugin)
314+
315+
for _, p := range c.plugin.preInitConnPlugins {
316+
if err := p(ctx, conn); err != nil {
317+
return err
318+
}
319+
}
320+
321+
cn.Inited = true
322+
323+
if c.plugin.initConnPlugin != nil {
324+
if err := c.plugin.initConnPlugin(ctx, conn); err != nil {
325+
return err
326+
}
327+
} else {
328+
if err := c.authentication(ctx, conn); err != nil {
329+
return err
305330
}
331+
}
306332

333+
_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
307334
if c.opt.DB > 0 {
308335
pipe.Select(ctx, c.opt.DB)
309336
}
@@ -322,6 +349,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
322349
return err
323350
}
324351

352+
for _, p := range c.plugin.postInitConnPlugin {
353+
if err = p(ctx, conn); err != nil {
354+
return err
355+
}
356+
}
357+
325358
if c.opt.OnConnect != nil {
326359
return c.opt.OnConnect(ctx, conn)
327360
}
@@ -631,7 +664,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
631664
}
632665

633666
func (c *Client) Conn() *Conn {
634-
return newConn(c.opt, pool.NewStickyConnPool(c.connPool))
667+
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.baseClient.plugin)
635668
}
636669

637670
// Do create a Cmd from the args and processes the cmd.
@@ -767,11 +800,12 @@ type Conn struct {
767800
hooksMixin
768801
}
769802

770-
func newConn(opt *Options, connPool pool.Pooler) *Conn {
803+
func newConn(opt *Options, connPool pool.Pooler, plugin plugin) *Conn {
771804
c := Conn{
772805
baseClient: baseClient{
773806
opt: opt,
774807
connPool: connPool,
808+
plugin: plugin,
775809
},
776810
}
777811

0 commit comments

Comments
 (0)