Skip to content

Commit 1a0eb21

Browse files
Add MinDNSResolutionRate Option
1 parent 03e76b3 commit 1a0eb21

File tree

6 files changed

+48
-29
lines changed

6 files changed

+48
-29
lines changed

dialoptions.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ type dialOptions struct {
7979
resolvers []resolver.Builder
8080
idleTimeout time.Duration
8181
recvBufferPool SharedBufferPool
82+
minDNSResolutionRate *time.Duration
8283
}
8384

8485
// DialOption configures how we set up the connection.
@@ -711,6 +712,17 @@ func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
711712
return withRecvBufferPool(bufferPool)
712713
}
713714

715+
// WithDNSMinResolutionRate sets the default minimum rate at which DNS re-resolutions are
716+
// allowed. This helps to prevent excessive re-resolution.
717+
//
718+
// Using this option overwrites the default [minResolutionRate] specified
719+
// in the dns resolver.
720+
func WithMinDNSResolutionRate(d time.Duration) DialOption {
721+
return newFuncDialOption(func(o *dialOptions) {
722+
o.minDNSResolutionRate = &d
723+
})
724+
}
725+
714726
func withRecvBufferPool(bufferPool SharedBufferPool) DialOption {
715727
return newFuncDialOption(func(o *dialOptions) {
716728
o.recvBufferPool = bufferPool

internal/resolver/dns/dns_resolver.go

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,17 @@ const (
6666
txtAttribute = "grpc_config="
6767
)
6868

69-
var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
70-
return func(ctx context.Context, network, _ string) (net.Conn, error) {
71-
var dialer net.Dialer
72-
return dialer.DialContext(ctx, network, address)
69+
var (
70+
addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
71+
return func(ctx context.Context, network, _ string) (net.Conn, error) {
72+
var dialer net.Dialer
73+
return dialer.DialContext(ctx, network, address)
74+
}
7375
}
74-
}
76+
// minResolutionRate is the minimum rate at which re-resolutions are
77+
// allowed. This helps to prevent excessive re-resolution.
78+
minResolutionRate = 30 * time.Second // this is the default value and can be changed via BuildOptions
79+
)
7580

7681
var newNetResolver = func(authority string) (internal.NetResolver, error) {
7782
if authority == "" {
@@ -113,6 +118,10 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
113118
return deadResolver{}, nil
114119
}
115120

121+
if opts.MinDNSResolutionRate != nil {
122+
minResolutionRate = *opts.MinDNSResolutionRate
123+
}
124+
116125
// DNS address (non-IP).
117126
ctx, cancel := context.WithCancel(context.Background())
118127
d := &dnsResolver{
@@ -123,6 +132,7 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
123132
cc: cc,
124133
rn: make(chan struct{}, 1),
125134
disableServiceConfig: opts.DisableServiceConfig,
135+
minResolutionRate: minResolutionRate,
126136
}
127137

128138
d.resolver, err = internal.NewNetResolver(target.URL.Host)
@@ -167,6 +177,7 @@ type dnsResolver struct {
167177
// replaceNetFunc (WRITE the lookup function pointers).
168178
wg sync.WaitGroup
169179
disableServiceConfig bool
180+
minResolutionRate time.Duration
170181
}
171182

172183
// ResolveNow invoke an immediate resolution of the target that this
@@ -198,10 +209,10 @@ func (d *dnsResolver) watcher() {
198209

199210
var waitTime time.Duration
200211
if err == nil {
201-
// Success resolving, wait for the next ResolveNow. However, also wait 30
202-
// seconds at the very least to prevent constantly re-resolving.
212+
// Success resolving, wait for the next ResolveNow. However, also wait for
213+
// [minResolutionRate] seconds at the very least to prevent constantly re-resolving.
203214
backoffIndex = 1
204-
waitTime = internal.MinResolutionRate
215+
waitTime = d.minResolutionRate
205216
select {
206217
case <-d.ctx.Done():
207218
return

internal/resolver/dns/dns_resolver_test.go

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ func overrideNetResolver(t *testing.T, r *testNetResolver) {
6767
t.Cleanup(func() { dnsinternal.NewNetResolver = origNetResolver })
6868
}
6969

70-
// Override the DNS Min Res Rate used by the resolver.
71-
func overrideResolutionRate(t *testing.T, d time.Duration) {
72-
origMinResRate := dnsinternal.MinResolutionRate
73-
dnsinternal.MinResolutionRate = d
74-
t.Cleanup(func() { dnsinternal.MinResolutionRate = origMinResRate })
75-
}
76-
7770
// Override the timer used by the DNS resolver to fire after a duration of d.
7871
func overrideTimeAfterFunc(t *testing.T, d time.Duration) {
7972
origTimeAfter := dnsinternal.TimeAfterFunc
@@ -109,7 +102,7 @@ func enableSRVLookups(t *testing.T) {
109102

110103
// Builds a DNS resolver for target and returns a couple of channels to read the
111104
// state and error pushed by the resolver respectively.
112-
func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Resolver, chan resolver.State, chan error) {
105+
func buildResolverWithTestClientConn(t *testing.T, target string, buildOptions resolver.BuildOptions) (resolver.Resolver, chan resolver.State, chan error) {
113106
t.Helper()
114107

115108
b := resolver.Get("dns")
@@ -135,7 +128,7 @@ func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Reso
135128
}
136129

137130
tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF, ReportErrorF: reportErrorF}
138-
r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, resolver.BuildOptions{})
131+
r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, buildOptions)
139132
if err != nil {
140133
t.Fatalf("Failed to build DNS resolver for target %q: %v\n", target, err)
141134
}
@@ -504,7 +497,7 @@ func (s) TestDNSResolver_Basic(t *testing.T) {
504497
txtLookupTable: test.txtLookupTable,
505498
})
506499
enableSRVLookups(t)
507-
_, stateCh, _ := buildResolverWithTestClientConn(t, test.target)
500+
_, stateCh, _ := buildResolverWithTestClientConn(t, test.target, resolver.BuildOptions{})
508501

509502
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
510503
defer cancel()
@@ -635,7 +628,6 @@ func (s) TestDNSResolver_ExponentialBackoff(t *testing.T) {
635628
func (s) TestDNSResolver_ResolveNow(t *testing.T) {
636629
const target = "foo.bar.com"
637630

638-
overrideResolutionRate(t, 0)
639631
overrideTimeAfterFunc(t, 0)
640632
tr := &testNetResolver{
641633
hostLookupTable: map[string][]string{
@@ -647,7 +639,8 @@ func (s) TestDNSResolver_ResolveNow(t *testing.T) {
647639
}
648640
overrideNetResolver(t, tr)
649641

650-
r, stateCh, _ := buildResolverWithTestClientConn(t, target)
642+
var minResolutionRate time.Duration = 0
643+
r, stateCh, _ := buildResolverWithTestClientConn(t, target, resolver.BuildOptions{MinDNSResolutionRate: &minResolutionRate})
651644

652645
// Verify that the first update pushed by the resolver matches expectations.
653646
wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
@@ -738,9 +731,10 @@ func (s) TestIPResolver(t *testing.T) {
738731

739732
for _, test := range tests {
740733
t.Run(test.name, func(t *testing.T) {
741-
overrideResolutionRate(t, 0)
742734
overrideTimeAfterFunc(t, 2*defaultTestTimeout)
743-
r, stateCh, _ := buildResolverWithTestClientConn(t, test.target)
735+
736+
var minResolutionRate time.Duration = 0
737+
r, stateCh, _ := buildResolverWithTestClientConn(t, test.target, resolver.BuildOptions{MinDNSResolutionRate: &minResolutionRate})
744738

745739
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
746740
defer cancel()
@@ -943,7 +937,7 @@ func (s) TestTXTError(t *testing.T) {
943937
// There is no entry for "ipv4.single.fake" in the txtLookupTbl
944938
// maintained by the fake net.Resolver. So, a TXT lookup for this
945939
// name will return an error.
946-
_, stateCh, _ := buildResolverWithTestClientConn(t, "ipv4.single.fake")
940+
_, stateCh, _ := buildResolverWithTestClientConn(t, "ipv4.single.fake", resolver.BuildOptions{})
947941

948942
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
949943
defer cancel()
@@ -1092,7 +1086,7 @@ func (s) TestRateLimitedResolve(t *testing.T) {
10921086
}
10931087
overrideNetResolver(t, tr)
10941088

1095-
r, stateCh, _ := buildResolverWithTestClientConn(t, target)
1089+
r, stateCh, _ := buildResolverWithTestClientConn(t, target, resolver.BuildOptions{})
10961090

10971091
// Wait for the first resolution request to be done. This happens as part
10981092
// of the first iteration of the for loop in watcher().
@@ -1171,7 +1165,7 @@ func (s) TestReportError(t *testing.T) {
11711165
overrideNetResolver(t, &testNetResolver{})
11721166

11731167
const target = "notfoundaddress"
1174-
_, _, errorCh := buildResolverWithTestClientConn(t, target)
1168+
_, _, errorCh := buildResolverWithTestClientConn(t, target, resolver.BuildOptions{})
11751169

11761170
// Should receive first error.
11771171
ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)

internal/resolver/dns/internal/internal.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ var (
5050

5151
// The following vars are overridden from tests.
5252
var (
53-
// MinResolutionRate is the minimum rate at which re-resolutions are
54-
// allowed. This helps to prevent excessive re-resolution.
55-
MinResolutionRate = 30 * time.Second
56-
5753
// TimeAfterFunc is used by the DNS resolver to wait for the given duration
5854
// to elapse. In non-test code, this is implemented by time.After. In test
5955
// code, this can be used to control the amount of time the resolver is

resolver/resolver.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"net"
2727
"net/url"
2828
"strings"
29+
"time"
2930

3031
"google.golang.org/grpc/attributes"
3132
"google.golang.org/grpc/credentials"
@@ -168,6 +169,10 @@ type BuildOptions struct {
168169
// field. In most cases though, it is not appropriate, and this field may
169170
// be ignored.
170171
Dialer func(context.Context, string) (net.Conn, error)
172+
// MinDNSResolutionRate is the minimum rate at which re-resolutions are
173+
// allowed. This helps to prevent excessive re-resolution.
174+
// Pointer was used to differentiate not-given from default value
175+
MinDNSResolutionRate *time.Duration
171176
}
172177

173178
// An Endpoint is one network endpoint, or server, which may have multiple

resolver_wrapper.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ func (ccr *ccResolverWrapper) start() error {
7575
DialCreds: ccr.cc.dopts.copts.TransportCredentials,
7676
CredsBundle: ccr.cc.dopts.copts.CredsBundle,
7777
Dialer: ccr.cc.dopts.copts.Dialer,
78+
MinDNSResolutionRate: ccr.cc.dopts.minDNSResolutionRate,
7879
}
7980
var err error
8081
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)

0 commit comments

Comments
 (0)