From 725135dd3af960bb3c1e50838de58d8f9b0aa7d6 Mon Sep 17 00:00:00 2001 From: Ryan Kim Date: Wed, 12 Aug 2020 19:09:21 +0000 Subject: [PATCH] Add support for sending session tickets to S2A --- .../s2a/internal/handshaker/handshaker.go | 2 + security/s2a/internal/record/record.go | 103 ++++++++- security/s2a/internal/record/record_test.go | 216 +++++++++++++++++- 3 files changed, 313 insertions(+), 8 deletions(-) diff --git a/security/s2a/internal/handshaker/handshaker.go b/security/s2a/internal/handshaker/handshaker.go index 380f87374299..9ead4667ee01 100644 --- a/security/s2a/internal/handshaker/handshaker.go +++ b/security/s2a/internal/handshaker/handshaker.go @@ -239,6 +239,8 @@ func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.Se InSequence: result.GetState().GetInSequence(), OutSequence: result.GetState().GetOutSequence(), HSAddr: h.hsAddr, + ConnectionID: result.GetState().GetConnectionId(), + LocalIdentity: result.GetLocalIdentity(), }) if err != nil { return nil, nil, err diff --git a/security/s2a/internal/record/record.go b/security/s2a/internal/record/record.go index 64a6e667455a..ad42a2d7f805 100644 --- a/security/s2a/internal/record/record.go +++ b/security/s2a/internal/record/record.go @@ -21,13 +21,18 @@ package record import ( + "context" "encoding/binary" "errors" "fmt" "math" "net" "sync" + "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/security/s2a/internal/handshaker/service" s2apb "google.golang.org/grpc/security/s2a/internal/proto" "google.golang.org/grpc/security/s2a/internal/record/internal/halfconn" ) @@ -180,8 +185,6 @@ type conn struct { // is computed as overheadSize = header size + record type byte + tag size. // Note that there is no padding by zeros in the overhead calculation. overheadSize int - // hsAddr stores the address of the S2A handshaker service. - hsAddr string // readMutex guards against concurrent calls to Read. This is required since // Close may be called during a Read. readMutex sync.Mutex @@ -196,6 +199,8 @@ type conn struct { // sessionTickets holds the completed session tickets until they are sent to // the handshaker service for processing. sessionTickets [][]byte + // ticketSender sends session tickets to the S2A handshaker service. + ticketSender s2aTicketSender } // ConnParameters holds the parameters used for creating a new conn object. @@ -228,6 +233,12 @@ type ConnParameters struct { // HSAddr stores the address of the S2A handshaker service. This parameter // is optional. If not provided, then TLS resumption is disabled. HSAddr string + // ConnectionId is the connection identifier that was created and sent by + // S2A at the end of a handshake. + ConnectionID uint64 + // LocalIdentity is the local identity that was used by S2A during session + // setup and included in the session result. + LocalIdentity *s2apb.Identity } func NewConn(o *ConnParameters) (net.Conn, error) { @@ -272,7 +283,6 @@ func NewConn(o *ConnParameters) (net.Conn, error) { outRecordsBuf: make([]byte, tlsRecordMaxSize), nextRecord: unusedBuf, overheadSize: overheadSize, - hsAddr: o.HSAddr, ticketState: ticketsNotYetReceived, // Pre-allocate the buffer for one session ticket message and the max // plaintext size. This is the largest size that handshakeBuf will need @@ -283,6 +293,11 @@ func NewConn(o *ConnParameters) (net.Conn, error) { // completed. Therefore, the buffer size below should be large enough to // buffer any handshake messages. handshakeBuf: make([]byte, 0, tlsHandshakePrefixSize+tlsMaxSessionTicketSize+tlsRecordMaxPlaintextSize-1), + ticketSender: &ticketSender{ + hsAddr: o.HSAddr, + connectionID: o.ConnectionID, + localIdentity: o.LocalIdentity, + }, } return s2aConn, nil } @@ -353,7 +368,7 @@ func (p *conn) Read(b []byte) (n int, err error) { } if p.ticketState == receivingTickets { p.ticketState = notReceivingTickets - // TODO: send tickets to handshaker + p.ticketSender.sendTicketsToS2A(p.sessionTickets) } case alert: return 0, p.handleAlertMessage() @@ -656,7 +671,7 @@ func (p *conn) handleHandshakeMessage() error { p.sessionTickets = append(p.sessionTickets, msg) if len(p.sessionTickets) == maxAllowedTickets { p.ticketState = notReceivingTickets - // TODO: send tickets to handshaker + p.ticketSender.sendTicketsToS2A(p.sessionTickets) } default: return errors.New("unknown handshake message type") @@ -703,6 +718,84 @@ func (p *conn) handleKeyUpdateMsg(msg []byte) error { return nil } +// s2aTicketSender sends session tickets to the S2A handshaker service. +type s2aTicketSender interface { + // sendTicketsToS2A sends the given session tickets to the S2A handshaker + // service. + sendTicketsToS2A(sessionTickets [][]byte) +} + +// ticketStream is the stream used to send and receive session information. +type ticketStream interface { + Send(*s2apb.SessionReq) error + Recv() (*s2apb.SessionResp, error) +} + +type ticketSender struct { + // hsAddr stores the address of the S2A handshaker service. + hsAddr string + // connectionID is the connection identifier that was created and sent by + // S2A at the end of a handshake. + connectionID uint64 + // localIdentity is the local identity that was used by S2A during session + // setup and included in the session result. + localIdentity *s2apb.Identity +} + +// sendTicketsToS2A sends the given sessionTickets to the S2A handshaker +// service. This is done asynchronously and writes to the error logs if an error +// occurs. +func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte) { + go func() { + if err := func() error { + hsConn, err := service.Dial(t.hsAddr) + if err != nil { + return err + } + client := s2apb.NewS2AServiceClient(hsConn) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + session, err := client.SetUpSession(ctx) + if err != nil { + return err + } + defer func() { + if err := session.CloseSend(); err != nil { + grpclog.Error(err) + } + }() + return t.writeTicketsToStream(session, sessionTickets) + }(); err != nil { + grpclog.Error(err) + } + }() +} + +// writeTicketsToStream writes the given session tickets to the given stream. +func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error { + if err := stream.Send( + &s2apb.SessionReq{ + ReqOneof: &s2apb.SessionReq_ResumptionTicket{ + ResumptionTicket: &s2apb.ResumptionTicketReq{ + InBytes: sessionTickets, + ConnectionId: t.connectionID, + LocalIdentity: t.localIdentity, + }, + }, + }, + ); err != nil { + return err + } + sessionResp, err := stream.Recv() + if err != nil { + return err + } + if sessionResp.GetStatus().GetCode() != uint32(codes.OK) { + return errors.New("s2a session ticket response was not OK") + } + return nil +} + // bidEndianInt24 converts the given byte buffer of at least size 3 and // outputs the resulting 24 bit integer as a uint32. This is needed because // TLS 1.3 requires 3 byte integers, and the binary.BigEndian package does diff --git a/security/s2a/internal/record/record_test.go b/security/s2a/internal/record/record_test.go index 62f091baa4eb..f425fc0638cd 100644 --- a/security/s2a/internal/record/record_test.go +++ b/security/s2a/internal/record/record_test.go @@ -26,6 +26,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/codes" s2apb "google.golang.org/grpc/security/s2a/internal/proto" "google.golang.org/grpc/security/s2a/internal/record/internal/aeadcrypter/testutil" ) @@ -73,6 +74,50 @@ func (c *fakeConn) Close() error { return nil } +type fakeTicketSender struct { + sessionTickets [][]byte +} + +func (f *fakeTicketSender) sendTicketsToS2A(sessionTickets [][]byte) { + f.sessionTickets = sessionTickets +} + +type fakeStream struct { + // returnInvalid is a flag indicating whether the return status of Recv is + // OK or not. + returnInvalid bool + // returnRecvErr is a flag indicating whether an error should be returned by + // Recv. + returnRecvErr bool +} + +func (fs *fakeStream) Send(req *s2apb.SessionReq) error { + if len(req.GetResumptionTicket().InBytes) == 0 { + return errors.New("fakeStream Send received an empty InBytes") + } + if req.GetResumptionTicket().ConnectionId == 0 { + return errors.New("fakeStream Send received a 0 ConnectionId") + } + if req.GetResumptionTicket().LocalIdentity == nil { + return errors.New("fakeStream Send received an empty LocalIdentity") + } + return nil +} + +func (fs *fakeStream) Recv() (*s2apb.SessionResp, error) { + if fs.returnRecvErr { + return nil, errors.New("fakeStream Recv error") + } + if fs.returnInvalid { + return &s2apb.SessionResp{ + Status: &s2apb.SessionStatus{Code: uint32(codes.InvalidArgument)}, + }, nil + } + return &s2apb.SessionResp{ + Status: &s2apb.SessionStatus{Code: uint32(codes.OK)}, + }, nil +} + func TestNewS2ARecordConn(t *testing.T) { for _, tc := range []struct { desc string @@ -80,6 +125,8 @@ func TestNewS2ARecordConn(t *testing.T) { outUnusedBytesBuf []byte outOverheadSize int outHandshakerServiceAddr string + outConnectionID uint64 + outLocalIdentity *s2apb.Identity outErr bool }{ { @@ -131,11 +178,23 @@ func TestNewS2ARecordConn(t *testing.T) { InTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), OutTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), HSAddr: "test handshaker address", + ConnectionID: 1, + LocalIdentity: &s2apb.Identity{ + IdentityOneof: &s2apb.Identity_SpiffeId{ + SpiffeId: "test_spiffe_id", + }, + }, }, // outOverheadSize = header size (5) + record type byte (1) + // tag size (16). outOverheadSize: 22, outHandshakerServiceAddr: "test handshaker address", + outConnectionID: 1, + outLocalIdentity: &s2apb.Identity{ + IdentityOneof: &s2apb.Identity_SpiffeId{ + SpiffeId: "test_spiffe_id", + }, + }, }, { desc: "basic with AES-256-GCM-SHA384", @@ -146,11 +205,23 @@ func TestNewS2ARecordConn(t *testing.T) { InTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), OutTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), HSAddr: "test handshaker address", + ConnectionID: 1, + LocalIdentity: &s2apb.Identity{ + IdentityOneof: &s2apb.Identity_SpiffeId{ + SpiffeId: "test_spiffe_id", + }, + }, }, // outOverheadSize = header size (5) + record type byte (1) + // tag size (16). outOverheadSize: 22, outHandshakerServiceAddr: "test handshaker address", + outConnectionID: 1, + outLocalIdentity: &s2apb.Identity{ + IdentityOneof: &s2apb.Identity_SpiffeId{ + SpiffeId: "test_spiffe_id", + }, + }, }, { desc: "basic with CHACHA20-POLY1305-SHA256", @@ -161,11 +232,23 @@ func TestNewS2ARecordConn(t *testing.T) { InTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), OutTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), HSAddr: "test handshaker address", + ConnectionID: 1, + LocalIdentity: &s2apb.Identity{ + IdentityOneof: &s2apb.Identity_SpiffeId{ + SpiffeId: "test_spiffe_id", + }, + }, }, // outOverheadSize = header size (5) + record type byte (1) + // tag size (16). outOverheadSize: 22, outHandshakerServiceAddr: "test handshaker address", + outConnectionID: 1, + outLocalIdentity: &s2apb.Identity{ + IdentityOneof: &s2apb.Identity_SpiffeId{ + SpiffeId: "test_spiffe_id", + }, + }, }, { desc: "basic with unusedBytes", @@ -200,8 +283,15 @@ func TestNewS2ARecordConn(t *testing.T) { if got, want := conn.overheadSize, tc.outOverheadSize; got != want { t.Errorf("conn.overheadSize = %v, want %v", got, want) } - if got, want := conn.hsAddr, tc.outHandshakerServiceAddr; got != want { - t.Errorf("conn.HSAddr = %v, want %v", got, want) + ticketSender := conn.ticketSender.(*ticketSender) + if got, want := ticketSender.hsAddr, tc.outHandshakerServiceAddr; got != want { + t.Errorf("ticketSender.hsAddr = %v, want %v", got, want) + } + if got, want := ticketSender.connectionID, tc.outConnectionID; got != want { + t.Errorf("ticketSender.connectionID = %v, want %v", got, want) + } + if got, want := ticketSender.localIdentity, tc.outLocalIdentity; !cmp.Equal(got, want) { + t.Errorf("ticketSender.localIdentity = %v, want %v", got, want) } }) } @@ -1161,6 +1251,7 @@ func TestConnNewSessionTicket(t *testing.T) { outPlaintexts [][]byte finalTicketState sessionTicketState outSessionTickets [][]byte + ticketsSent bool }{ // All the session tickets below are []byte{0}. This is not a valid // ticket, but is sufficient for testing since the the client does not @@ -1226,6 +1317,7 @@ func TestConnNewSessionTicket(t *testing.T) { outSessionTickets: [][]byte{ {0}, }, + ticketsSent: true, }, { desc: "AES-256-GCM-SHA384 new session ticket followed by application data", @@ -1243,6 +1335,7 @@ func TestConnNewSessionTicket(t *testing.T) { outSessionTickets: [][]byte{ {0}, }, + ticketsSent: true, }, { desc: "CHACHA20-POLY1305-SHA256 new session ticket followed by application data", @@ -1260,6 +1353,7 @@ func TestConnNewSessionTicket(t *testing.T) { outSessionTickets: [][]byte{ {0}, }, + ticketsSent: true, }, { desc: "AES-128-GCM-SHA256 ticket, application data, then another ticket", @@ -1279,6 +1373,7 @@ func TestConnNewSessionTicket(t *testing.T) { outSessionTickets: [][]byte{ {0}, }, + ticketsSent: true, }, { desc: "AES-256-GCM-SHA384 ticket, application data, then another ticket", @@ -1298,6 +1393,7 @@ func TestConnNewSessionTicket(t *testing.T) { outSessionTickets: [][]byte{ {0}, }, + ticketsSent: true, }, { desc: "CHACHA20-POLY1305-SHA256 ticket, application data, then another ticket", @@ -1317,6 +1413,7 @@ func TestConnNewSessionTicket(t *testing.T) { outSessionTickets: [][]byte{ {0}, }, + ticketsSent: true, }, // TODO(rnkim): Add test cases for handshake key update messages. // Specifically, fragmented handshake messages and multiple handshake @@ -1334,6 +1431,10 @@ func TestConnNewSessionTicket(t *testing.T) { if err != nil { t.Fatalf("NewConn() failed: %v", err) } + newConn := c.(*conn) + // Replace the ticket sender with a fake. + fakeTicketSender := &fakeTicketSender{} + newConn.ticketSender = fakeTicketSender for _, outPlaintext := range tc.outPlaintexts { plaintext := make([]byte, tlsRecordMaxPlaintextSize) n, err := c.Read(plaintext) @@ -1348,7 +1449,6 @@ func TestConnNewSessionTicket(t *testing.T) { t.Errorf("len(c.(*conn).pendingApplicationData) = %v, want %v", got, want) } } - newConn := c.(*conn) if got, want := newConn.ticketState, tc.finalTicketState; got != want { t.Errorf("newConn.ticketState = %v, want %v", got, want) } @@ -1358,6 +1458,11 @@ func TestConnNewSessionTicket(t *testing.T) { if got, want := newConn.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) { t.Errorf("newConn.sessionTickets = %v, want %v", got, want) } + if tc.ticketsSent { + if got, want := fakeTicketSender.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) { + t.Errorf("fakeTicketSender.sessionTickets = %v, want %v", got, want) + } + } }) } } @@ -1370,6 +1475,7 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) { sessionTickets [][]byte finalTicketState sessionTicketState outSessionTickets [][]byte + ticketsSent bool }{ { desc: "AES-128-GCM-SHA256 consecutive tickets", @@ -1566,6 +1672,72 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) { []byte("abc"), }, }, + { + desc: "AES-128-GCM-SHA256 past max limit", + ciphersuite: s2apb.Ciphersuite_AES_128_GCM_SHA256, + trafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), + sessionTickets: [][]byte{ + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + }, + finalTicketState: notReceivingTickets, + outSessionTickets: [][]byte{ + []byte("abc"), + []byte("abc"), + []byte("abc"), + []byte("abc"), + []byte("abc"), + }, + ticketsSent: true, + }, + { + desc: "AES-256-GCM-SHA384 past max limit", + ciphersuite: s2apb.Ciphersuite_AES_256_GCM_SHA384, + trafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), + sessionTickets: [][]byte{ + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + }, + finalTicketState: notReceivingTickets, + outSessionTickets: [][]byte{ + []byte("abc"), + []byte("abc"), + []byte("abc"), + []byte("abc"), + []byte("abc"), + }, + ticketsSent: true, + }, + { + desc: "CHACHA20-POLY1305-SHA256 past max limit", + ciphersuite: s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256, + trafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"), + sessionTickets: [][]byte{ + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + buildSessionTicket([]byte("abc")), + }, + finalTicketState: notReceivingTickets, + outSessionTickets: [][]byte{ + []byte("abc"), + []byte("abc"), + []byte("abc"), + []byte("abc"), + []byte("abc"), + }, + ticketsSent: true, + }, } { t.Run(tc.desc, func(t *testing.T) { fc := &fakeConn{} @@ -1580,6 +1752,9 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) { t.Fatalf("NewConn() failed: %v", err) } c := netConn.(*conn) + // Replace the ticket sender with a fake. + fakeTicketSender := &fakeTicketSender{} + c.ticketSender = fakeTicketSender for _, sessionTicket := range tc.sessionTickets { _, err = c.writeTLSRecord(sessionTicket, byte(handshake)) if err != nil { @@ -1613,10 +1788,45 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) { if got, want := c.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) { t.Errorf("newConn.sessionTickets = %v, want %v", got, want) } + if tc.ticketsSent { + if got, want := fakeTicketSender.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) { + t.Errorf("fakeTicketSender.sessionTickets = %v, want %v", got, want) + } + } }) } } +func TestWriteTicketsToStream(t *testing.T) { + for _, tc := range []struct { + returnInvalid bool + returnRecvError bool + }{ + { + // Both flags are set to false. + }, + { + returnInvalid: true, + }, + { + returnRecvError: true, + }, + } { + sender := ticketSender{ + connectionID: 1, + localIdentity: &s2apb.Identity{ + IdentityOneof: &s2apb.Identity_SpiffeId{ + SpiffeId: "test_spiffe_id", + }, + }, + } + fs := &fakeStream{returnInvalid: tc.returnInvalid, returnRecvErr: tc.returnRecvError} + if got, want := sender.writeTicketsToStream(fs, make([][]byte, 1)) == nil, !tc.returnRecvError && !tc.returnInvalid; got != want { + t.Errorf("sender.writeTicketsToStream(%v, _) = (err=nil) = %v, want %v", fs, got, want) + } + } +} + func TestWrite(t *testing.T) { for _, tc := range []struct { desc string