Skip to content

Commit

Permalink
Handshaker component, part 3: Handshaker newConn impl and tests (#42)
Browse files Browse the repository at this point in the history
* created documents for handshaker

* added starter code

* removed HS test, added skeleton

* fixed merge conflicts

* fixed PR issues, changed comments, implemented new methods

* added unit tests for handshaker.go, added testutil

* fixed comments to handshaker, fixed pr issues

* fixed handshaker_test pr issues.

* removed testutil, added fakeconn

* fixed misspellings, pr issues.

* fixed comments, pr issues

* fixed pr issues, comments

* started impl methods

* updated comments to match part 1

* synced new code with merged

* fixed comments, pr issues for handshaker

* added tests for client/server handshaker

* added tests to handshaker_test

* added comments, fixed PeerNotRespondingError test

* changed frameLimit, added tests

* some pr issues resolved

* added error checks

* added fake s2a, fixed comments, names

* fixed comments, refactored variables, fixed pr issues

* fixed test variables

* added authinfo check, fixed pr issues

* added LocalCertFingerprint and PeerCertFingerprint checks

* added newConn impl and tests to handshaker

* removed grpc from import

* fixed comments, pr issues

* fixed merge conflict

* fixed comments, pr issues

* rearranged functions, changed tests, fixed pr issues

* fixed formatting in record

* fixed pr issues, comments

* added descriptions to exported fields

* added hsaddr to s2ahandshaker struct and new

* fixed pr issues
  • Loading branch information
davisgu authored Jun 25, 2020
1 parent 24d5b31 commit b639256
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 90 deletions.
54 changes: 36 additions & 18 deletions security/s2a/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/security/s2a/internal/authinfo"
s2apb "google.golang.org/grpc/security/s2a/internal/proto"
"google.golang.org/grpc/security/s2a/internal/record"
)

var (
Expand Down Expand Up @@ -87,43 +88,47 @@ type s2aHandshaker struct {
serverOpts *ServerHandshakerOptions
// isClient determines if the handshaker is client or server side
isClient bool
// HandshakerServiceAddress stores the address of the S2A handshaker service.
hsAddr string
}

// NewClientHandshaker creates an s2aHandshaker instance that performs a
// client-side TLS handshake using the S2A handshaker service.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (*s2aHandshaker, error) {
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) (*s2aHandshaker, error) {
stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
if err != nil {
return nil, err
}
return newClientHandshaker(stream, c, opts), err
return newClientHandshaker(stream, c, hsAddr, opts), err
}

func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, opts *ClientHandshakerOptions) *s2aHandshaker {
func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) *s2aHandshaker {
return &s2aHandshaker{
stream: stream,
conn: c,
clientOpts: opts,
isClient: true,
hsAddr: hsAddr,
}
}

// NewServerHandshaker creates an s2aHandshaker instance that performs a
// server-side TLS handshake using the S2A handshaker service.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (*s2aHandshaker, error) {
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) (*s2aHandshaker, error) {
stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
if err != nil {
return nil, err
}
return newServerHandshaker(stream, c, opts), err
return newServerHandshaker(stream, c, hsAddr, opts), err
}

func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, opts *ServerHandshakerOptions) *s2aHandshaker {
func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) *s2aHandshaker {
return &s2aHandshaker{
stream: stream,
conn: c,
serverOpts: opts,
isClient: false,
hsAddr: hsAddr,
}
}

Expand Down Expand Up @@ -206,8 +211,8 @@ func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.Se
return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
}
}
// Calculate the extra unread bytes from the Session. Attempting to consume more
// than the bytes sent will throw an error.
// Calculate the extra unread bytes from the Session. Attempting to consume
// more than the bytes sent will throw an error.
var extra []byte
if req.GetServerStart() != nil {
if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
Expand All @@ -219,9 +224,22 @@ func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.Se
if err != nil {
return nil, nil, err
}
// TODO(gud): use the NewConn API to construct the record protocol when PR#29
// is merged.
return h.conn, result, nil
// Create a new TLS record protocol using the Session Result.
newConn, err := record.NewConn(&record.ConnOptions{
NetConn: h.conn,
Ciphersuite: result.GetState().GetTlsCiphersuite(),
TLSVersion: result.GetState().GetTlsVersion(),
InTrafficSecret: result.GetState().GetInKey(),
OutTrafficSecret: result.GetState().GetOutKey(),
UnusedBuf: extra,
InSequence: result.GetState().GetInSequence(),
OutSequence: result.GetState().GetOutSequence(),
HsAddr: h.hsAddr,
})
if err != nil {
return nil, nil, err
}
return newConn, result, nil
}

// accessHandshakerService sends the session request to the S2A Handshaker
Expand Down Expand Up @@ -255,16 +273,16 @@ func (h *s2aHandshaker) processUntilDone(resp *s2apb.SessionResp, unusedBytes []
if err != nil && err != io.EOF {
return nil, nil, err
}
// If there is nothing to send to the handshaker service and
// nothing is received from the peer, then we are stuck.
// This covers the case when the peer is not responding. Note
// that handshaker service connection issues are caught in
// accessHandshakerService before we even get here.
// If there is nothing to send to the handshaker service and nothing is
// received from the peer, then we are stuck. This covers the case when
// the peer is not responding. Note that handshaker service connection
// issues are caught in accessHandshakerService before we even get
// here.
if len(resp.OutFrames) == 0 && n == 0 {
return nil, nil, peerNotRespondingError
}
// Append extra bytes from the previous interaction with the
// handshaker service with the current buffer read from conn.
// Append extra bytes from the previous interaction with the handshaker
// service with the current buffer read from conn.
p := append(unusedBytes, buf[:n]...)
// From here on, p and unusedBytes point to the same slice.
resp, err = h.accessHandshakerService(&s2apb.SessionReq{
Expand Down
27 changes: 18 additions & 9 deletions security/s2a/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ import (

"golang.org/x/sync/errgroup"
"google.golang.org/grpc"

s2apb "google.golang.org/grpc/security/s2a/internal/proto"
)

var (
// testHSAddr is the handshaker service address used for testing
testHSAddr = "handshaker_address"

// testClientHandshakerOptions are the client-side handshaker options used for
// testing.
testClientHandshakerOptions = &ClientHandshakerOptions{
Expand Down Expand Up @@ -142,7 +144,7 @@ type fakeStream struct {
grpc.ClientStream
t *testing.T
// expectedResp is the expected SessionResp message from the handshaker
// service.
// service.
expectedResp *s2apb.SessionResp
// isFirstAccess indicates whether the first call to the handshaker service
// has been made or not.
Expand Down Expand Up @@ -219,7 +221,7 @@ func (fc *fakeConn) Read(b []byte) (n int, err error) { return fc.in.Read(b) }
func (fc *fakeConn) Write(b []byte) (n int, err error) { return fc.out.Write(b) }
func (fc *fakeConn) Close() error { return nil }

// fakeInvalidConn is a fake implementation of a invalid net.Conn interface//
// fakeInvalidConn is a fake implementation of a invalid net.Conn interface
// that is used for testing.
type fakeInvalidConn struct {
net.Conn
Expand All @@ -239,7 +241,7 @@ func TestNewClientHandshaker(t *testing.T) {
in: in,
out: new(bytes.Buffer),
}
chs := newClientHandshaker(stream, c, testClientHandshakerOptions)
chs := newClientHandshaker(stream, c, testHSAddr, testClientHandshakerOptions)
if chs.clientOpts != testClientHandshakerOptions || chs.conn != c {
t.Errorf("handshaker parameters incorrect")
}
Expand All @@ -255,7 +257,7 @@ func TestNewServerHandshaker(t *testing.T) {
in: in,
out: new(bytes.Buffer),
}
shs := newServerHandshaker(stream, c, testServerHandshakerOptions)
shs := newServerHandshaker(stream, c, testHSAddr, testServerHandshakerOptions)
if shs.serverOpts != testServerHandshakerOptions || shs.conn != c {
t.Errorf("handshaker parameters incorrect")
}
Expand All @@ -280,12 +282,12 @@ func TestClientHandshake(t *testing.T) {
conn: c,
clientOpts: testClientHandshakerOptions,
isClient: true,
hsAddr: testHSAddr,
}
result := testClientSessionResult
errg.Go(func() error {
// Returned conn is ignored until record protocol is implemented.
// TODO(gud): Add tests for returned conn.
_, auth, err := chs.ClientHandshake(context.Background())
newConn, auth, err := chs.ClientHandshake(context.Background())
if err != nil {
return err
}
Expand All @@ -298,6 +300,9 @@ func TestClientHandshake(t *testing.T) {
!bytes.Equal(auth.PeerCertFingerprint(), result.GetPeerCertFingerprint()) {
return errors.New("Authinfo s2a context incorrect")
}
if newConn == nil {
return errors.New("Expected non-nil net.Conn")
}
chs.Close()
return nil
})
Expand Down Expand Up @@ -326,13 +331,13 @@ func TestServerHandshake(t *testing.T) {
conn: c,
serverOpts: testServerHandshakerOptions,
isClient: false,
hsAddr: testHSAddr,
}
result := testServerSessionResult
errg.Go(func() error {
// The conn returned by ServerHandshake is ignored until record protocol
// is implemented.
// TODO(gud): Add tests for returned conn.
_, auth, err := shs.ServerHandshake(context.Background())
newConn, auth, err := shs.ServerHandshake(context.Background())
if err != nil {
return err
}
Expand All @@ -345,6 +350,9 @@ func TestServerHandshake(t *testing.T) {
!bytes.Equal(auth.PeerCertFingerprint(), result.GetPeerCertFingerprint()) {
return errors.New("Authinfo s2a context incorrect")
}
if newConn == nil {
return errors.New("Expected non-nil net.Conn")
}
shs.Close()
return nil
})
Expand Down Expand Up @@ -381,6 +389,7 @@ func TestPeerNotResponding(t *testing.T) {
conn: c,
clientOpts: testClientHandshakerOptions,
isClient: true,
hsAddr: testHSAddr,
}
_, context, err := chs.ClientHandshake(context.Background())
chs.Close()
Expand Down
46 changes: 31 additions & 15 deletions security/s2a/internal/record/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package record
import (
"errors"
"fmt"
"net"

s2apb "google.golang.org/grpc/security/s2a/internal/proto"
"google.golang.org/grpc/security/s2a/internal/record/internal/halfconn"
"net"
)

const (
Expand Down Expand Up @@ -55,27 +56,42 @@ type conn struct {

// ConnOptions holds the options used for creating a new conn object.
type ConnOptions struct {
netConn net.Conn
ciphersuite s2apb.Ciphersuite
tlsVersion s2apb.TLSVersion
inTrafficSecret, outTrafficSecret, unusedBuf []byte
inSequence, outSequence uint64
hsAddr string
// NetConn is the current TLS record.
NetConn net.Conn
// Ciphersuite is the TLS ciphersuite negotiated by the S2A's handshaker
// module.
Ciphersuite s2apb.Ciphersuite
// TLSVersion is the TLS version number that the S2A's handshaker module
// used to set up the session.
TLSVersion s2apb.TLSVersion
// InTrafficSecret is the key for the in bound direction.
InTrafficSecret []byte
// OutTrafficSecret is the key for the out bound direction.
OutTrafficSecret []byte
// UnusedBuf is the data read from the network that has not yet been
// decrypted.
UnusedBuf []byte
// InSequence is the sequence number of the next, incoming, TLS record.
InSequence uint64
// OutSequence is the sequence number of the next, outgoing, TLS record.
OutSequence uint64
// hsAddr stores the address of the S2A handshaker service.
HsAddr string
}

func NewConn(o *ConnOptions) (net.Conn, error) {
if o == nil {
return nil, errors.New("conn options must not be nil")
}
if o.tlsVersion != s2apb.TLSVersion_TLS1_3 {
if o.TLSVersion != s2apb.TLSVersion_TLS1_3 {
return nil, errors.New("TLS version must be TLS 1.3")
}

inConn, err := halfconn.New(o.ciphersuite, o.inTrafficSecret, o.inSequence)
inConn, err := halfconn.New(o.Ciphersuite, o.InTrafficSecret, o.InSequence)
if err != nil {
return nil, fmt.Errorf("failed to create inbound half connection: %v", err)
}
outConn, err := halfconn.New(o.ciphersuite, o.outTrafficSecret, o.outSequence)
outConn, err := halfconn.New(o.Ciphersuite, o.OutTrafficSecret, o.OutSequence)
if err != nil {
return nil, fmt.Errorf("failed to create outbound half connection: %v", err)
}
Expand All @@ -84,20 +100,20 @@ func NewConn(o *ConnOptions) (net.Conn, error) {
overheadSize := tlsRecordHeaderSize + tlsRecordTypeSize + inConn.TagSize()
var unusedBuf []byte
// TODO(gud): Potentially optimize unusedBuf with pre-allocation.
if o.unusedBuf != nil {
unusedBuf = make([]byte, len(o.unusedBuf))
copy(unusedBuf, o.unusedBuf)
if o.UnusedBuf != nil {
unusedBuf = make([]byte, len(o.UnusedBuf))
copy(unusedBuf, o.UnusedBuf)
}

s2aConn := &conn{
Conn: o.netConn,
Conn: o.NetConn,
inConn: inConn,
outConn: outConn,
unusedBuf: unusedBuf,
outRecordsBuf: make([]byte, outBufSize),
nextRecord: unusedBuf,
overheadSize: overheadSize,
hsAddr: o.hsAddr,
hsAddr: o.HsAddr,
}
return s2aConn, nil
}
Expand Down
Loading

0 comments on commit b639256

Please sign in to comment.