Skip to content

Commit 608e03b

Browse files
hasson82printchard
authored andcommitted
scripts: add linter rule for using context.WithTimeout on tests (grpc#7342)
1 parent e88ac1e commit 608e03b

File tree

13 files changed

+105
-28
lines changed

13 files changed

+105
-28
lines changed

credentials/alts/internal/handshaker/handshaker_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,9 @@ func (s) TestNewClientHandshaker(t *testing.T) {
309309
conn := testutil.NewTestConn(nil, nil)
310310
clientConn := &grpc.ClientConn{}
311311
opts := &ClientHandshakerOptions{}
312-
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
312+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
313+
defer cancel()
314+
hs, err := NewClientHandshaker(ctx, clientConn, conn, opts)
313315
if err != nil {
314316
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
315317
}
@@ -341,7 +343,9 @@ func (s) TestNewServerHandshaker(t *testing.T) {
341343
conn := testutil.NewTestConn(nil, nil)
342344
clientConn := &grpc.ClientConn{}
343345
opts := &ServerHandshakerOptions{}
344-
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
346+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
347+
defer cancel()
348+
hs, err := NewServerHandshaker(ctx, clientConn, conn, opts)
345349
if err != nil {
346350
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
347351
}

gcp/observability/observability_test.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ func (s) TestRefuseStartWithInvalidPatterns(t *testing.T) {
184184
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
185185
}()
186186
// If there is at least one invalid pattern, which should not be silently tolerated.
187-
if err := Start(context.Background()); err == nil {
187+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
188+
defer cancel()
189+
if err := Start(ctx); err == nil {
188190
t.Fatalf("Invalid patterns not triggering error")
189191
}
190192
}
@@ -220,7 +222,9 @@ func (s) TestRefuseStartWithExcludeAndWildCardAll(t *testing.T) {
220222
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
221223
}()
222224
// If there is at least one invalid pattern, which should not be silently tolerated.
223-
if err := Start(context.Background()); err == nil {
225+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
226+
defer cancel()
227+
if err := Start(ctx); err == nil {
224228
t.Fatalf("Invalid patterns not triggering error")
225229
}
226230
}
@@ -316,7 +320,9 @@ func (s) TestBothConfigEnvVarsSet(t *testing.T) {
316320
defer func() {
317321
envconfig.ObservabilityConfig = oldObservabilityConfig
318322
}()
319-
if err := Start(context.Background()); err == nil {
323+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
324+
defer cancel()
325+
if err := Start(ctx); err == nil {
320326
t.Fatalf("Invalid patterns not triggering error")
321327
}
322328
}
@@ -331,7 +337,9 @@ func (s) TestErrInFileSystemEnvVar(t *testing.T) {
331337
defer func() {
332338
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
333339
}()
334-
if err := Start(context.Background()); err == nil {
340+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
341+
defer cancel()
342+
if err := Start(ctx); err == nil {
335343
t.Fatalf("Invalid file system path not triggering error")
336344
}
337345
}
@@ -346,7 +354,9 @@ func (s) TestNoEnvSet(t *testing.T) {
346354
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
347355
}()
348356
// If there is no observability config set at all, the Start should return an error.
349-
if err := Start(context.Background()); err == nil {
357+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
358+
defer cancel()
359+
if err := Start(ctx); err == nil {
350360
t.Fatalf("Invalid patterns not triggering error")
351361
}
352362
}
@@ -540,7 +550,9 @@ func (s) TestStartErrorsThenEnd(t *testing.T) {
540550
envconfig.ObservabilityConfig = oldObservabilityConfig
541551
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
542552
}()
543-
if err := Start(context.Background()); err == nil {
553+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
554+
defer cancel()
555+
if err := Start(ctx); err == nil {
544556
t.Fatalf("Invalid patterns not triggering error")
545557
}
546558
End()

internal/binarylog/method_logger_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import (
3333
"google.golang.org/protobuf/types/known/durationpb"
3434
)
3535

36+
const defaultTestTimeout = 10 * time.Second
37+
3638
func (s) TestLog(t *testing.T) {
3739
idGen.reset()
3840
ml := NewTruncatingMethodLogger(10, 10)
@@ -333,10 +335,12 @@ func (s) TestLog(t *testing.T) {
333335
},
334336
},
335337
}
338+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
339+
defer cancel()
336340
for i, tc := range testCases {
337341
buf.Reset()
338342
tc.want.SequenceIdWithinCall = uint64(i + 1)
339-
ml.Log(context.Background(), tc.config)
343+
ml.Log(ctx, tc.config)
340344
inSink := new(binlogpb.GrpcLogEntry)
341345
if err := proto.Unmarshal(buf.Bytes()[4:], inSink); err != nil {
342346
t.Errorf("failed to unmarshal bytes in sink to proto: %v", err)

internal/xds/bootstrap/tlscreds/bundle_ext_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ func (s) TestMTLS(t *testing.T) {
247247
}
248248
defer conn.Close()
249249
client := testgrpc.NewTestServiceClient(conn)
250-
if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
250+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
251+
defer cancel()
252+
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
251253
t.Errorf("EmptyCall(): got error %v when expected to succeed", err)
252254
}
253255
}

internal/xds/bootstrap/tlscreds/bundle_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"fmt"
2626
"strings"
2727
"testing"
28+
"time"
2829

2930
"google.golang.org/grpc"
3031
"google.golang.org/grpc/credentials/tls/certprovider"
@@ -37,6 +38,8 @@ import (
3738
testpb "google.golang.org/grpc/interop/grpc_testing"
3839
)
3940

41+
const defaultTestTimeout = 5 * time.Second
42+
4043
type s struct {
4144
grpctest.Tester
4245
}
@@ -86,7 +89,9 @@ func (s) TestFailingProvider(t *testing.T) {
8689
defer conn.Close()
8790

8891
client := testgrpc.NewTestServiceClient(conn)
89-
_, err = client.EmptyCall(context.Background(), &testpb.Empty{})
92+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
93+
defer cancel()
94+
_, err = client.EmptyCall(ctx, &testpb.Empty{})
9095
if wantErr := "test error"; err == nil || !strings.Contains(err.Error(), wantErr) {
9196
t.Errorf("EmptyCall() got err: %s, want err to contain: %s", err, wantErr)
9297
}

internal/xds/rbac/rbac_engine_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"net/url"
2828
"reflect"
2929
"testing"
30+
"time"
3031

3132
v1xdsudpatypepb "github.com/cncf/xds/go/udpa/type/v1"
3233
v3xdsxdstypepb "github.com/cncf/xds/go/xds/type/v3"
@@ -48,6 +49,8 @@ import (
4849
"google.golang.org/protobuf/types/known/wrapperspb"
4950
)
5051

52+
const defaultTestTimeout = 10 * time.Second
53+
5154
type s struct {
5255
grpctest.Tester
5356
}
@@ -1742,14 +1745,15 @@ func (s) TestChainEngine(t *testing.T) {
17421745
}
17431746
// Query the created chain of RBAC Engines with different args to see
17441747
// if the chain of RBAC Engines configured as such works as intended.
1748+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1749+
defer cancel()
17451750
for _, data := range test.rbacQueries {
17461751
func() {
17471752
// Construct the context with three data points that have enough
17481753
// information to represent incoming RPC's. This will be how a
17491754
// user uses this API. A user will have to put MD, PeerInfo, and
17501755
// the connection the RPC is sent on in the context.
1751-
ctx := metadata.NewIncomingContext(context.Background(), data.rpcData.md)
1752-
1756+
ctx = metadata.NewIncomingContext(ctx, data.rpcData.md)
17531757
// Make a TCP connection with a certain destination port. The
17541758
// address/port of this connection will be used to populate the
17551759
// destination ip/port in RPCData struct. This represents what

metadata/metadata_test.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,11 @@ func (s) TestFromIncomingContext(t *testing.T) {
202202
md := Pairs(
203203
"X-My-Header-1", "42",
204204
)
205+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
206+
defer cancel()
205207
// Verify that we lowercase if callers directly modify md
206208
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
207-
ctx := NewIncomingContext(context.Background(), md)
209+
ctx = NewIncomingContext(ctx, md)
208210

209211
result, found := FromIncomingContext(ctx)
210212
if !found {
@@ -238,9 +240,11 @@ func (s) TestValueFromIncomingContext(t *testing.T) {
238240
"X-My-Header-2", "43-2",
239241
"x-my-header-3", "44",
240242
)
243+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
244+
defer cancel()
241245
// Verify that we lowercase if callers directly modify md
242246
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
243-
ctx := NewIncomingContext(context.Background(), md)
247+
ctx = NewIncomingContext(ctx, md)
244248

245249
for _, test := range []struct {
246250
key string
@@ -376,17 +380,22 @@ func BenchmarkFromOutgoingContext(b *testing.B) {
376380
}
377381

378382
func BenchmarkFromIncomingContext(b *testing.B) {
383+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
384+
defer cancel()
379385
md := Pairs("X-My-Header-1", "42")
380-
ctx := NewIncomingContext(context.Background(), md)
386+
ctx = NewIncomingContext(ctx, md)
387+
381388
b.ResetTimer()
382389
for n := 0; n < b.N; n++ {
383390
FromIncomingContext(ctx)
384391
}
385392
}
386393

387394
func BenchmarkValueFromIncomingContext(b *testing.B) {
395+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
396+
defer cancel()
388397
md := Pairs("X-My-Header-1", "42")
389-
ctx := NewIncomingContext(context.Background(), md)
398+
ctx = NewIncomingContext(ctx, md)
390399

391400
b.Run("key-found", func(b *testing.B) {
392401
for n := 0; n < b.N; n++ {

peer/peer_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ import (
2222
"context"
2323
"fmt"
2424
"testing"
25+
"time"
2526

2627
"google.golang.org/grpc/credentials"
2728
)
2829

30+
const defaultTestTimeout = 10 * time.Second
31+
2932
// A struct that implements AuthInfo interface and implements CommonAuthInfo() method.
3033
type testAuthInfo struct {
3134
credentials.CommonAuthInfo
@@ -80,9 +83,12 @@ func TestPeerStringer(t *testing.T) {
8083
want: "Peer<nil>",
8184
},
8285
}
86+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
87+
defer cancel()
8388
for _, tc := range testCases {
8489
t.Run(tc.name, func(t *testing.T) {
85-
ctx := NewContext(context.Background(), tc.peer)
90+
ctx = NewContext(ctx, tc.peer)
91+
8692
p, ok := FromContext(ctx)
8793
if !ok {
8894
t.Fatalf("Unable to get peer from context")

picker_wrapper_test.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,11 @@ func (s) TestBlockingPick(t *testing.T) {
7878
bp := newPickerWrapper(nil)
7979
// All goroutines should block because picker is nil in bp.
8080
var finishedCount uint64
81+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
82+
defer cancel()
8183
for i := goroutineCount; i > 0; i-- {
8284
go func() {
83-
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
85+
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
8486
t.Errorf("bp.pick returned non-nil error: %v", err)
8587
}
8688
atomic.AddUint64(&finishedCount, 1)
@@ -97,10 +99,12 @@ func (s) TestBlockingPickNoSubAvailable(t *testing.T) {
9799
bp := newPickerWrapper(nil)
98100
var finishedCount uint64
99101
bp.updatePicker(&testingPicker{err: balancer.ErrNoSubConnAvailable, maxCalled: goroutineCount})
102+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
103+
defer cancel()
100104
// All goroutines should block because picker returns no subConn available.
101105
for i := goroutineCount; i > 0; i-- {
102106
go func() {
103-
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
107+
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
104108
t.Errorf("bp.pick returned non-nil error: %v", err)
105109
}
106110
atomic.AddUint64(&finishedCount, 1)
@@ -117,11 +121,13 @@ func (s) TestBlockingPickTransientWaitforready(t *testing.T) {
117121
bp := newPickerWrapper(nil)
118122
bp.updatePicker(&testingPicker{err: balancer.ErrTransientFailure, maxCalled: goroutineCount})
119123
var finishedCount uint64
124+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
125+
defer cancel()
120126
// All goroutines should block because picker returns transientFailure and
121127
// picks are not failfast.
122128
for i := goroutineCount; i > 0; i-- {
123129
go func() {
124-
if tr, _, err := bp.pick(context.Background(), false, balancer.PickInfo{}); err != nil || tr != testT {
130+
if tr, _, err := bp.pick(ctx, false, balancer.PickInfo{}); err != nil || tr != testT {
125131
t.Errorf("bp.pick returned non-nil error: %v", err)
126132
}
127133
atomic.AddUint64(&finishedCount, 1)
@@ -138,10 +144,12 @@ func (s) TestBlockingPickSCNotReady(t *testing.T) {
138144
bp := newPickerWrapper(nil)
139145
bp.updatePicker(&testingPicker{sc: testSCNotReady, maxCalled: goroutineCount})
140146
var finishedCount uint64
147+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
148+
defer cancel()
141149
// All goroutines should block because subConn is not ready.
142150
for i := goroutineCount; i > 0; i-- {
143151
go func() {
144-
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
152+
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
145153
t.Errorf("bp.pick returned non-nil error: %v", err)
146154
}
147155
atomic.AddUint64(&finishedCount, 1)

scripts/vet.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ not git grep "\(import \|^\s*\)\"google.golang.org/grpc/interop/grpc_testing" --
6969
# - Ensure all xds proto imports are renamed to *pb or *grpc.
7070
git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "'
7171

72+
# - Ensure all context usages are done with timeout.
73+
# Context tests under benchmark are excluded as they are testing the performance of context.Background() and context.TODO().
74+
# TODO: Remove the exclusions once the tests are updated to use context.WithTimeout().
75+
# See https://github.com/grpc/grpc-go/issues/7304
76+
git grep -e 'context.Background()' --or -e 'context.TODO()' -- "*_test.go" | grep -v "benchmark/primitives/context_test.go" | grep -v "credential
77+
s/google" | grep -v "internal/transport/" | grep -v "xds/internal/" | grep -v "security/advancedtls" | grep -v 'context.WithTimeout(' | not grep -v 'context.WithCancel('
78+
7279
misspell -error .
7380

7481
# - gofmt, goimports, go vet, go mod tidy.

server_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,22 @@ func (s) TestRetryChainedInterceptor(t *testing.T) {
153153
handler := func(ctx context.Context, req any) (any, error) {
154154
return nil, nil
155155
}
156-
ii(context.Background(), nil, nil, handler)
156+
157+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
158+
defer cancel()
159+
160+
ii(ctx, nil, nil, handler)
157161
if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
158162
t.Fatalf("retry failed on chained interceptors: %v", records)
159163
}
160164
}
161165

162166
func (s) TestStreamContext(t *testing.T) {
163167
expectedStream := &transport.Stream{}
164-
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
168+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
169+
defer cancel()
170+
ctx = NewContextWithServerTransportStream(ctx, expectedStream)
171+
165172
s := ServerTransportStreamFromContext(ctx)
166173
stream, ok := s.(*transport.Stream)
167174
if !ok || expectedStream != stream {
@@ -170,6 +177,8 @@ func (s) TestStreamContext(t *testing.T) {
170177
}
171178

172179
func BenchmarkChainUnaryInterceptor(b *testing.B) {
180+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
181+
defer cancel()
173182
for _, n := range []int{1, 3, 5, 10} {
174183
n := n
175184
b.Run(strconv.Itoa(n), func(b *testing.B) {
@@ -186,7 +195,7 @@ func BenchmarkChainUnaryInterceptor(b *testing.B) {
186195
b.ReportAllocs()
187196
b.ResetTimer()
188197
for i := 0; i < b.N; i++ {
189-
if _, err := s.opts.unaryInt(context.Background(), nil, nil,
198+
if _, err := s.opts.unaryInt(ctx, nil, nil,
190199
func(ctx context.Context, req any) (any, error) {
191200
return nil, nil
192201
},

stats/opentelemetry/csm/observability_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,9 @@ func (s) TestXDSLabels(t *testing.T) {
602602
// without error. The actual functionality of this function will be verified in
603603
// interop tests.
604604
func (s) TestObservability(t *testing.T) {
605-
cleanup := EnableObservability(context.Background(), opentelemetry.Options{})
605+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
606+
defer cancel()
607+
608+
cleanup := EnableObservability(ctx, opentelemetry.Options{})
606609
cleanup()
607610
}

0 commit comments

Comments
 (0)