Skip to content

Commit

Permalink
Add support for sending session tickets to S2A
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryanfsdf committed Aug 13, 2020
1 parent fd854f8 commit a0f18c8
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 8 deletions.
2 changes: 2 additions & 0 deletions security/s2a/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.Se
InSequence: result.GetState().GetInSequence(),
OutSequence: result.GetState().GetOutSequence(),
HSAddr: h.hsAddr,
ConnectionID: result.GetState().GetConnectionId(),
LocalIdentity: result.GetLocalIdentity(),
})
if err != nil {
return nil, nil, err
Expand Down
20 changes: 15 additions & 5 deletions security/s2a/internal/record/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,6 @@ type conn struct {
// is computed as overheadSize = header size + record type byte + tag size.
// Note that there is no padding by zeros in the overhead calculation.
overheadSize int
// hsAddr stores the address of the S2A handshaker service.
hsAddr string
// readMutex guards against concurrent calls to Read. This is required since
// Close may be called during a Read.
readMutex sync.Mutex
Expand All @@ -196,6 +194,8 @@ type conn struct {
// sessionTickets holds the completed session tickets until they are sent to
// the handshaker service for processing.
sessionTickets [][]byte
// ticketSender sends session tickets to the S2A handshaker service.
ticketSender s2aTicketSender
}

// ConnParameters holds the parameters used for creating a new conn object.
Expand Down Expand Up @@ -228,6 +228,12 @@ type ConnParameters struct {
// HSAddr stores the address of the S2A handshaker service. This parameter
// is optional. If not provided, then TLS resumption is disabled.
HSAddr string
// ConnectionId is the connection identifier that was created and sent by
// S2A at the end of a handshake.
ConnectionID uint64
// LocalIdentity is the local identity that was used by S2A during session
// setup and included in the session result.
LocalIdentity *s2apb.Identity
}

func NewConn(o *ConnParameters) (net.Conn, error) {
Expand Down Expand Up @@ -272,7 +278,6 @@ func NewConn(o *ConnParameters) (net.Conn, error) {
outRecordsBuf: make([]byte, tlsRecordMaxSize),
nextRecord: unusedBuf,
overheadSize: overheadSize,
hsAddr: o.HSAddr,
ticketState: ticketsNotYetReceived,
// Pre-allocate the buffer for one session ticket message and the max
// plaintext size. This is the largest size that handshakeBuf will need
Expand All @@ -283,6 +288,11 @@ func NewConn(o *ConnParameters) (net.Conn, error) {
// completed. Therefore, the buffer size below should be large enough to
// buffer any handshake messages.
handshakeBuf: make([]byte, 0, tlsHandshakePrefixSize+tlsMaxSessionTicketSize+tlsRecordMaxPlaintextSize-1),
ticketSender: &ticketSender{
hsAddr: o.HSAddr,
connectionID: o.ConnectionID,
localIdentity: o.LocalIdentity,
},
}
return s2aConn, nil
}
Expand Down Expand Up @@ -353,7 +363,7 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
if p.ticketState == receivingTickets {
p.ticketState = notReceivingTickets
// TODO: send tickets to handshaker
p.ticketSender.sendTicketsToS2A(p.sessionTickets)
}
case alert:
return 0, p.handleAlertMessage()
Expand Down Expand Up @@ -656,7 +666,7 @@ func (p *conn) handleHandshakeMessage() error {
p.sessionTickets = append(p.sessionTickets, msg)
if len(p.sessionTickets) == maxAllowedTickets {
p.ticketState = notReceivingTickets
// TODO: send tickets to handshaker
p.ticketSender.sendTicketsToS2A(p.sessionTickets)
}
default:
return errors.New("unknown handshake message type")
Expand Down
149 changes: 146 additions & 3 deletions security/s2a/internal/record/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,23 @@ func (c *fakeConn) Close() error {
return nil
}

type fakeTicketSender struct {
sessionTickets [][]byte
}

func (f *fakeTicketSender) sendTicketsToS2A(sessionTickets [][]byte) {
f.sessionTickets = sessionTickets
}

func TestNewS2ARecordConn(t *testing.T) {
for _, tc := range []struct {
desc string
options *ConnParameters
outUnusedBytesBuf []byte
outOverheadSize int
outHandshakerServiceAddr string
outConnectionID uint64
outLocalIdentity *s2apb.Identity
outErr bool
}{
{
Expand Down Expand Up @@ -131,11 +141,23 @@ func TestNewS2ARecordConn(t *testing.T) {
InTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
OutTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
HSAddr: "test handshaker address",
ConnectionID: 1,
LocalIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_SpiffeId{
SpiffeId: "test_spiffe_id",
},
},
},
// outOverheadSize = header size (5) + record type byte (1) +
// tag size (16).
outOverheadSize: 22,
outHandshakerServiceAddr: "test handshaker address",
outConnectionID: 1,
outLocalIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_SpiffeId{
SpiffeId: "test_spiffe_id",
},
},
},
{
desc: "basic with AES-256-GCM-SHA384",
Expand All @@ -146,11 +168,23 @@ func TestNewS2ARecordConn(t *testing.T) {
InTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
OutTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
HSAddr: "test handshaker address",
ConnectionID: 1,
LocalIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_SpiffeId{
SpiffeId: "test_spiffe_id",
},
},
},
// outOverheadSize = header size (5) + record type byte (1) +
// tag size (16).
outOverheadSize: 22,
outHandshakerServiceAddr: "test handshaker address",
outConnectionID: 1,
outLocalIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_SpiffeId{
SpiffeId: "test_spiffe_id",
},
},
},
{
desc: "basic with CHACHA20-POLY1305-SHA256",
Expand All @@ -161,11 +195,23 @@ func TestNewS2ARecordConn(t *testing.T) {
InTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
OutTrafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
HSAddr: "test handshaker address",
ConnectionID: 1,
LocalIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_SpiffeId{
SpiffeId: "test_spiffe_id",
},
},
},
// outOverheadSize = header size (5) + record type byte (1) +
// tag size (16).
outOverheadSize: 22,
outHandshakerServiceAddr: "test handshaker address",
outConnectionID: 1,
outLocalIdentity: &s2apb.Identity{
IdentityOneof: &s2apb.Identity_SpiffeId{
SpiffeId: "test_spiffe_id",
},
},
},
{
desc: "basic with unusedBytes",
Expand Down Expand Up @@ -200,8 +246,15 @@ func TestNewS2ARecordConn(t *testing.T) {
if got, want := conn.overheadSize, tc.outOverheadSize; got != want {
t.Errorf("conn.overheadSize = %v, want %v", got, want)
}
if got, want := conn.hsAddr, tc.outHandshakerServiceAddr; got != want {
t.Errorf("conn.HSAddr = %v, want %v", got, want)
ticketSender := conn.ticketSender.(*ticketSender)
if got, want := ticketSender.hsAddr, tc.outHandshakerServiceAddr; got != want {
t.Errorf("ticketSender.hsAddr = %v, want %v", got, want)
}
if got, want := ticketSender.connectionID, tc.outConnectionID; got != want {
t.Errorf("ticketSender.connectionID = %v, want %v", got, want)
}
if got, want := ticketSender.localIdentity, tc.outLocalIdentity; !cmp.Equal(got, want) {
t.Errorf("ticketSender.localIdentity = %v, want %v", got, want)
}
})
}
Expand Down Expand Up @@ -1161,6 +1214,7 @@ func TestConnNewSessionTicket(t *testing.T) {
outPlaintexts [][]byte
finalTicketState sessionTicketState
outSessionTickets [][]byte
ticketsSent bool
}{
// All the session tickets below are []byte{0}. This is not a valid
// ticket, but is sufficient for testing since the the client does not
Expand Down Expand Up @@ -1226,6 +1280,7 @@ func TestConnNewSessionTicket(t *testing.T) {
outSessionTickets: [][]byte{
{0},
},
ticketsSent: true,
},
{
desc: "AES-256-GCM-SHA384 new session ticket followed by application data",
Expand All @@ -1243,6 +1298,7 @@ func TestConnNewSessionTicket(t *testing.T) {
outSessionTickets: [][]byte{
{0},
},
ticketsSent: true,
},
{
desc: "CHACHA20-POLY1305-SHA256 new session ticket followed by application data",
Expand All @@ -1260,6 +1316,7 @@ func TestConnNewSessionTicket(t *testing.T) {
outSessionTickets: [][]byte{
{0},
},
ticketsSent: true,
},
{
desc: "AES-128-GCM-SHA256 ticket, application data, then another ticket",
Expand All @@ -1279,6 +1336,7 @@ func TestConnNewSessionTicket(t *testing.T) {
outSessionTickets: [][]byte{
{0},
},
ticketsSent: true,
},
{
desc: "AES-256-GCM-SHA384 ticket, application data, then another ticket",
Expand All @@ -1298,6 +1356,7 @@ func TestConnNewSessionTicket(t *testing.T) {
outSessionTickets: [][]byte{
{0},
},
ticketsSent: true,
},
{
desc: "CHACHA20-POLY1305-SHA256 ticket, application data, then another ticket",
Expand All @@ -1317,6 +1376,7 @@ func TestConnNewSessionTicket(t *testing.T) {
outSessionTickets: [][]byte{
{0},
},
ticketsSent: true,
},
// TODO(rnkim): Add test cases for handshake key update messages.
// Specifically, fragmented handshake messages and multiple handshake
Expand All @@ -1334,6 +1394,10 @@ func TestConnNewSessionTicket(t *testing.T) {
if err != nil {
t.Fatalf("NewConn() failed: %v", err)
}
newConn := c.(*conn)
// Replace the ticket sender with a fake.
fakeTicketSender := &fakeTicketSender{}
newConn.ticketSender = fakeTicketSender
for _, outPlaintext := range tc.outPlaintexts {
plaintext := make([]byte, tlsRecordMaxPlaintextSize)
n, err := c.Read(plaintext)
Expand All @@ -1348,7 +1412,6 @@ func TestConnNewSessionTicket(t *testing.T) {
t.Errorf("len(c.(*conn).pendingApplicationData) = %v, want %v", got, want)
}
}
newConn := c.(*conn)
if got, want := newConn.ticketState, tc.finalTicketState; got != want {
t.Errorf("newConn.ticketState = %v, want %v", got, want)
}
Expand All @@ -1358,6 +1421,11 @@ func TestConnNewSessionTicket(t *testing.T) {
if got, want := newConn.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) {
t.Errorf("newConn.sessionTickets = %v, want %v", got, want)
}
if tc.ticketsSent {
if got, want := fakeTicketSender.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) {
t.Errorf("fakeTicketSender.sessionTickets = %v, want %v", got, want)
}
}
})
}
}
Expand All @@ -1370,6 +1438,7 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) {
sessionTickets [][]byte
finalTicketState sessionTicketState
outSessionTickets [][]byte
ticketsSent bool
}{
{
desc: "AES-128-GCM-SHA256 consecutive tickets",
Expand Down Expand Up @@ -1566,6 +1635,72 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) {
[]byte("abc"),
},
},
{
desc: "AES-128-GCM-SHA256 past max limit",
ciphersuite: s2apb.Ciphersuite_AES_128_GCM_SHA256,
trafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
sessionTickets: [][]byte{
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
},
finalTicketState: notReceivingTickets,
outSessionTickets: [][]byte{
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
},
ticketsSent: true,
},
{
desc: "AES-256-GCM-SHA384 past max limit",
ciphersuite: s2apb.Ciphersuite_AES_256_GCM_SHA384,
trafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
sessionTickets: [][]byte{
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
},
finalTicketState: notReceivingTickets,
outSessionTickets: [][]byte{
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
},
ticketsSent: true,
},
{
desc: "CHACHA20-POLY1305-SHA256 past max limit",
ciphersuite: s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256,
trafficSecret: testutil.Dehex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"),
sessionTickets: [][]byte{
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
buildSessionTicket([]byte("abc")),
},
finalTicketState: notReceivingTickets,
outSessionTickets: [][]byte{
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
[]byte("abc"),
},
ticketsSent: true,
},
} {
t.Run(tc.desc, func(t *testing.T) {
fc := &fakeConn{}
Expand All @@ -1580,6 +1715,9 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) {
t.Fatalf("NewConn() failed: %v", err)
}
c := netConn.(*conn)
// Replace the ticket sender with a fake.
fakeTicketSender := &fakeTicketSender{}
c.ticketSender = fakeTicketSender
for _, sessionTicket := range tc.sessionTickets {
_, err = c.writeTLSRecord(sessionTicket, byte(handshake))
if err != nil {
Expand Down Expand Up @@ -1613,6 +1751,11 @@ func TestConnNewSessionTicketWithTicketBuilder(t *testing.T) {
if got, want := c.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) {
t.Errorf("newConn.sessionTickets = %v, want %v", got, want)
}
if tc.ticketsSent {
if got, want := fakeTicketSender.sessionTickets, tc.outSessionTickets; !cmp.Equal(got, want) {
t.Errorf("fakeTicketSender.sessionTickets = %v, want %v", got, want)
}
}
})
}
}
Expand Down
Loading

0 comments on commit a0f18c8

Please sign in to comment.