Skip to content

Commit 8481b33

Browse files
committed
dial hook
1 parent df43b28 commit 8481b33

File tree

7 files changed

+62
-31
lines changed

7 files changed

+62
-31
lines changed

example_instrumentation_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package redis_test
33
import (
44
"context"
55
"fmt"
6+
"net"
67

78
"github.com/go-redis/redis/v9"
89
)
@@ -12,8 +13,8 @@ type redisHook struct{}
1213
var _ redis.Hook = redisHook{}
1314

1415
func (redisHook) DialHook(hook redis.DialHook) redis.DialHook {
15-
return func(ctx context.Context) error {
16-
return hook(ctx)
16+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
17+
return hook(ctx, network, addr)
1718
}
1819
}
1920

extra/redisotel/metrics.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package redisotel
22

33
import (
44
"context"
5+
"net"
56
"time"
67

78
"github.com/go-redis/redis/v9"
@@ -75,11 +76,11 @@ type metricsHook struct {
7576
var _ redis.Hook = (*metricsHook)(nil)
7677

7778
func (mh *metricsHook) DialHook(hook redis.DialHook) redis.DialHook {
78-
return func(ctx context.Context) error {
79+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
7980
start := time.Now()
80-
err := hook(ctx)
81+
conn, err := hook(ctx, network, addr)
8182
mh.createTime.Record(ctx, time.Since(start).Microseconds())
82-
return err
83+
return conn, err
8384
}
8485
}
8586

extra/redisotel/tracing.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package redisotel
22

33
import (
44
"context"
5+
"net"
56

67
"go.opentelemetry.io/otel/attribute"
78
"go.opentelemetry.io/otel/codes"
@@ -41,19 +42,26 @@ func NewTracingHook(opts ...Option) *TracingHook {
4142
}
4243

4344
func (th *TracingHook) DialHook(hook redis.DialHook) redis.DialHook {
44-
return func(ctx context.Context) error {
45+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
4546
if !trace.SpanFromContext(ctx).IsRecording() {
46-
return hook(ctx)
47+
return hook(ctx, network, addr)
4748
}
4849

49-
ctx, span := th.conf.tracer.Start(ctx, "redis.dial", th.spanOpts...)
50+
spanOpts := th.spanOpts
51+
spanOpts = append(spanOpts, trace.WithAttributes(
52+
attribute.String("network", network),
53+
attribute.String("addr", addr),
54+
))
55+
56+
ctx, span := th.conf.tracer.Start(ctx, "redis.dial", spanOpts...)
5057
defer span.End()
5158

52-
if err := hook(ctx); err != nil {
59+
conn, err := hook(ctx, network, addr)
60+
if err != nil {
5361
recordError(ctx, span, err)
54-
return err
62+
return nil, err
5563
}
56-
return nil
64+
return conn, nil
5765
}
5866
}
5967

options.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,13 @@ func getUserPassword(u *url.URL) (string, string) {
470470
return user, password
471471
}
472472

473-
func newConnPool(opt *Options) *pool.ConnPool {
473+
func newConnPool(
474+
opt *Options,
475+
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
476+
) *pool.ConnPool {
474477
return pool.NewConnPool(&pool.Options{
475478
Dialer: func(ctx context.Context) (net.Conn, error) {
476-
return opt.Dialer(ctx, opt.Network, opt.Addr)
479+
return dialer(ctx, opt.Network, opt.Addr)
477480
},
478481
PoolFIFO: opt.PoolFIFO,
479482
PoolSize: opt.PoolSize,

redis.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"net"
78
"strings"
89
"sync/atomic"
910
"time"
@@ -30,7 +31,7 @@ type Hook interface {
3031
}
3132

3233
type (
33-
DialHook func(ctx context.Context) error
34+
DialHook func(ctx context.Context, network, addr string) (net.Conn, error)
3435
ProcessHook func(ctx context.Context, cmd Cmder) error
3536
ProcessPipelineHook func(ctx context.Context, cmds []Cmder) error
3637
)
@@ -68,6 +69,15 @@ func (hs *hooks) clone() hooks {
6869
return clone
6970
}
7071

72+
func (hs *hooks) setDial(dial DialHook) {
73+
hs.dial = dial
74+
for _, h := range hs.slice {
75+
if wrapped := h.DialHook(hs.dial); wrapped != nil {
76+
hs.dial = wrapped
77+
}
78+
}
79+
}
80+
7181
func (hs *hooks) setProcess(process ProcessHook) {
7282
hs.process = process
7383
for _, h := range hs.slice {
@@ -124,13 +134,6 @@ type baseClient struct {
124134
onClose func() error // hook called when client is closed
125135
}
126136

127-
func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
128-
return &baseClient{
129-
opt: opt,
130-
connPool: connPool,
131-
}
132-
}
133-
134137
func (c *baseClient) clone() *baseClient {
135138
clone := *c
136139
return &clone
@@ -286,6 +289,10 @@ func (c *baseClient) withConn(
286289
return fn(ctx, cn)
287290
}
288291

292+
func (c *baseClient) dial(ctx context.Context, network, addr string) (net.Conn, error) {
293+
return c.opt.Dialer(ctx, network, addr)
294+
}
295+
289296
func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
290297
var lastErr error
291298
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
@@ -527,15 +534,19 @@ func NewClient(opt *Options) *Client {
527534
opt.init()
528535

529536
c := Client{
530-
baseClient: newBaseClient(opt, newConnPool(opt)),
537+
baseClient: &baseClient{
538+
opt: opt,
539+
},
531540
}
541+
c.connPool = newConnPool(opt, c.baseClient.dial)
532542
c.init()
533543

534544
return &c
535545
}
536546

537547
func (c *Client) init() {
538548
c.cmdable = c.Process
549+
c.hooks.setDial(c.baseClient.dial)
539550
c.hooks.setProcess(c.baseClient.process)
540551
c.hooks.setProcessPipeline(c.baseClient.processPipeline)
541552
c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline)
@@ -696,6 +707,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
696707
c.cmdable = c.Process
697708
c.statefulCmdable = c.Process
698709

710+
c.hooks.setDial(c.baseClient.dial)
699711
c.hooks.setProcess(c.baseClient.process)
700712
c.hooks.setProcessPipeline(c.baseClient.processPipeline)
701713
c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline)

sentinel.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,17 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
198198
opt.Dialer = masterReplicaDialer(failover)
199199
opt.init()
200200

201-
connPool := newConnPool(opt)
201+
var connPool *pool.ConnPool
202+
203+
rdb := &Client{
204+
baseClient: &baseClient{
205+
opt: opt,
206+
},
207+
}
208+
connPool = newConnPool(opt, rdb.baseClient.dial)
209+
rdb.connPool = connPool
210+
rdb.onClose = failover.Close
211+
rdb.init()
202212

203213
failover.mu.Lock()
204214
failover.onFailover = func(ctx context.Context, addr string) {
@@ -208,12 +218,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
208218
}
209219
failover.mu.Unlock()
210220

211-
rdb := &Client{
212-
baseClient: newBaseClient(opt, connPool),
213-
}
214-
rdb.onClose = failover.Close
215-
rdb.init()
216-
217221
return rdb
218222
}
219223

@@ -262,11 +266,12 @@ func NewSentinelClient(opt *Options) *SentinelClient {
262266
opt.init()
263267
c := &SentinelClient{
264268
baseClient: &baseClient{
265-
opt: opt,
266-
connPool: newConnPool(opt),
269+
opt: opt,
267270
},
268271
}
272+
c.connPool = newConnPool(opt, c.baseClient.dial)
269273

274+
c.hooks.setDial(c.baseClient.dial)
270275
c.hooks.setProcess(c.baseClient.process)
271276

272277
return c

tx.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func (c *Tx) init() {
3838
c.cmdable = c.Process
3939
c.statefulCmdable = c.Process
4040

41+
c.hooks.setDial(c.baseClient.dial)
4142
c.hooks.setProcess(c.baseClient.process)
4243
c.hooks.setProcessPipeline(c.baseClient.processPipeline)
4344
c.hooks.setProcessTxPipeline(c.baseClient.processTxPipeline)

0 commit comments

Comments
 (0)