Skip to content

Commit d2ec263

Browse files
committed
DNS resolving with timeout.
1 parent 6ce73bf commit d2ec263

File tree

6 files changed

+94
-22
lines changed

6 files changed

+94
-22
lines changed

dialoptions.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ type dialOptions struct {
7777
defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON.
7878
defaultServiceConfigRawJSON *string
7979
resolvers []resolver.Builder
80+
resolveTimeout time.Duration
8081
idleTimeout time.Duration
8182
recvBufferPool SharedBufferPool
8283
}
@@ -694,6 +695,13 @@ func WithIdleTimeout(d time.Duration) DialOption {
694695
})
695696
}
696697

698+
// WithResolveTimeout returns a DialOption that configures a DNS resolving timeout.
699+
func WithResolveTimeout(d time.Duration) DialOption {
700+
return newFuncDialOption(func(o *dialOptions) {
701+
o.resolveTimeout = d
702+
})
703+
}
704+
697705
// WithRecvBufferPool returns a DialOption that configures the ClientConn
698706
// to use the provided shared buffer pool for parsing incoming messages. Depending
699707
// on the application's workload, this could result in reduced memory allocation.

internal/resolver/dns/dns_resolver.go

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ type dnsBuilder struct{}
100100

101101
// Build creates and starts a DNS resolver that watches the name resolution of
102102
// the target.
103-
func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
103+
func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
104+
resolver.Resolver, error,
105+
) {
104106
host, port, err := parseTarget(target.Endpoint(), defaultPort)
105107
if err != nil {
106108
return nil, err
@@ -118,13 +120,18 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
118120
d := &dnsResolver{
119121
host: host,
120122
port: port,
123+
timeout: opts.Timeout,
121124
ctx: ctx,
122125
cancel: cancel,
123126
cc: cc,
124127
rn: make(chan struct{}, 1),
125128
disableServiceConfig: opts.DisableServiceConfig,
126129
}
127130

131+
if d.timeout == 0 {
132+
d.timeout = 1 * time.Minute
133+
}
134+
128135
d.resolver, err = internal.NewNetResolver(target.URL.Host)
129136
if err != nil {
130137
return nil, err
@@ -152,6 +159,7 @@ type dnsResolver struct {
152159
host string
153160
port string
154161
resolver internal.NetResolver
162+
timeout time.Duration
155163
ctx context.Context
156164
cancel context.CancelFunc
157165
cc resolver.ClientConn
@@ -221,18 +229,18 @@ func (d *dnsResolver) watcher() {
221229
}
222230
}
223231

224-
func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
232+
func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
225233
if !EnableSRVLookups {
226234
return nil, nil
227235
}
228236
var newAddrs []resolver.Address
229-
_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host)
237+
_, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
230238
if err != nil {
231239
err = handleDNSError(err, "SRV") // may become nil
232240
return nil, err
233241
}
234242
for _, s := range srvs {
235-
lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
243+
lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
236244
if err != nil {
237245
err = handleDNSError(err, "A") // may become nil
238246
if err == nil {
@@ -269,8 +277,8 @@ func handleDNSError(err error, lookupType string) error {
269277
return err
270278
}
271279

272-
func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
273-
ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host)
280+
func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
281+
ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
274282
if err != nil {
275283
if envconfig.TXTErrIgnore {
276284
return nil
@@ -297,8 +305,8 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
297305
return d.cc.ParseServiceConfig(sc)
298306
}
299307

300-
func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
301-
addrs, err := d.resolver.LookupHost(d.ctx, d.host)
308+
func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
309+
addrs, err := d.resolver.LookupHost(ctx, d.host)
302310
if err != nil {
303311
err = handleDNSError(err, "A")
304312
return nil, err
@@ -316,8 +324,12 @@ func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
316324
}
317325

318326
func (d *dnsResolver) lookup() (*resolver.State, error) {
319-
srv, srvErr := d.lookupSRV()
320-
addrs, hostErr := d.lookupHost()
327+
ctxSRV, cancelSRV := context.WithTimeout(d.ctx, d.timeout)
328+
defer cancelSRV()
329+
srv, srvErr := d.lookupSRV(ctxSRV)
330+
ctxHost, cancelHost := context.WithTimeout(d.ctx, d.timeout)
331+
defer cancelHost()
332+
addrs, hostErr := d.lookupHost(ctxHost)
321333
if hostErr != nil && (srvErr != nil || len(srv) == 0) {
322334
return nil, hostErr
323335
}
@@ -327,7 +339,9 @@ func (d *dnsResolver) lookup() (*resolver.State, error) {
327339
state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
328340
}
329341
if !d.disableServiceConfig {
330-
state.ServiceConfig = d.lookupTXT()
342+
ctxTXT, cancelTXT := context.WithTimeout(d.ctx, d.timeout)
343+
defer cancelTXT()
344+
state.ServiceConfig = d.lookupTXT(ctxTXT)
331345
}
332346
return &state, nil
333347
}

internal/resolver/dns/dns_resolver_test.go

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Reso
135135
}
136136

137137
tcc := &testutils.ResolverClientConn{Logger: t, UpdateStateF: updateStateF, ReportErrorF: reportErrorF}
138-
r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, resolver.BuildOptions{})
138+
rOpts := resolver.BuildOptions{Timeout: 100 * time.Millisecond}
139+
r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("dns:///%s", target))}, tcc, rOpts)
139140
if err != nil {
140141
t.Fatalf("Failed to build DNS resolver for target %q: %v\n", target, err)
141142
}
@@ -144,6 +145,29 @@ func buildResolverWithTestClientConn(t *testing.T, target string) (resolver.Reso
144145
return r, stateCh, errCh
145146
}
146147

148+
// Test verifies that when the DNS resolver gets timeout error when net.Resolver
149+
// takes too long to resolve a target.
150+
func (s) TestResolveTimeout(t *testing.T) {
151+
const target = "timeoutaddress"
152+
tr := &testNetResolver{}
153+
tr.UpdateHostLookupTable(map[string][]string{target: {"1.2.3.4"}})
154+
overrideNetResolver(t, tr)
155+
156+
r, _, errCh := buildResolverWithTestClientConn(t, target)
157+
r.ResolveNow(resolver.ResolveNowOptions{})
158+
159+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
160+
defer cancel()
161+
select {
162+
case <-ctx.Done():
163+
t.Fatal("Timeout when waiting for an error from the resolver")
164+
case err := <-errCh:
165+
if err == nil || !strings.Contains(err.Error(), "timed out") {
166+
t.Fatalf(`we expect to see timed out error`)
167+
}
168+
}
169+
}
170+
147171
// Waits for a state update from the DNS resolver and verifies the following:
148172
// - wantAddrs matches the list of addresses in the update
149173
// - wantBalancerAddrs matches the list of grpclb addresses in the update
@@ -401,7 +425,9 @@ func (s) TestDNSResolver_Basic(t *testing.T) {
401425
txtLookupTable: map[string][]string{
402426
"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
403427
},
404-
wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
428+
wantAddrs: []resolver.Address{
429+
{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort},
430+
},
405431
wantBalancerAddrs: nil,
406432
wantSC: scJSON,
407433
},
@@ -468,16 +494,23 @@ func (s) TestDNSResolver_Basic(t *testing.T) {
468494
txtLookupTable: map[string][]string{
469495
"_grpc_config.srv.ipv6.single.fake": txtRecordServiceConfig(txtRecordNonMatching),
470496
},
471-
wantAddrs: nil,
472-
wantBalancerAddrs: []resolver.Address{{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake"}},
473-
wantSC: "{}",
497+
wantAddrs: nil,
498+
wantBalancerAddrs: []resolver.Address{
499+
{
500+
Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake",
501+
},
502+
},
503+
wantSC: "{}",
474504
},
475505
{
476506
name: "ipv6_with_SRV_and_multiple_grpclb_address",
477507
target: "srv.ipv6.multi.fake",
478508
hostLookupTable: map[string][]string{
479509
"srv.ipv6.multi.fake": nil,
480-
"ipv6.multi.fake": {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"},
510+
"ipv6.multi.fake": {
511+
"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002",
512+
"2607:f8b0:400a:801::1003",
513+
},
481514
},
482515
srvLookupTable: map[string][]*net.SRV{
483516
"_grpclb._tcp.srv.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}},
@@ -880,8 +913,10 @@ func (s) TestDisableServiceConfig(t *testing.T) {
880913
"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
881914
},
882915
disableServiceConfig: false,
883-
wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
884-
wantSC: scJSON,
916+
wantAddrs: []resolver.Address{
917+
{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort},
918+
},
919+
wantSC: scJSON,
885920
},
886921
{
887922
name: "true",
@@ -891,8 +926,10 @@ func (s) TestDisableServiceConfig(t *testing.T) {
891926
"_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood),
892927
},
893928
disableServiceConfig: true,
894-
wantAddrs: []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}},
895-
wantSC: "{}",
929+
wantAddrs: []resolver.Address{
930+
{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort},
931+
},
932+
wantSC: "{}",
896933
},
897934
}
898935

@@ -1043,7 +1080,9 @@ func (s) TestCustomAuthority(t *testing.T) {
10431080
// Override the address dialer to verify the authority being passed.
10441081
origAddressDialer := dnsinternal.AddressDialer
10451082
errChan := make(chan error, 1)
1046-
dnsinternal.AddressDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
1083+
dnsinternal.AddressDialer = func(authority string) func(ctx context.Context, network, address string) (
1084+
net.Conn, error,
1085+
) {
10471086
if authority != test.wantAuthority {
10481087
errChan <- fmt.Errorf("wrong custom authority passed to resolver. target: %s got authority: %s want authority: %s", test.authority, authority, test.wantAuthority)
10491088
} else {

internal/resolver/dns/fake_net_resolver_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ package dns_test
2020

2121
import (
2222
"context"
23+
"fmt"
2324
"net"
25+
"strings"
2426
"sync"
2527

2628
"google.golang.org/grpc/internal/testutils"
@@ -47,6 +49,10 @@ func (tr *testNetResolver) LookupHost(ctx context.Context, host string) ([]strin
4749
tr.mu.Lock()
4850
defer tr.mu.Unlock()
4951

52+
if strings.Contains(host, "timeoutaddress") {
53+
return nil, fmt.Errorf("timed out")
54+
}
55+
5056
if addrs, ok := tr.hostLookupTable[host]; ok {
5157
return addrs, nil
5258
}

resolver/resolver.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"net"
2727
"net/url"
2828
"strings"
29+
"time"
2930

3031
"google.golang.org/grpc/attributes"
3132
"google.golang.org/grpc/credentials"
@@ -168,6 +169,9 @@ type BuildOptions struct {
168169
// field. In most cases though, it is not appropriate, and this field may
169170
// be ignored.
170171
Dialer func(context.Context, string) (net.Conn, error)
172+
// Timeout specifies time limit for DNS resolver to resolve a target,
173+
// which can be SRV , Host, TXT records.
174+
Timeout time.Duration
171175
}
172176

173177
// An Endpoint is one network endpoint, or server, which may have multiple

resolver_wrapper.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ func (ccr *ccResolverWrapper) start() error {
7575
DialCreds: ccr.cc.dopts.copts.TransportCredentials,
7676
CredsBundle: ccr.cc.dopts.copts.CredsBundle,
7777
Dialer: ccr.cc.dopts.copts.Dialer,
78+
Timeout: ccr.cc.dopts.resolveTimeout,
7879
}
7980
var err error
8081
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)

0 commit comments

Comments
 (0)