Skip to content

Commit

Permalink
Revamp gRPC test server for api tests (#4819)
Browse files Browse the repository at this point in the history
* Revamp gRPC test server for api tests

Unifies server implementation. The test server ensures that all handlers
invocations are done when the server is cleaned up. This allows tests
that want to check post-streaming RPC conditions deterministically.

Signed-off-by: Andrew Harding <azdagron@gmail.com>
  • Loading branch information
azdagron authored Jan 20, 2024
1 parent 113a666 commit b23550a
Show file tree
Hide file tree
Showing 24 changed files with 384 additions and 296 deletions.
2 changes: 1 addition & 1 deletion pkg/agent/api/debug/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
)

// RegisterService registers debug service on provided server
func RegisterService(s *grpc.Server, service *Service) {
func RegisterService(s grpc.ServiceRegistrar, service *Service) {
debugv1.RegisterDebugServer(s, service)
}

Expand Down
12 changes: 5 additions & 7 deletions pkg/agent/api/debug/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/spiffe/spire/pkg/agent/manager/cache"
"github.com/spiffe/spire/pkg/agent/svid"
"github.com/spiffe/spire/test/clock"
"github.com/spiffe/spire/test/grpctest"
"github.com/spiffe/spire/test/spiretest"
"github.com/spiffe/spire/test/testca"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -250,15 +251,12 @@ func setupServiceTest(t *testing.T) *serviceTest {
uptime: fakeUptime,
}

registerFn := func(s *grpc.Server) {
registerFn := func(s grpc.ServiceRegistrar) {
debug.RegisterService(s, service)
}
contextFn := func(ctx context.Context) context.Context {
return ctx
}
conn, done := spiretest.NewAPIServer(t, registerFn, contextFn)
test.done = done
test.client = debugv1.NewDebugClient(conn)
server := grpctest.StartServer(t, registerFn)
test.done = server.Stop
test.client = debugv1.NewDebugClient(server.Dial(t))

return test
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/api/health/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// RegisterService registers the service on the gRPC server.
func RegisterService(s *grpc.Server, service *Service) {
func RegisterService(s grpc.ServiceRegistrar, service *Service) {
grpc_health_v1.RegisterHealthServer(s, service)
}

Expand Down
15 changes: 7 additions & 8 deletions pkg/agent/api/health/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/spiffe/spire/pkg/agent/api/health/v1"
"github.com/spiffe/spire/pkg/agent/api/rpccontext"
"github.com/spiffe/spire/pkg/common/x509util"
"github.com/spiffe/spire/test/grpctest"
"github.com/spiffe/spire/test/spiretest"
"github.com/spiffe/spire/test/testca"

Expand Down Expand Up @@ -97,17 +98,15 @@ func TestServiceCheck(t *testing.T) {
Addr: spiretest.StartWorkloadAPI(t, wlAPI),
})

conn, done := spiretest.NewAPIServer(t,
func(s *grpc.Server) {
health.RegisterService(s, service)
},
func(ctx context.Context) context.Context {
server := grpctest.StartServer(t, func(s grpc.ServiceRegistrar) {
health.RegisterService(s, service)
},
grpctest.OverrideContext(func(ctx context.Context) context.Context {
return rpccontext.WithLogger(ctx, log)
},
}),
)
defer done()

client := grpc_health_v1.NewHealthClient(conn)
client := grpc_health_v1.NewHealthClient(server.Dial(t))
resp, err := client.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{
Service: tt.service,
})
Expand Down
35 changes: 9 additions & 26 deletions pkg/agent/endpoints/workload/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ import (
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/pkg/common/x509util"
"github.com/spiffe/spire/proto/spire/common"
"github.com/spiffe/spire/test/grpctest"
"github.com/spiffe/spire/test/spiretest"
"github.com/spiffe/spire/test/testca"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -1508,45 +1508,28 @@ func runTest(t *testing.T, params testParams, fn func(ctx context.Context, clien
AllowedForeignJWTClaims: params.AllowedForeignJWTClaims,
})

drainHandler := spiretest.NewDrainHandlerMiddleware()
unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain(
drainHandler,
server := grpctest.StartServer(t, func(s grpc.ServiceRegistrar) {
workloadPB.RegisterSpiffeWorkloadAPIServer(s, handler)
}, grpctest.Middleware(
middleware.WithLogger(log),
middleware.Preprocess(func(ctx context.Context, fullMethod string, req any) (context.Context, error) {
return rpccontext.WithCallerPID(ctx, params.AsPID), nil
}),
))

server := grpc.NewServer(
grpc.UnaryInterceptor(unaryInterceptor),
grpc.StreamInterceptor(streamInterceptor),
), grpctest.OverUDS(),
)
workloadPB.RegisterSpiffeWorkloadAPIServer(server, handler)
addr := spiretest.ServeGRPCServerOnTempUDSSocket(t, server)

conn := server.Dial(t)

// Provide a cancelable context to ensure the stream is always
// closed when the test case is done, and also to ensure that
// any unexpected blocking call is timed out.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

conn, err := grpc.DialContext(ctx, "unix:"+addr.String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })

fn(ctx, workloadPB.NewSpiffeWorkloadAPIClient(conn))

cancel()

// Cancelling the stream context above causes the streaming RPC to
// technically "finish", which results in GracefulStop returning before the
// handler implementation finishes executing. We want to ensure that all
// callers of SubscribeToCacheChanges call Finish on the returned
// subscription, so after calling GracefulStop, wait until the statsHandler
// reports that all RPCs are complete before checking that Finish was
// called.
server.GracefulStop()
drainHandler.Wait()
// Stop the server (draining the handlers)
server.Stop()

assert.Equal(t, 0, manager.Subscribers(), "there should be no more subscribers")

Expand Down
2 changes: 1 addition & 1 deletion pkg/server/api/agent/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func New(config Config) *Service {
}

// RegisterService registers the agent service on the gRPC server/
func RegisterService(s *grpc.Server, service *Service) {
func RegisterService(s grpc.ServiceRegistrar, service *Service) {
agentv1.RegisterAgentServer(s, service)
}

Expand Down
31 changes: 14 additions & 17 deletions pkg/server/api/agent/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/spiffe/spire/test/fakes/fakeserverca"
"github.com/spiffe/spire/test/fakes/fakeservercatalog"
"github.com/spiffe/spire/test/fakes/fakeservernodeattestor"
"github.com/spiffe/spire/test/grpctest"
"github.com/spiffe/spire/test/spiretest"
"github.com/spiffe/spire/test/testkey"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -3203,9 +3204,6 @@ func setupServiceTest(t *testing.T, agentSVIDTTL time.Duration) *serviceTest {

log, logHook := test.NewNullLogger()
log.Level = logrus.DebugLevel
registerFn := func(s *grpc.Server) {
agent.RegisterService(s, service)
}

rateLimiter := &fakeRateLimiter{}

Expand All @@ -3218,27 +3216,26 @@ func setupServiceTest(t *testing.T, agentSVIDTTL time.Duration) *serviceTest {
rateLimiter: rateLimiter,
}

ppMiddleware := middleware.Preprocess(func(ctx context.Context, fullMethod string, req any) (context.Context, error) {
overrideContext := func(ctx context.Context) context.Context {
ctx = rpccontext.WithLogger(ctx, log)
ctx = rpccontext.WithRateLimiter(ctx, rateLimiter)
if test.withCallerID {
ctx = rpccontext.WithCallerID(ctx, agentID)
}
return ctx, nil
})
unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain(
ppMiddleware,
// Add audit log with local tracking disabled
middleware.WithAuditLog(false),
))

server := grpc.NewServer(
grpc.UnaryInterceptor(unaryInterceptor),
grpc.StreamInterceptor(streamInterceptor),
return ctx
}

server := grpctest.StartServer(t, func(s grpc.ServiceRegistrar) {
agent.RegisterService(s, service)
},
grpctest.OverrideContext(overrideContext),
grpctest.Middleware(middleware.WithAuditLog(false)),
)
conn, done := spiretest.NewAPIServerWithMiddleware(t, registerFn, server)
test.done = done

conn := server.Dial(t)

test.client = agentv1.NewAgentClient(conn)
test.done = server.Stop

return test
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/api/bundle/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func New(config Config) *Service {
}

// RegisterService registers the bundle service on the gRPC server.
func RegisterService(s *grpc.Server, service *Service) {
func RegisterService(s grpc.ServiceRegistrar, service *Service) {
bundlev1.RegisterBundleServer(s, service)
}

Expand Down
29 changes: 13 additions & 16 deletions pkg/server/api/bundle/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/spiffe/spire/pkg/server/datastore"
"github.com/spiffe/spire/proto/spire/common"
"github.com/spiffe/spire/test/fakes/fakedatastore"
"github.com/spiffe/spire/test/grpctest"
"github.com/spiffe/spire/test/spiretest"
"github.com/spiffe/spire/test/testca"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -2905,9 +2906,6 @@ func setupServiceTest(t *testing.T) *serviceTest {

log, logHook := test.NewNullLogger()
log.Level = logrus.DebugLevel
registerFn := func(s *grpc.Server) {
bundle.RegisterService(s, service)
}

test := &serviceTest{
ds: ds,
Expand All @@ -2916,7 +2914,7 @@ func setupServiceTest(t *testing.T) *serviceTest {
rateLimiter: rateLimiter,
}

ppMiddleware := middleware.Preprocess(func(ctx context.Context, fullMethod string, req any) (context.Context, error) {
overrideContext := func(ctx context.Context) context.Context {
ctx = rpccontext.WithLogger(ctx, log)
if test.isAdmin {
ctx = rpccontext.WithAdminCaller(ctx)
Expand All @@ -2932,21 +2930,20 @@ func setupServiceTest(t *testing.T) *serviceTest {
}

ctx = rpccontext.WithRateLimiter(ctx, rateLimiter)
return ctx, nil
})
return ctx
}

unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain(
ppMiddleware,
// Add audit log with local tracking disabled
middleware.WithAuditLog(false),
))
server := grpc.NewServer(
grpc.UnaryInterceptor(unaryInterceptor),
grpc.StreamInterceptor(streamInterceptor),
server := grpctest.StartServer(t, func(s grpc.ServiceRegistrar) {
bundle.RegisterService(s, service)
},
grpctest.OverrideContext(overrideContext),
grpctest.Middleware(middleware.WithAuditLog(false)),
)
conn, done := spiretest.NewAPIServerWithMiddleware(t, registerFn, server)
test.done = done

conn := server.Dial(t)

test.client = bundlev1.NewBundleClient(conn)
test.done = server.Stop

return test
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/api/debug/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const (
)

// RegisterService registers debug service on provided server
func RegisterService(s *grpc.Server, service *Service) {
func RegisterService(s grpc.ServiceRegistrar, service *Service) {
debugv1.RegisterDebugServer(s, service)
}

Expand Down
12 changes: 8 additions & 4 deletions pkg/server/api/debug/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/spiffe/spire/pkg/server/svid"
"github.com/spiffe/spire/proto/spire/common"
"github.com/spiffe/spire/test/fakes/fakedatastore"
"github.com/spiffe/spire/test/grpctest"
"github.com/spiffe/spire/test/spiretest"
"github.com/spiffe/spire/test/testca"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -464,16 +465,19 @@ func setupServiceTest(t *testing.T) *serviceTest {
uptime: fakeUptime,
}

registerFn := func(s *grpc.Server) {
registerFn := func(s grpc.ServiceRegistrar) {
debug.RegisterService(s, service)
}
contextFn := func(ctx context.Context) context.Context {
overrideContext := func(ctx context.Context) context.Context {
ctx = rpccontext.WithLogger(ctx, log)
return ctx
}

conn, done := spiretest.NewAPIServer(t, registerFn, contextFn)
test.done = done
server := grpctest.StartServer(t, registerFn, grpctest.OverrideContext(overrideContext))

conn := server.Dial(t)

test.done = server.Stop
test.client = debugv1.NewDebugClient(conn)

return test
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/api/entry/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func New(config Config) *Service {
}

// RegisterService registers the entry service on the gRPC server.
func RegisterService(s *grpc.Server, service *Service) {
func RegisterService(s grpc.ServiceRegistrar, service *Service) {
entryv1.RegisterEntryServer(s, service)
}

Expand Down
35 changes: 12 additions & 23 deletions pkg/server/api/entry/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/spiffe/spire/pkg/server/datastore"
"github.com/spiffe/spire/proto/spire/common"
"github.com/spiffe/spire/test/fakes/fakedatastore"
"github.com/spiffe/spire/test/grpctest"
"github.com/spiffe/spire/test/spiretest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -4751,43 +4752,31 @@ func setupServiceTest(t *testing.T, ds datastore.DataStore, options ...serviceTe
})

log, logHook := test.NewNullLogger()
registerFn := func(s *grpc.Server) {
entry.RegisterService(s, service)
}

test := &serviceTest{
ds: ds,
logHook: logHook,
ef: ef,
}

ppMiddleware := middleware.Preprocess(func(ctx context.Context, fullMethod string, req any) (context.Context, error) {
overrideContext := func(ctx context.Context) context.Context {
ctx = rpccontext.WithLogger(ctx, log)
if !test.omitCallerID {
ctx = rpccontext.WithCallerID(ctx, agentID)
}
return ctx, nil
})
return ctx
}

drainHandler := spiretest.NewDrainHandlerMiddleware()

unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain(
drainHandler,
ppMiddleware,
// Add audit log with local tracking disabled
middleware.WithAuditLog(false),
))
server := grpc.NewServer(
grpc.UnaryInterceptor(unaryInterceptor),
grpc.StreamInterceptor(streamInterceptor),
server := grpctest.StartServer(t, func(s grpc.ServiceRegistrar) {
entry.RegisterService(s, service)
},
grpctest.OverrideContext(overrideContext),
grpctest.Middleware(middleware.WithAuditLog(false)),
)

conn, done := spiretest.NewAPIServerWithMiddleware(t, registerFn, server)
test.done = func() {
done()
drainHandler.Wait()
}
conn := server.Dial(t)

test.client = entryv1.NewEntryClient(conn)
test.done = server.Stop

return test
}
Expand Down
Loading

0 comments on commit b23550a

Please sign in to comment.