Skip to content

scripts: add linter rule for using context.WithTimeout on tests #7342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
58ff83f
add context with timeout check to vet script
hasson82 Jun 19, 2024
7da7920
expand lint rule to avoid dealing with context.WithCancel cases
hasson82 Jun 19, 2024
d47051e
add context.WithTimeout to handeshaker tests
hasson82 Jun 19, 2024
5acc6e3
avoid context_test.go in linting
hasson82 Jun 19, 2024
c47300f
Change the lint rule to avoid context_test.go file
hasson82 Jun 19, 2024
d595842
add context with timeout to observability tests
hasson82 Jun 19, 2024
a1e7035
ignore credentials/google package as this package doesnt have default…
hasson82 Jun 19, 2024
4a7ec24
Change all context in bundle_ext_test.go to context with timeout
hasson82 Jun 19, 2024
912de44
Change all context in bundle_test.go to context with timeout
hasson82 Jun 19, 2024
150515f
Change rbac_engine_test context to context with timeout
hasson82 Jun 19, 2024
9cb967d
add context with timeout to method_logger_test in binary log
hasson82 Jun 20, 2024
53cc270
adding internal/transport package to avoid list in linter rule
hasson82 Jun 20, 2024
befaa3d
reduce rule in linter to avoid catching cases were context.WithTimeou…
hasson82 Jun 20, 2024
eaccdda
change all context usage in metadata_test to context.WithTimeout
hasson82 Jun 20, 2024
a4e7742
add xds/internal to avoid list of linter rule
hasson82 Jun 20, 2024
aa71ff0
add context.WithTimeout to peer_test
hasson82 Jun 20, 2024
d1aefe7
change all context usage in picker_wrapper_test to context.WithTimeout
hasson82 Jun 20, 2024
40a54fb
add security/advancedtls to avoid list in linter rule
hasson82 Jun 20, 2024
96345b1
change all usage of context in server_test.go to context.WithTimeout
hasson82 Jun 20, 2024
59a3855
change missing context usage in observability_test to context.WithTim…
hasson82 Jun 20, 2024
cfc815b
change all context usage to context.WithTimeout in gracefulstop_test
hasson82 Jun 20, 2024
c33cbef
remove added lint from vet.sh
hasson82 Jun 20, 2024
e7192b1
add lint rule of all context should have timeouts back
hasson82 Jun 20, 2024
689ae5b
change lint rule in all avoided packages to not grep
hasson82 Jun 20, 2024
6a51d0d
change lint rule to not to enforce that the grep returns nothing
hasson82 Jun 20, 2024
c29c45c
change lint rule to not to enforce that the grep returns nothing
hasson82 Jun 20, 2024
7f744cd
Merge branch 'hasson82/add-context-with-timeout-checks-to-lint' of gi…
hasson82 Jun 20, 2024
f0f0952
add clarifications to the comment above the lint rule
hasson82 Jun 27, 2024
0430481
revert const change in test
hasson82 Jun 27, 2024
832789b
move context creation to improve test readability
hasson82 Jun 27, 2024
9719207
extracted context with timeout setup to central location to improve t…
hasson82 Jun 28, 2024
3081de9
improve test readability
hasson82 Jun 28, 2024
cb43ba0
improve stream context test readability
hasson82 Jun 28, 2024
6c25072
rename ctxWithTimeout to ctx
hasson82 Jun 28, 2024
4b89735
move ctx to inside each test
hasson82 Jun 28, 2024
fdca263
move context creation to outside of the loops
hasson82 Jul 3, 2024
d76f29c
extract context creation from some more loops
hasson82 Jul 3, 2024
136b4ba
extract outside of all loops in bench mark server_test
hasson82 Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ func (s) TestNewClientHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ClientHandshakerOptions{}
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
hs, err := NewClientHandshaker(ctx, clientConn, conn, opts)
if err != nil {
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
}
Expand Down Expand Up @@ -341,7 +343,9 @@ func (s) TestNewServerHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ServerHandshakerOptions{}
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
hs, err := NewServerHandshaker(ctx, clientConn, conn, opts)
if err != nil {
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
}
Expand Down
24 changes: 18 additions & 6 deletions gcp/observability/observability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ func (s) TestRefuseStartWithInvalidPatterns(t *testing.T) {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
// If there is at least one invalid pattern, which should not be silently tolerated.
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand Down Expand Up @@ -220,7 +222,9 @@ func (s) TestRefuseStartWithExcludeAndWildCardAll(t *testing.T) {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
// If there is at least one invalid pattern, which should not be silently tolerated.
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand Down Expand Up @@ -316,7 +320,9 @@ func (s) TestBothConfigEnvVarsSet(t *testing.T) {
defer func() {
envconfig.ObservabilityConfig = oldObservabilityConfig
}()
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand All @@ -331,7 +337,9 @@ func (s) TestErrInFileSystemEnvVar(t *testing.T) {
defer func() {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid file system path not triggering error")
}
}
Expand All @@ -346,7 +354,9 @@ func (s) TestNoEnvSet(t *testing.T) {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
// If there is no observability config set at all, the Start should return an error.
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand Down Expand Up @@ -540,7 +550,9 @@ func (s) TestStartErrorsThenEnd(t *testing.T) {
envconfig.ObservabilityConfig = oldObservabilityConfig
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
End()
Expand Down
6 changes: 5 additions & 1 deletion internal/binarylog/method_logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"google.golang.org/protobuf/types/known/durationpb"
)

const defaultTestTimeout = 10 * time.Second

func (s) TestLog(t *testing.T) {
idGen.reset()
ml := NewTruncatingMethodLogger(10, 10)
Expand Down Expand Up @@ -333,10 +335,12 @@ func (s) TestLog(t *testing.T) {
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i, tc := range testCases {
buf.Reset()
tc.want.SequenceIdWithinCall = uint64(i + 1)
ml.Log(context.Background(), tc.config)
ml.Log(ctx, tc.config)
inSink := new(binlogpb.GrpcLogEntry)
if err := proto.Unmarshal(buf.Bytes()[4:], inSink); err != nil {
t.Errorf("failed to unmarshal bytes in sink to proto: %v", err)
Expand Down
4 changes: 3 additions & 1 deletion internal/xds/bootstrap/tlscreds/bundle_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ func (s) TestMTLS(t *testing.T) {
}
defer conn.Close()
client := testgrpc.NewTestServiceClient(conn)
if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Errorf("EmptyCall(): got error %v when expected to succeed", err)
}
}
7 changes: 6 additions & 1 deletion internal/xds/bootstrap/tlscreds/bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials/tls/certprovider"
Expand All @@ -36,6 +37,8 @@ import (
"google.golang.org/grpc/testdata"
)

const defaultTestTimeout = 5 * time.Second

type s struct {
grpctest.Tester
}
Expand Down Expand Up @@ -85,7 +88,9 @@ func (s) TestFailingProvider(t *testing.T) {
defer conn.Close()

client := testgrpc.NewTestServiceClient(conn)
_, err = client.EmptyCall(context.Background(), &testpb.Empty{})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err = client.EmptyCall(ctx, &testpb.Empty{})
if wantErr := "test error"; err == nil || !strings.Contains(err.Error(), wantErr) {
t.Errorf("EmptyCall() got err: %s, want err to contain: %s", err, wantErr)
}
Expand Down
8 changes: 6 additions & 2 deletions internal/xds/rbac/rbac_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net/url"
"reflect"
"testing"
"time"

v1xdsudpatypepb "github.com/cncf/xds/go/udpa/type/v1"
v3xdsxdstypepb "github.com/cncf/xds/go/xds/type/v3"
Expand All @@ -48,6 +49,8 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
)

const defaultTestTimeout = 10 * time.Second

type s struct {
grpctest.Tester
}
Expand Down Expand Up @@ -1742,14 +1745,15 @@ func (s) TestChainEngine(t *testing.T) {
}
// Query the created chain of RBAC Engines with different args to see
// if the chain of RBAC Engines configured as such works as intended.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, data := range test.rbacQueries {
func() {
// Construct the context with three data points that have enough
// information to represent incoming RPC's. This will be how a
// user uses this API. A user will have to put MD, PeerInfo, and
// the connection the RPC is sent on in the context.
ctx := metadata.NewIncomingContext(context.Background(), data.rpcData.md)

ctx = metadata.NewIncomingContext(ctx, data.rpcData.md)
// Make a TCP connection with a certain destination port. The
// address/port of this connection will be used to populate the
// destination ip/port in RPCData struct. This represents what
Expand Down
17 changes: 13 additions & 4 deletions metadata/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,11 @@ func (s) TestFromIncomingContext(t *testing.T) {
md := Pairs(
"X-My-Header-1", "42",
)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Verify that we lowercase if callers directly modify md
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

result, found := FromIncomingContext(ctx)
if !found {
Expand Down Expand Up @@ -239,9 +241,11 @@ func (s) TestValueFromIncomingContext(t *testing.T) {
"X-My-Header-2", "43-2",
"x-my-header-3", "44",
)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Verify that we lowercase if callers directly modify md
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

for _, test := range []struct {
key string
Expand Down Expand Up @@ -397,17 +401,22 @@ func BenchmarkFromOutgoingContext(b *testing.B) {
}

func BenchmarkFromIncomingContext(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md := Pairs("X-My-Header-1", "42")
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

b.ResetTimer()
for n := 0; n < b.N; n++ {
FromIncomingContext(ctx)
}
}

func BenchmarkValueFromIncomingContext(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md := Pairs("X-My-Header-1", "42")
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

b.Run("key-found", func(b *testing.B) {
for n := 0; n < b.N; n++ {
Expand Down
8 changes: 7 additions & 1 deletion peer/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ import (
"context"
"fmt"
"testing"
"time"

"google.golang.org/grpc/credentials"
)

const defaultTestTimeout = 10 * time.Second

// A struct that implements AuthInfo interface and implements CommonAuthInfo() method.
type testAuthInfo struct {
credentials.CommonAuthInfo
Expand Down Expand Up @@ -80,9 +83,12 @@ func TestPeerStringer(t *testing.T) {
want: "Peer<nil>",
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx := NewContext(context.Background(), tc.peer)
ctx = NewContext(ctx, tc.peer)

p, ok := FromContext(ctx)
if !ok {
t.Fatalf("Unable to get peer from context")
Expand Down
16 changes: 12 additions & 4 deletions picker_wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ func (s) TestBlockingPick(t *testing.T) {
bp := newPickerWrapper(nil)
// All goroutines should block because picker is nil in bp.
var finishedCount uint64
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand All @@ -97,10 +99,12 @@ func (s) TestBlockingPickNoSubAvailable(t *testing.T) {
bp := newPickerWrapper(nil)
var finishedCount uint64
bp.updatePicker(&testingPicker{err: balancer.ErrNoSubConnAvailable, maxCalled: goroutineCount})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// All goroutines should block because picker returns no subConn available.
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand All @@ -117,11 +121,13 @@ func (s) TestBlockingPickTransientWaitforready(t *testing.T) {
bp := newPickerWrapper(nil)
bp.updatePicker(&testingPicker{err: balancer.ErrTransientFailure, maxCalled: goroutineCount})
var finishedCount uint64
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// All goroutines should block because picker returns transientFailure and
// picks are not failfast.
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), false, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, false, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand All @@ -138,10 +144,12 @@ func (s) TestBlockingPickSCNotReady(t *testing.T) {
bp := newPickerWrapper(nil)
bp.updatePicker(&testingPicker{sc: testSCNotReady, maxCalled: goroutineCount})
var finishedCount uint64
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// All goroutines should block because subConn is not ready.
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand Down
7 changes: 7 additions & 0 deletions scripts/vet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ not git grep "\(import \|^\s*\)\"google.golang.org/grpc/interop/grpc_testing" --
# - Ensure all xds proto imports are renamed to *pb or *grpc.
git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "'

# - Ensure all context usages are done with timeout.
# Context tests under benchmark are excluded as they are testing the performance of context.Background() and context.TODO().
# TODO: Remove the exclusions once the tests are updated to use context.WithTimeout().
# See https://github.com/grpc/grpc-go/issues/7304
git grep -e 'context.Background()' --or -e 'context.TODO()' -- "*_test.go" | grep -v "benchmark/primitives/context_test.go" | grep -v "credential
s/google" | grep -v "internal/transport/" | grep -v "xds/internal/" | grep -v "security/advancedtls" | grep -v 'context.WithTimeout(' | not grep -v 'context.WithCancel('

misspell -error .

# - gofmt, goimports, go vet, go mod tidy.
Expand Down
15 changes: 12 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,22 @@ func (s) TestRetryChainedInterceptor(t *testing.T) {
handler := func(ctx context.Context, req any) (any, error) {
return nil, nil
}
ii(context.Background(), nil, nil, handler)

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

ii(ctx, nil, nil, handler)
if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
t.Fatalf("retry failed on chained interceptors: %v", records)
}
}

func (s) TestStreamContext(t *testing.T) {
expectedStream := &transport.Stream{}
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = NewContextWithServerTransportStream(ctx, expectedStream)

s := ServerTransportStreamFromContext(ctx)
stream, ok := s.(*transport.Stream)
if !ok || expectedStream != stream {
Expand All @@ -170,6 +177,8 @@ func (s) TestStreamContext(t *testing.T) {
}

func BenchmarkChainUnaryInterceptor(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, n := range []int{1, 3, 5, 10} {
n := n
b.Run(strconv.Itoa(n), func(b *testing.B) {
Expand All @@ -186,7 +195,7 @@ func BenchmarkChainUnaryInterceptor(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := s.opts.unaryInt(context.Background(), nil, nil,
if _, err := s.opts.unaryInt(ctx, nil, nil,
func(ctx context.Context, req any) (any, error) {
return nil, nil
},
Expand Down
5 changes: 4 additions & 1 deletion stats/opentelemetry/csm/observability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ func (s) TestXDSLabels(t *testing.T) {
// without error. The actual functionality of this function will be verified in
// interop tests.
func (s) TestObservability(t *testing.T) {
cleanup := EnableObservability(context.Background(), opentelemetry.Options{})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

cleanup := EnableObservability(ctx, opentelemetry.Options{})
cleanup()
}
Loading