Skip to content

Commit

Permalink
Use GRPC interceptors instead of explicit context wrappers (#6133)
Browse files Browse the repository at this point in the history
## Which problem is this PR solving?
- Resolves #6035 

## Description of the changes
- 

## How was this change tested?
- 

## Checklist
- [ ] I have read
https://github.com/jaegertracing/jaeger/blob/master/CONTRIBUTING_GUIDELINES.md
- [ ] I have signed all commits
- [ ] I have added unit tests for the new functionality
- [ ] I have run lint and test steps successfully
  - for `jaeger`: `make lint test`
  - for `jaeger-ui`: `yarn lint` and `yarn test`

---------

Signed-off-by: chahatsagarmain <chahatsagar2003@gmail.com>
  • Loading branch information
chahatsagarmain authored Oct 29, 2024
1 parent 3c1e85d commit 0af8d35
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 67 deletions.
20 changes: 15 additions & 5 deletions cmd/query/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,23 @@ func createGRPCServerOTEL(
telset telemetery.Setting,
) (*grpc.Server, error) {
var grpcOpts []configgrpc.ToServerOption
unaryInterceptors := []grpc.UnaryServerInterceptor{
bearertoken.NewUnaryServerInterceptor(),
}
streamInterceptors := []grpc.StreamServerInterceptor{
bearertoken.NewStreamServerInterceptor(),
}

//nolint:contextcheck
if tm.Enabled {
//nolint:contextcheck
grpcOpts = append(grpcOpts,
configgrpc.WithGrpcServerOption(grpc.StreamInterceptor(tenancy.NewGuardingStreamInterceptor(tm))),
configgrpc.WithGrpcServerOption(grpc.UnaryInterceptor(tenancy.NewGuardingUnaryInterceptor(tm))),
)
unaryInterceptors = append(unaryInterceptors, tenancy.NewGuardingUnaryInterceptor(tm))
streamInterceptors = append(streamInterceptors, tenancy.NewGuardingStreamInterceptor(tm))
}

grpcOpts = append(grpcOpts,
configgrpc.WithGrpcServerOption(grpc.ChainUnaryInterceptor(unaryInterceptors...)),
configgrpc.WithGrpcServerOption(grpc.ChainStreamInterceptor(streamInterceptors...)),
)
return options.GRPC.ToServer(
ctx,
telset.Host,
Expand Down
2 changes: 1 addition & 1 deletion plugin/metrics/prometheus/metricsstore/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func getHTTPRoundTripper(c *config.Configuration, logger *zap.Logger) (rt http.R
}
token = tokenFromFile
}
return bearertoken.RoundTripper{
return &bearertoken.RoundTripper{
Transport: httpTransport,
OverrideFromCtx: c.TokenOverrideFromContext,
StaticToken: token,
Expand Down
4 changes: 2 additions & 2 deletions plugin/storage/grpc/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ func (f *Factory) newRemoteStorage(
if c.Auth != nil {
return nil, fmt.Errorf("authenticator is not supported")
}
tenancyMgr := tenancy.NewManager(&c.Tenancy)
unaryInterceptors := []grpc.UnaryClientInterceptor{
bearertoken.NewUnaryClientInterceptor(),
}
streamInterceptors := []grpc.StreamClientInterceptor{
tenancy.NewClientStreamInterceptor(tenancyMgr),
bearertoken.NewStreamClientInterceptor(),
}
tenancyMgr := tenancy.NewManager(&c.Tenancy)
if tenancyMgr.Enabled {
unaryInterceptors = append(unaryInterceptors, tenancy.NewClientUnaryInterceptor(tenancyMgr))
streamInterceptors = append(streamInterceptors, tenancy.NewClientStreamInterceptor(tenancyMgr))
Expand Down
2 changes: 1 addition & 1 deletion plugin/storage/grpc/shared/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type archiveWriter struct {

// GetTrace takes a traceID and returns a Trace associated with that traceID from Archive Storage
func (r *archiveReader) GetTrace(ctx context.Context, traceID model.TraceID) (*model.Trace, error) {
stream, err := r.client.GetArchiveTrace(upgradeContext(ctx), &storage_v1.GetTraceRequest{
stream, err := r.client.GetArchiveTrace(ctx, &storage_v1.GetTraceRequest{
TraceID: traceID,
})
if status.Code(err) == codes.NotFound {
Expand Down
45 changes: 5 additions & 40 deletions plugin/storage/grpc/shared/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ import (

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/jaegertracing/jaeger/model"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
_ "github.com/jaegertracing/jaeger/pkg/gogocodec" // force gogo codec registration
"github.com/jaegertracing/jaeger/proto-gen/storage_v1"
"github.com/jaegertracing/jaeger/storage/dependencystore"
Expand All @@ -30,9 +28,6 @@ var (
_ StoragePlugin = (*GRPCClient)(nil)
_ ArchiveStoragePlugin = (*GRPCClient)(nil)
_ PluginCapabilities = (*GRPCClient)(nil)

// upgradeContext composites several steps of upgrading context
upgradeContext = composeContextUpgradeFuncs(upgradeContextWithBearerToken)
)

// GRPCClient implements shared.StoragePlugin and reads/writes spans and dependencies
Expand All @@ -58,36 +53,6 @@ func NewGRPCClient(tracedConn *grpc.ClientConn, untracedConn *grpc.ClientConn) *
}
}

// ContextUpgradeFunc is a functional type that can be composed to upgrade context
type ContextUpgradeFunc func(ctx context.Context) context.Context

// composeContextUpgradeFuncs composes ContextUpgradeFunc and returns a composed function
// to run the given func in strict order.
func composeContextUpgradeFuncs(funcs ...ContextUpgradeFunc) ContextUpgradeFunc {
return func(ctx context.Context) context.Context {
for _, fun := range funcs {
ctx = fun(ctx)
}
return ctx
}
}

// upgradeContextWithBearerToken turns the context into a gRPC outgoing context with bearer token
// in the request metadata, if the original context has bearer token attached.
// Otherwise returns original context.
func upgradeContextWithBearerToken(ctx context.Context) context.Context {
bearerToken, hasToken := bearertoken.GetBearerToken(ctx)
if hasToken {
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = metadata.New(nil)
}
md.Set(BearerTokenKey, bearerToken)
return metadata.NewOutgoingContext(ctx, md)
}
return ctx
}

// DependencyReader implements shared.StoragePlugin.
func (c *GRPCClient) DependencyReader() dependencystore.Reader {
return c
Expand Down Expand Up @@ -117,7 +82,7 @@ func (c *GRPCClient) ArchiveSpanWriter() spanstore.Writer {

// GetTrace takes a traceID and returns a Trace associated with that traceID
func (c *GRPCClient) GetTrace(ctx context.Context, traceID model.TraceID) (*model.Trace, error) {
stream, err := c.readerClient.GetTrace(upgradeContext(ctx), &storage_v1.GetTraceRequest{
stream, err := c.readerClient.GetTrace(ctx, &storage_v1.GetTraceRequest{
TraceID: traceID,
})
if status.Code(err) == codes.NotFound {
Expand All @@ -132,7 +97,7 @@ func (c *GRPCClient) GetTrace(ctx context.Context, traceID model.TraceID) (*mode

// GetServices returns a list of all known services
func (c *GRPCClient) GetServices(ctx context.Context) ([]string, error) {
resp, err := c.readerClient.GetServices(upgradeContext(ctx), &storage_v1.GetServicesRequest{})
resp, err := c.readerClient.GetServices(ctx, &storage_v1.GetServicesRequest{})
if err != nil {
return nil, fmt.Errorf("plugin error: %w", err)
}
Expand All @@ -145,7 +110,7 @@ func (c *GRPCClient) GetOperations(
ctx context.Context,
query spanstore.OperationQueryParameters,
) ([]spanstore.Operation, error) {
resp, err := c.readerClient.GetOperations(upgradeContext(ctx), &storage_v1.GetOperationsRequest{
resp, err := c.readerClient.GetOperations(ctx, &storage_v1.GetOperationsRequest{
Service: query.ServiceName,
SpanKind: query.SpanKind,
})
Expand Down Expand Up @@ -173,7 +138,7 @@ func (c *GRPCClient) GetOperations(

// FindTraces retrieves traces that match the traceQuery
func (c *GRPCClient) FindTraces(ctx context.Context, query *spanstore.TraceQueryParameters) ([]*model.Trace, error) {
stream, err := c.readerClient.FindTraces(upgradeContext(ctx), &storage_v1.FindTracesRequest{
stream, err := c.readerClient.FindTraces(ctx, &storage_v1.FindTracesRequest{
Query: &storage_v1.TraceQueryParameters{
ServiceName: query.ServiceName,
OperationName: query.OperationName,
Expand Down Expand Up @@ -212,7 +177,7 @@ func (c *GRPCClient) FindTraces(ctx context.Context, query *spanstore.TraceQuery

// FindTraceIDs retrieves traceIDs that match the traceQuery
func (c *GRPCClient) FindTraceIDs(ctx context.Context, query *spanstore.TraceQueryParameters) ([]model.TraceID, error) {
resp, err := c.readerClient.FindTraceIDs(upgradeContext(ctx), &storage_v1.FindTraceIDsRequest{
resp, err := c.readerClient.FindTraceIDs(ctx, &storage_v1.FindTraceIDsRequest{
Query: &storage_v1.TraceQueryParameters{
ServiceName: query.ServiceName,
OperationName: query.OperationName,
Expand Down
18 changes: 0 additions & 18 deletions plugin/storage/grpc/shared/grpc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/jaegertracing/jaeger/model"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/proto-gen/storage_v1"
grpcMocks "github.com/jaegertracing/jaeger/proto-gen/storage_v1/mocks"
"github.com/jaegertracing/jaeger/storage/spanstore"
Expand Down Expand Up @@ -116,22 +114,6 @@ func TestNewGRPCClient(t *testing.T) {
assert.Implements(t, (*storage_v1.StreamingSpanWriterPluginClient)(nil), client.streamWriterClient)
}

func TestContextUpgradeWithToken(t *testing.T) {
testBearerToken := "test-bearer-token"
ctx := bearertoken.ContextWithBearerToken(context.Background(), testBearerToken)
upgradedToken := upgradeContextWithBearerToken(ctx)
md, ok := metadata.FromOutgoingContext(upgradedToken)
assert.Truef(t, ok, "Expected metadata in context")
bearerTokenFromMetadata := md.Get(BearerTokenKey)
assert.Equal(t, []string{testBearerToken}, bearerTokenFromMetadata)
}

func TestContextUpgradeWithoutToken(t *testing.T) {
upgradedToken := upgradeContextWithBearerToken(context.Background())
_, ok := metadata.FromOutgoingContext(upgradedToken)
assert.Falsef(t, ok, "Expected no metadata in context")
}

func TestGRPCClientGetServices(t *testing.T) {
withGRPCClient(func(r *grpcClientTest) {
r.spanReader.On("GetServices", mock.Anything, &storage_v1.GetServicesRequest{}).
Expand Down

0 comments on commit 0af8d35

Please sign in to comment.