forked from matthewstevenson88/grpc-go
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add S2A Public APIs (matthewstevenson88#48)
- Loading branch information
Showing
4 changed files
with
738 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
package s2a | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"time" | ||
|
||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/security/s2a/internal/handshaker" | ||
"google.golang.org/grpc/security/s2a/internal/handshaker/service" | ||
s2apb "google.golang.org/grpc/security/s2a/internal/proto" | ||
) | ||
|
||
const ( | ||
s2aSecurityProtocol = "s2a" | ||
// defaultTimeout specifies the default server handshake timeout. | ||
defaultTimeout = 30.0 * time.Second | ||
) | ||
|
||
// s2aTransportCreds are the credentials required for establishing a secure | ||
// connection using the S2A handshaker service. It implements the | ||
// credentials.TransportCredentials interface. | ||
type s2aTransportCreds struct { | ||
info *credentials.ProtocolInfo | ||
minTLSVersion s2apb.TLSVersion | ||
maxTLSVersion s2apb.TLSVersion | ||
// tlsCiphersuites contains the ciphersuites used in the S2A connection. | ||
// Note that these are currently unconfigurable. | ||
tlsCiphersuites []s2apb.Ciphersuite | ||
// localIdentity should only be used by the client. | ||
localIdentity *s2apb.Identity | ||
// localIdentities should only be used by the server. | ||
localIdentities []*s2apb.Identity | ||
// targetIdentities should only be used by the client. | ||
targetIdentities []*s2apb.Identity | ||
isClient bool | ||
hsAddr string | ||
} | ||
|
||
// NewClientCreds returns a client-side transport credentials object that uses | ||
// the S2A handshaker service to establish a secure connection with a server. | ||
func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) { | ||
if opts == nil { | ||
return nil, errors.New("nil client options") | ||
} | ||
var targetIdentities []*s2apb.Identity | ||
for _, targetIdentity := range opts.TargetIdentities { | ||
protoTargetIdentity, err := toProtoIdentity(targetIdentity) | ||
if err != nil { | ||
return nil, err | ||
} | ||
targetIdentities = append(targetIdentities, protoTargetIdentity) | ||
} | ||
localIdentity, err := toProtoIdentity(opts.LocalIdentity) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &s2aTransportCreds{ | ||
info: &credentials.ProtocolInfo{ | ||
SecurityProtocol: s2aSecurityProtocol, | ||
}, | ||
minTLSVersion: s2apb.TLSVersion_TLS1_3, | ||
maxTLSVersion: s2apb.TLSVersion_TLS1_3, | ||
tlsCiphersuites: []s2apb.Ciphersuite{ | ||
s2apb.Ciphersuite_AES_128_GCM_SHA256, | ||
s2apb.Ciphersuite_AES_256_GCM_SHA384, | ||
s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256, | ||
}, | ||
localIdentity: localIdentity, | ||
targetIdentities: targetIdentities, | ||
isClient: true, | ||
hsAddr: opts.HandshakerServiceAddress, | ||
}, nil | ||
} | ||
|
||
// NewServerCreds returns a server-side transport credentials object that uses | ||
// the S2A handshaker service to establish a secure connection with a client. | ||
func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) { | ||
if opts == nil { | ||
return nil, errors.New("nil server options") | ||
} | ||
var localIdentities []*s2apb.Identity | ||
for _, localIdentity := range opts.LocalIdentities { | ||
protoLocalIdentity, err := toProtoIdentity(localIdentity) | ||
if err != nil { | ||
return nil, err | ||
} | ||
localIdentities = append(localIdentities, protoLocalIdentity) | ||
} | ||
return &s2aTransportCreds{ | ||
info: &credentials.ProtocolInfo{ | ||
SecurityProtocol: s2aSecurityProtocol, | ||
}, | ||
minTLSVersion: s2apb.TLSVersion_TLS1_3, | ||
maxTLSVersion: s2apb.TLSVersion_TLS1_3, | ||
tlsCiphersuites: []s2apb.Ciphersuite{ | ||
s2apb.Ciphersuite_AES_128_GCM_SHA256, | ||
s2apb.Ciphersuite_AES_256_GCM_SHA384, | ||
s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256, | ||
}, | ||
localIdentities: localIdentities, | ||
isClient: false, | ||
hsAddr: opts.HandshakerServiceAddress, | ||
}, nil | ||
} | ||
|
||
// ClientHandshake performs a client-side TLS handshake using the S2A handshaker | ||
// service. | ||
func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAddr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { | ||
if !c.isClient { | ||
return nil, nil, errors.New("client handshake called using server transport credentials") | ||
} | ||
|
||
// Connect to the S2A handshaker service. | ||
hsConn, err := service.Dial(c.hsAddr) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
var cancel context.CancelFunc | ||
ctx, cancel = context.WithCancel(ctx) | ||
defer cancel() | ||
|
||
opts := &handshaker.ClientHandshakerOptions{ | ||
MinTLSVersion: c.minTLSVersion, | ||
MaxTLSVersion: c.maxTLSVersion, | ||
TLSCiphersuites: c.tlsCiphersuites, | ||
TargetIdentities: c.targetIdentities, | ||
LocalIdentity: c.localIdentity, | ||
TargetName: serverAddr, | ||
} | ||
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.hsAddr, opts) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
defer func() { | ||
if err != nil { | ||
if closeErr := chs.Close(); closeErr != nil { | ||
err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr) | ||
} | ||
} | ||
}() | ||
|
||
secConn, authInfo, err := chs.ClientHandshake(context.Background()) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
return secConn, authInfo, nil | ||
} | ||
|
||
// ServerHandshake performs a server-side TLS handshake using the S2A handshaker | ||
// service. | ||
func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
if c.isClient { | ||
return nil, nil, errors.New("server handshake called using client transport credentials") | ||
} | ||
|
||
// Connect to the S2A handshaker service. | ||
hsConn, err := service.Dial(c.hsAddr) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) | ||
defer cancel() | ||
|
||
opts := &handshaker.ServerHandshakerOptions{ | ||
MinTLSVersion: c.minTLSVersion, | ||
MaxTLSVersion: c.maxTLSVersion, | ||
TLSCiphersuites: c.tlsCiphersuites, | ||
LocalIdentities: c.localIdentities, | ||
} | ||
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.hsAddr, opts) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
defer func() { | ||
if err != nil { | ||
if closeErr := shs.Close(); closeErr != nil { | ||
err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr) | ||
} | ||
} | ||
}() | ||
|
||
secConn, authInfo, err := shs.ServerHandshake(context.Background()) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
return secConn, authInfo, nil | ||
} | ||
|
||
func (c *s2aTransportCreds) Info() credentials.ProtocolInfo { | ||
return *c.info | ||
} | ||
|
||
func (c *s2aTransportCreds) Clone() credentials.TransportCredentials { | ||
info := *c.info | ||
var localIdentity *s2apb.Identity | ||
if c.localIdentity != nil { | ||
v := *c.localIdentity | ||
localIdentity = &v | ||
} | ||
var localIdentities []*s2apb.Identity | ||
if c.localIdentities != nil { | ||
localIdentities = make([]*s2apb.Identity, len(c.localIdentities)) | ||
for i, localIdentity := range c.localIdentities { | ||
v := *localIdentity | ||
localIdentities[i] = &v | ||
} | ||
} | ||
var targetIdentities []*s2apb.Identity | ||
if c.targetIdentities != nil { | ||
targetIdentities = make([]*s2apb.Identity, len(c.targetIdentities)) | ||
for i, targetIdentity := range c.targetIdentities { | ||
v := *targetIdentity | ||
targetIdentities[i] = &v | ||
} | ||
} | ||
return &s2aTransportCreds{ | ||
info: &info, | ||
minTLSVersion: c.minTLSVersion, | ||
maxTLSVersion: c.maxTLSVersion, | ||
tlsCiphersuites: c.tlsCiphersuites, | ||
localIdentity: localIdentity, | ||
localIdentities: localIdentities, | ||
targetIdentities: targetIdentities, | ||
isClient: c.isClient, | ||
hsAddr: c.hsAddr, | ||
} | ||
} | ||
|
||
func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error { | ||
c.info.ServerName = serverNameOverride | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package s2a | ||
|
||
import ( | ||
"errors" | ||
|
||
s2apb "google.golang.org/grpc/security/s2a/internal/proto" | ||
) | ||
|
||
// Identity is the interface for S2A identities. | ||
type Identity interface { | ||
// Name returns the name of the identity. | ||
Name() string | ||
} | ||
|
||
type spiffeID struct { | ||
spiffeID string | ||
} | ||
|
||
func (s *spiffeID) Name() string { return s.spiffeID } | ||
|
||
func NewSpiffeID(id string) Identity { | ||
return &spiffeID{spiffeID: id} | ||
} | ||
|
||
type hostname struct { | ||
hostname string | ||
} | ||
|
||
func (h *hostname) Name() string { return h.hostname } | ||
|
||
func NewHostname(name string) Identity { | ||
return &hostname{hostname: name} | ||
} | ||
|
||
// ClientOptions contains the client-side options used to establish a secure | ||
// channel using the S2A handshaker service. | ||
type ClientOptions struct { | ||
// TargetIdentities contains a list of allowed server identities. One of the | ||
// target identities should match the peer identity in the handshake | ||
// result; otherwise, the handshake fails. | ||
TargetIdentities []Identity | ||
// LocalIdentity is the local identity of the client application. If none is | ||
// provided, then the S2A will choose the default identity. | ||
LocalIdentity Identity | ||
// HandshakerServiceAddress is the address of the S2A handshaker service. | ||
HandshakerServiceAddress string | ||
} | ||
|
||
// DefaultClientOptions returns the default client options. | ||
func DefaultClientOptions(handshakerAddress string) *ClientOptions { | ||
return &ClientOptions{HandshakerServiceAddress: handshakerAddress} | ||
} | ||
|
||
// ServerOptions contains the server-side options used to establish a secure | ||
// channel using the S2A handshaker service. | ||
type ServerOptions struct { | ||
// LocalIdentities is the list of local identities that may be assumed by | ||
// the server. If no local identity is specified, then the S2A chooses a | ||
// default local identity. | ||
LocalIdentities []Identity | ||
// HandshakerServiceAddress is the address of the S2A handshaker service. | ||
HandshakerServiceAddress string | ||
} | ||
|
||
// DefaultServerOptions returns the default server options. | ||
func DefaultServerOptions(handshakerAddress string) *ServerOptions { | ||
return &ServerOptions{HandshakerServiceAddress: handshakerAddress} | ||
} | ||
|
||
func toProtoIdentity(identity Identity) (*s2apb.Identity, error) { | ||
switch id := identity.(type) { | ||
case *spiffeID: | ||
return &s2apb.Identity{IdentityOneof: &s2apb.Identity_SpiffeId{SpiffeId: id.Name()}}, nil | ||
case *hostname: | ||
return &s2apb.Identity{IdentityOneof: &s2apb.Identity_Hostname{Hostname: id.Name()}}, nil | ||
default: | ||
return nil, errors.New("unrecognized identity type") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package s2a | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
s2apb "google.golang.org/grpc/security/s2a/internal/proto" | ||
) | ||
|
||
func TestToProtoIdentity(t *testing.T) { | ||
for _, tc := range []struct { | ||
identity Identity | ||
outIdentity *s2apb.Identity | ||
}{ | ||
{ | ||
identity: NewSpiffeID("test_spiffe_id"), | ||
outIdentity: &s2apb.Identity{ | ||
IdentityOneof: &s2apb.Identity_SpiffeId{SpiffeId: "test_spiffe_id"}, | ||
}, | ||
}, | ||
{ | ||
identity: NewHostname("test_hostname"), | ||
outIdentity: &s2apb.Identity{ | ||
IdentityOneof: &s2apb.Identity_Hostname{Hostname: "test_hostname"}, | ||
}, | ||
}, | ||
} { | ||
t.Run(tc.outIdentity.String(), func(t *testing.T) { | ||
protoSpiffeID, err := toProtoIdentity(tc.identity) | ||
if err != nil { | ||
t.Errorf("toProtoIdentity(%v) failed: %v", tc.identity, err) | ||
} | ||
if got, want := protoSpiffeID, tc.outIdentity; !cmp.Equal(got, want) { | ||
t.Errorf("toProtoIdentity(%v) = %v, want %v", tc.outIdentity, got, want) | ||
} | ||
}) | ||
} | ||
} |
Oops, something went wrong.