Skip to content

Commit

Permalink
Add S2A Public APIs (matthewstevenson88#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryanfsdf authored Jul 15, 2020
1 parent c6a7197 commit b1ce5d4
Show file tree
Hide file tree
Showing 4 changed files with 738 additions and 0 deletions.
237 changes: 237 additions & 0 deletions security/s2a/s2a.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
package s2a

import (
"context"
"errors"
"fmt"
"net"
"time"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/security/s2a/internal/handshaker"
"google.golang.org/grpc/security/s2a/internal/handshaker/service"
s2apb "google.golang.org/grpc/security/s2a/internal/proto"
)

const (
s2aSecurityProtocol = "s2a"
// defaultTimeout specifies the default server handshake timeout.
defaultTimeout = 30.0 * time.Second
)

// s2aTransportCreds are the credentials required for establishing a secure
// connection using the S2A handshaker service. It implements the
// credentials.TransportCredentials interface.
type s2aTransportCreds struct {
info *credentials.ProtocolInfo
minTLSVersion s2apb.TLSVersion
maxTLSVersion s2apb.TLSVersion
// tlsCiphersuites contains the ciphersuites used in the S2A connection.
// Note that these are currently unconfigurable.
tlsCiphersuites []s2apb.Ciphersuite
// localIdentity should only be used by the client.
localIdentity *s2apb.Identity
// localIdentities should only be used by the server.
localIdentities []*s2apb.Identity
// targetIdentities should only be used by the client.
targetIdentities []*s2apb.Identity
isClient bool
hsAddr string
}

// NewClientCreds returns a client-side transport credentials object that uses
// the S2A handshaker service to establish a secure connection with a server.
func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
if opts == nil {
return nil, errors.New("nil client options")
}
var targetIdentities []*s2apb.Identity
for _, targetIdentity := range opts.TargetIdentities {
protoTargetIdentity, err := toProtoIdentity(targetIdentity)
if err != nil {
return nil, err
}
targetIdentities = append(targetIdentities, protoTargetIdentity)
}
localIdentity, err := toProtoIdentity(opts.LocalIdentity)
if err != nil {
return nil, err
}
return &s2aTransportCreds{
info: &credentials.ProtocolInfo{
SecurityProtocol: s2aSecurityProtocol,
},
minTLSVersion: s2apb.TLSVersion_TLS1_3,
maxTLSVersion: s2apb.TLSVersion_TLS1_3,
tlsCiphersuites: []s2apb.Ciphersuite{
s2apb.Ciphersuite_AES_128_GCM_SHA256,
s2apb.Ciphersuite_AES_256_GCM_SHA384,
s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256,
},
localIdentity: localIdentity,
targetIdentities: targetIdentities,
isClient: true,
hsAddr: opts.HandshakerServiceAddress,
}, nil
}

// NewServerCreds returns a server-side transport credentials object that uses
// the S2A handshaker service to establish a secure connection with a client.
func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
if opts == nil {
return nil, errors.New("nil server options")
}
var localIdentities []*s2apb.Identity
for _, localIdentity := range opts.LocalIdentities {
protoLocalIdentity, err := toProtoIdentity(localIdentity)
if err != nil {
return nil, err
}
localIdentities = append(localIdentities, protoLocalIdentity)
}
return &s2aTransportCreds{
info: &credentials.ProtocolInfo{
SecurityProtocol: s2aSecurityProtocol,
},
minTLSVersion: s2apb.TLSVersion_TLS1_3,
maxTLSVersion: s2apb.TLSVersion_TLS1_3,
tlsCiphersuites: []s2apb.Ciphersuite{
s2apb.Ciphersuite_AES_128_GCM_SHA256,
s2apb.Ciphersuite_AES_256_GCM_SHA384,
s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256,
},
localIdentities: localIdentities,
isClient: false,
hsAddr: opts.HandshakerServiceAddress,
}, nil
}

// ClientHandshake performs a client-side TLS handshake using the S2A handshaker
// service.
func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAddr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
if !c.isClient {
return nil, nil, errors.New("client handshake called using server transport credentials")
}

// Connect to the S2A handshaker service.
hsConn, err := service.Dial(c.hsAddr)
if err != nil {
return nil, nil, err
}

var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()

opts := &handshaker.ClientHandshakerOptions{
MinTLSVersion: c.minTLSVersion,
MaxTLSVersion: c.maxTLSVersion,
TLSCiphersuites: c.tlsCiphersuites,
TargetIdentities: c.targetIdentities,
LocalIdentity: c.localIdentity,
TargetName: serverAddr,
}
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.hsAddr, opts)
if err != nil {
return nil, nil, err
}
defer func() {
if err != nil {
if closeErr := chs.Close(); closeErr != nil {
err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
}
}
}()

secConn, authInfo, err := chs.ClientHandshake(context.Background())
if err != nil {
return nil, nil, err
}
return secConn, authInfo, nil
}

// ServerHandshake performs a server-side TLS handshake using the S2A handshaker
// service.
func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if c.isClient {
return nil, nil, errors.New("server handshake called using client transport credentials")
}

// Connect to the S2A handshaker service.
hsConn, err := service.Dial(c.hsAddr)
if err != nil {
return nil, nil, err
}

ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()

opts := &handshaker.ServerHandshakerOptions{
MinTLSVersion: c.minTLSVersion,
MaxTLSVersion: c.maxTLSVersion,
TLSCiphersuites: c.tlsCiphersuites,
LocalIdentities: c.localIdentities,
}
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.hsAddr, opts)
if err != nil {
return nil, nil, err
}
defer func() {
if err != nil {
if closeErr := shs.Close(); closeErr != nil {
err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
}
}
}()

secConn, authInfo, err := shs.ServerHandshake(context.Background())
if err != nil {
return nil, nil, err
}
return secConn, authInfo, nil
}

func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
return *c.info
}

func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
info := *c.info
var localIdentity *s2apb.Identity
if c.localIdentity != nil {
v := *c.localIdentity
localIdentity = &v
}
var localIdentities []*s2apb.Identity
if c.localIdentities != nil {
localIdentities = make([]*s2apb.Identity, len(c.localIdentities))
for i, localIdentity := range c.localIdentities {
v := *localIdentity
localIdentities[i] = &v
}
}
var targetIdentities []*s2apb.Identity
if c.targetIdentities != nil {
targetIdentities = make([]*s2apb.Identity, len(c.targetIdentities))
for i, targetIdentity := range c.targetIdentities {
v := *targetIdentity
targetIdentities[i] = &v
}
}
return &s2aTransportCreds{
info: &info,
minTLSVersion: c.minTLSVersion,
maxTLSVersion: c.maxTLSVersion,
tlsCiphersuites: c.tlsCiphersuites,
localIdentity: localIdentity,
localIdentities: localIdentities,
targetIdentities: targetIdentities,
isClient: c.isClient,
hsAddr: c.hsAddr,
}
}

func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
c.info.ServerName = serverNameOverride
return nil
}
79 changes: 79 additions & 0 deletions security/s2a/s2a_options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package s2a

import (
"errors"

s2apb "google.golang.org/grpc/security/s2a/internal/proto"
)

// Identity is the interface for S2A identities.
type Identity interface {
// Name returns the name of the identity.
Name() string
}

type spiffeID struct {
spiffeID string
}

func (s *spiffeID) Name() string { return s.spiffeID }

func NewSpiffeID(id string) Identity {
return &spiffeID{spiffeID: id}
}

type hostname struct {
hostname string
}

func (h *hostname) Name() string { return h.hostname }

func NewHostname(name string) Identity {
return &hostname{hostname: name}
}

// ClientOptions contains the client-side options used to establish a secure
// channel using the S2A handshaker service.
type ClientOptions struct {
// TargetIdentities contains a list of allowed server identities. One of the
// target identities should match the peer identity in the handshake
// result; otherwise, the handshake fails.
TargetIdentities []Identity
// LocalIdentity is the local identity of the client application. If none is
// provided, then the S2A will choose the default identity.
LocalIdentity Identity
// HandshakerServiceAddress is the address of the S2A handshaker service.
HandshakerServiceAddress string
}

// DefaultClientOptions returns the default client options.
func DefaultClientOptions(handshakerAddress string) *ClientOptions {
return &ClientOptions{HandshakerServiceAddress: handshakerAddress}
}

// ServerOptions contains the server-side options used to establish a secure
// channel using the S2A handshaker service.
type ServerOptions struct {
// LocalIdentities is the list of local identities that may be assumed by
// the server. If no local identity is specified, then the S2A chooses a
// default local identity.
LocalIdentities []Identity
// HandshakerServiceAddress is the address of the S2A handshaker service.
HandshakerServiceAddress string
}

// DefaultServerOptions returns the default server options.
func DefaultServerOptions(handshakerAddress string) *ServerOptions {
return &ServerOptions{HandshakerServiceAddress: handshakerAddress}
}

func toProtoIdentity(identity Identity) (*s2apb.Identity, error) {
switch id := identity.(type) {
case *spiffeID:
return &s2apb.Identity{IdentityOneof: &s2apb.Identity_SpiffeId{SpiffeId: id.Name()}}, nil
case *hostname:
return &s2apb.Identity{IdentityOneof: &s2apb.Identity_Hostname{Hostname: id.Name()}}, nil
default:
return nil, errors.New("unrecognized identity type")
}
}
38 changes: 38 additions & 0 deletions security/s2a/s2a_options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package s2a

import (
"testing"

"github.com/google/go-cmp/cmp"
s2apb "google.golang.org/grpc/security/s2a/internal/proto"
)

func TestToProtoIdentity(t *testing.T) {
for _, tc := range []struct {
identity Identity
outIdentity *s2apb.Identity
}{
{
identity: NewSpiffeID("test_spiffe_id"),
outIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_SpiffeId{SpiffeId: "test_spiffe_id"},
},
},
{
identity: NewHostname("test_hostname"),
outIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_Hostname{Hostname: "test_hostname"},
},
},
} {
t.Run(tc.outIdentity.String(), func(t *testing.T) {
protoSpiffeID, err := toProtoIdentity(tc.identity)
if err != nil {
t.Errorf("toProtoIdentity(%v) failed: %v", tc.identity, err)
}
if got, want := protoSpiffeID, tc.outIdentity; !cmp.Equal(got, want) {
t.Errorf("toProtoIdentity(%v) = %v, want %v", tc.outIdentity, got, want)
}
})
}
}
Loading

0 comments on commit b1ce5d4

Please sign in to comment.