diff --git a/transport/grpc/dial.go b/transport/grpc/dial.go index cb7fa5aef98..2f6359f2921 100644 --- a/transport/grpc/dial.go +++ b/transport/grpc/dial.go @@ -53,6 +53,9 @@ var logRateLimiter = rate.Sometimes{Interval: 1 * time.Second} // Assign to var for unit test replacement var dialContext = grpc.DialContext +// Assign to var for unit test replacement +var dialContextNewAuth = grpctransport.Dial + // otelStatsHandler is a singleton otelgrpc.clientHandler to be used across // all dial connections to avoid the memory leak documented in // https://github.com/open-telemetry/opentelemetry-go-contrib/issues/4226 @@ -218,12 +221,12 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna defaultEndpointTemplate = ds.DefaultEndpoint } - pool, err := grpctransport.Dial(ctx, secure, &grpctransport.Options{ + pool, err := dialContextNewAuth(ctx, secure, &grpctransport.Options{ DisableTelemetry: ds.TelemetryDisabled, DisableAuthentication: ds.NoAuth, Endpoint: ds.Endpoint, Metadata: metadata, - GRPCDialOpts: ds.GRPCDialOpts, + GRPCDialOpts: prepareDialOptsNewAuth(ds), PoolSize: poolSize, Credentials: creds, APIKey: ds.APIKey, @@ -248,6 +251,15 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna return pool, err } +func prepareDialOptsNewAuth(ds *internal.DialSettings) []grpc.DialOption { + var opts []grpc.DialOption + if ds.UserAgent != "" { + opts = append(opts, grpc.WithUserAgent(ds.UserAgent)) + } + + return append(opts, ds.GRPCDialOpts...) +} + func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.ClientConn, error) { if o.HTTPClient != nil { return nil, errors.New("unsupported HTTP client specified") diff --git a/transport/grpc/dial_test.go b/transport/grpc/dial_test.go index b767854d220..bf19d02cd37 100644 --- a/transport/grpc/dial_test.go +++ b/transport/grpc/dial_test.go @@ -11,6 +11,7 @@ import ( "strings" "testing" + "cloud.google.com/go/auth/grpctransport" "cloud.google.com/go/compute/metadata" "github.com/google/go-cmp/cmp" "golang.org/x/oauth2/google" @@ -35,6 +36,73 @@ func TestDial(t *testing.T) { dial(context.Background(), false, &o) } +func TestDialPoolNewAuthDialOptions(t *testing.T) { + oldDialContextNewAuth := dialContextNewAuth + var wantNumOpts int + // Replace package var in order to assert DialContext args. + dialContextNewAuth = func(ctx context.Context, secure bool, opts *grpctransport.Options) (grpctransport.GRPCClientConnPool, error) { + if len(opts.GRPCDialOpts) != wantNumOpts { + t.Fatalf("got: %d, want: %d", len(opts.GRPCDialOpts), wantNumOpts) + } + return nil, nil + } + defer func() { + dialContextNewAuth = oldDialContextNewAuth + }() + + for _, testcase := range []struct { + name string + ds *internal.DialSettings + wantNumOpts int + }{ + { + name: "no dial options", + ds: &internal.DialSettings{}, + wantNumOpts: 0, + }, + { + name: "with user agent", + ds: &internal.DialSettings{ + UserAgent: "test", + }, + wantNumOpts: 1, + }, + } { + t.Run(testcase.name, func(t *testing.T) { + wantNumOpts = testcase.wantNumOpts + dialPoolNewAuth(context.Background(), false, 1, testcase.ds) + }) + } +} + +func TestPrepareDialOptsNewAuth(t *testing.T) { + for _, testcase := range []struct { + name string + ds *internal.DialSettings + wantNumOpts int + }{ + { + name: "empty", + ds: &internal.DialSettings{}, + wantNumOpts: 0, + }, + { + name: "user agent", + ds: &internal.DialSettings{ + UserAgent: "test", + }, + wantNumOpts: 1, + }, + } { + t.Run(testcase.name, func(t *testing.T) { + got := prepareDialOptsNewAuth(testcase.ds) + if len(got) != testcase.wantNumOpts { + t.Fatalf("got %d options, want %d options", len(got), testcase.wantNumOpts) + } + }) + } +} + func TestCheckDirectPathEndPoint(t *testing.T) { for _, testcase := range []struct { name string