diff --git a/client/v3/client.go b/client/v3/client.go index 10c00f612fa1..4c6f75f3699e 100644 --- a/client/v3/client.go +++ b/client/v3/client.go @@ -69,7 +69,7 @@ type Client struct { Username string // Password is a password for authentication. Password string - authTokenBundle credentials.Bundle + authTokenBundle credentials.PerRPCCredentialsBundle callOpts []grpc.CallOption @@ -338,7 +338,7 @@ func (c *Client) credentialsForEndpoint(ep string) grpccredentials.TransportCred if c.creds != nil { return c.creds } - return credentials.NewBundle(credentials.Config{}).TransportCredentials() + return credentials.NewTransportCredential(nil) default: panic(fmt.Errorf("unsupported CredsRequirement: %v", r)) } @@ -350,7 +350,7 @@ func newClient(cfg *Config) (*Client, error) { } var creds grpccredentials.TransportCredentials if cfg.TLS != nil { - creds = credentials.NewBundle(credentials.Config{TLSConfig: cfg.TLS}).TransportCredentials() + creds = credentials.NewTransportCredential(cfg.TLS) } // use a temporary skeleton client to bootstrap first connection @@ -389,7 +389,7 @@ func newClient(cfg *Config) (*Client, error) { if cfg.Username != "" && cfg.Password != "" { client.Username = cfg.Username client.Password = cfg.Password - client.authTokenBundle = credentials.NewBundle(credentials.Config{}) + client.authTokenBundle = credentials.NewPerRPCCredentialBundle() } if cfg.MaxCallSendMsgSize > 0 || cfg.MaxCallRecvMsgSize > 0 { if cfg.MaxCallRecvMsgSize > 0 && cfg.MaxCallSendMsgSize > cfg.MaxCallRecvMsgSize { diff --git a/client/v3/credentials/credentials.go b/client/v3/credentials/credentials.go index 024c16b6048d..d98ce9ff1c97 100644 --- a/client/v3/credentials/credentials.go +++ b/client/v3/credentials/credentials.go @@ -19,7 +19,6 @@ package credentials import ( "context" "crypto/tls" - "net" "sync" grpccredentials "google.golang.org/grpc/credentials" @@ -27,85 +26,43 @@ import ( "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" ) -// Config defines gRPC credential configuration. -type Config struct { - TLSConfig *tls.Config +func NewTransportCredential(cfg *tls.Config) grpccredentials.TransportCredentials { + return grpccredentials.NewTLS(cfg) } -// Bundle defines gRPC credential interface. -type Bundle interface { - grpccredentials.Bundle +// PerRPCCredentialsBundle defines gRPC credential interface. +type PerRPCCredentialsBundle interface { UpdateAuthToken(token string) + PerRPCCredentials() grpccredentials.PerRPCCredentials } -// NewBundle constructs a new gRPC credential bundle. -func NewBundle(cfg Config) Bundle { - return &bundle{ - tc: newTransportCredential(cfg.TLSConfig), - rc: newPerRPCCredential(), +func NewPerRPCCredentialBundle() PerRPCCredentialsBundle { + return &perRPCCredentialBundle{ + rc: &perRPCCredential{}, } } -// bundle implements "grpccredentials.Bundle" interface. -type bundle struct { - tc *transportCredential +type perRPCCredentialBundle struct { rc *perRPCCredential } -func (b *bundle) TransportCredentials() grpccredentials.TransportCredentials { - return b.tc -} - -func (b *bundle) PerRPCCredentials() grpccredentials.PerRPCCredentials { - return b.rc -} - -func (b *bundle) NewWithMode(mode string) (grpccredentials.Bundle, error) { - // no-op - return nil, nil -} - -// transportCredential implements "grpccredentials.TransportCredentials" interface. -type transportCredential struct { - gtc grpccredentials.TransportCredentials -} - -func newTransportCredential(cfg *tls.Config) *transportCredential { - return &transportCredential{ - gtc: grpccredentials.NewTLS(cfg), - } -} - -func (tc *transportCredential) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) { - return tc.gtc.ClientHandshake(ctx, authority, rawConn) -} - -func (tc *transportCredential) ServerHandshake(rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) { - return tc.gtc.ServerHandshake(rawConn) -} - -func (tc *transportCredential) Info() grpccredentials.ProtocolInfo { - return tc.gtc.Info() -} - -func (tc *transportCredential) Clone() grpccredentials.TransportCredentials { - return &transportCredential{ - gtc: tc.gtc.Clone(), +func (b *perRPCCredentialBundle) UpdateAuthToken(token string) { + if b.rc == nil { + return } + b.rc.UpdateAuthToken(token) } -func (tc *transportCredential) OverrideServerName(serverNameOverride string) error { - return tc.gtc.OverrideServerName(serverNameOverride) +func (b *perRPCCredentialBundle) PerRPCCredentials() grpccredentials.PerRPCCredentials { + return b.rc } -// perRPCCredential implements "grpccredentials.PerRPCCredentials" interface. +// perRPCCredential implements `PerRPCCredentialsWrapper` interface. type perRPCCredential struct { authToken string authTokenMu sync.RWMutex } -func newPerRPCCredential() *perRPCCredential { return &perRPCCredential{} } - func (rc *perRPCCredential) RequireTransportSecurity() bool { return false } func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) { @@ -118,13 +75,6 @@ func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) return map[string]string{rpctypes.TokenFieldNameGRPC: authToken}, nil } -func (b *bundle) UpdateAuthToken(token string) { - if b.rc == nil { - return - } - b.rc.UpdateAuthToken(token) -} - func (rc *perRPCCredential) UpdateAuthToken(token string) { rc.authTokenMu.Lock() rc.authToken = token diff --git a/client/v3/credentials/credentials_test.go b/client/v3/credentials/credentials_test.go index 5111a2ad5ece..dcf06cf15685 100644 --- a/client/v3/credentials/credentials_test.go +++ b/client/v3/credentials/credentials_test.go @@ -24,7 +24,7 @@ import ( ) func TestUpdateAuthToken(t *testing.T) { - bundle := NewBundle(Config{}) + bundle := NewPerRPCCredentialBundle() ctx := context.TODO() metadataBeforeUpdate, _ := bundle.PerRPCCredentials().GetRequestMetadata(ctx) diff --git a/server/embed/etcd.go b/server/embed/etcd.go index 594e11ec385c..f8563dadeb6d 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -797,8 +797,7 @@ func (e *Etcd) grpcGatewayDial(splitHttp bool) (grpcDial func(ctx context.Contex dtls := tlscfg.Clone() // trust local server dtls.InsecureSkipVerify = true - bundle := credentials.NewBundle(credentials.Config{TLSConfig: dtls}) - opts = append(opts, grpc.WithTransportCredentials(bundle.TransportCredentials())) + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTransportCredential(dtls))) } else { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } diff --git a/server/etcdserver/api/v3rpc/grpc.go b/server/etcdserver/api/v3rpc/grpc.go index 349ebea40074..148914e0f9a9 100644 --- a/server/etcdserver/api/v3rpc/grpc.go +++ b/server/etcdserver/api/v3rpc/grpc.go @@ -39,8 +39,7 @@ func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnarySer var opts []grpc.ServerOption opts = append(opts, grpc.CustomCodec(&codec{})) if tls != nil { - bundle := credentials.NewBundle(credentials.Config{TLSConfig: tls}) - opts = append(opts, grpc.Creds(bundle.TransportCredentials())) + opts = append(opts, grpc.Creds(credentials.NewTransportCredential(tls))) } chainUnaryInterceptors := []grpc.UnaryServerInterceptor{ newLogUnaryInterceptor(s),