@@ -183,6 +183,7 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
183
183
//------------------------------------------------------------------------------
184
184
185
185
type baseClient struct {
186
+ plugin
186
187
opt * Options
187
188
connPool pool.Pooler
188
189
@@ -264,22 +265,14 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
264
265
return cn , nil
265
266
}
266
267
267
- func (c * baseClient ) initConn (ctx context.Context , cn * pool.Conn ) error {
268
- if cn .Inited {
269
- return nil
270
- }
271
- cn .Inited = true
268
+ func (c * baseClient ) authentication (ctx context.Context , conn * Conn ) error {
269
+ var auth bool
272
270
273
271
username , password := c .opt .Username , c .opt .Password
274
272
if c .opt .CredentialsProvider != nil {
275
273
username , password = c .opt .CredentialsProvider ()
276
274
}
277
275
278
- connPool := pool .NewSingleConnPool (c .connPool , cn )
279
- conn := newConn (c .opt , connPool )
280
-
281
- var auth bool
282
-
283
276
// for redis-server versions that do not support the HELLO command,
284
277
// RESP2 will continue to be used.
285
278
if err := conn .Hello (ctx , 3 , username , password , "" ).Err (); err == nil {
@@ -295,15 +288,49 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
295
288
return err
296
289
}
297
290
298
- _ , err := conn .Pipelined (ctx , func (pipe Pipeliner ) error {
299
- if ! auth && password != "" {
300
- if username != "" {
301
- pipe .AuthACL (ctx , username , password )
302
- } else {
303
- pipe .Auth (ctx , password )
304
- }
291
+ if ! auth && password != "" {
292
+ var authErr error
293
+ if username != "" {
294
+ authErr = conn .AuthACL (ctx , username , password ).Err ()
295
+ } else {
296
+ authErr = conn .Auth (ctx , password ).Err ()
297
+ }
298
+
299
+ if authErr != nil {
300
+ return authErr
301
+ }
302
+ }
303
+
304
+ return nil
305
+ }
306
+
307
+ func (c * baseClient ) initConn (ctx context.Context , cn * pool.Conn ) error {
308
+ if cn .Inited {
309
+ return nil
310
+ }
311
+
312
+ connPool := pool .NewSingleConnPool (c .connPool , cn )
313
+ conn := newConn (c .opt , connPool , c .plugin )
314
+
315
+ for _ , p := range c .plugin .preInitConnPlugins {
316
+ if err := p (ctx , conn ); err != nil {
317
+ return err
318
+ }
319
+ }
320
+
321
+ cn .Inited = true
322
+
323
+ if c .plugin .initConnPlugin != nil {
324
+ if err := c .plugin .initConnPlugin (ctx , conn ); err != nil {
325
+ return err
326
+ }
327
+ } else {
328
+ if err := c .authentication (ctx , conn ); err != nil {
329
+ return err
305
330
}
331
+ }
306
332
333
+ _ , err := conn .Pipelined (ctx , func (pipe Pipeliner ) error {
307
334
if c .opt .DB > 0 {
308
335
pipe .Select (ctx , c .opt .DB )
309
336
}
@@ -322,6 +349,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
322
349
return err
323
350
}
324
351
352
+ for _ , p := range c .plugin .postInitConnPlugin {
353
+ if err = p (ctx , conn ); err != nil {
354
+ return err
355
+ }
356
+ }
357
+
325
358
if c .opt .OnConnect != nil {
326
359
return c .opt .OnConnect (ctx , conn )
327
360
}
@@ -631,7 +664,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
631
664
}
632
665
633
666
func (c * Client ) Conn () * Conn {
634
- return newConn (c .opt , pool .NewStickyConnPool (c .connPool ))
667
+ return newConn (c .opt , pool .NewStickyConnPool (c .connPool ), c . baseClient . plugin )
635
668
}
636
669
637
670
// Do create a Cmd from the args and processes the cmd.
@@ -767,11 +800,12 @@ type Conn struct {
767
800
hooksMixin
768
801
}
769
802
770
- func newConn (opt * Options , connPool pool.Pooler ) * Conn {
803
+ func newConn (opt * Options , connPool pool.Pooler , plugin plugin ) * Conn {
771
804
c := Conn {
772
805
baseClient : baseClient {
773
806
opt : opt ,
774
807
connPool : connPool ,
808
+ plugin : plugin ,
775
809
},
776
810
}
777
811
0 commit comments