Skip to content

Commit

Permalink
Add support for sending session tickets to S2A
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryanfsdf committed Aug 13, 2020
1 parent fd854f8 commit 725135d
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 8 deletions.
2 changes: 2 additions & 0 deletions security/s2a/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.Se
InSequence: result.GetState().GetInSequence(),
OutSequence: result.GetState().GetOutSequence(),
HSAddr: h.hsAddr,
ConnectionID: result.GetState().GetConnectionId(),
LocalIdentity: result.GetLocalIdentity(),
})
if err != nil {
return nil, nil, err
Expand Down
103 changes: 98 additions & 5 deletions security/s2a/internal/record/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@
package record

import (
"context"
"encoding/binary"
"errors"
"fmt"
"math"
"net"
"sync"
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/security/s2a/internal/handshaker/service"
s2apb "google.golang.org/grpc/security/s2a/internal/proto"
"google.golang.org/grpc/security/s2a/internal/record/internal/halfconn"
)
Expand Down Expand Up @@ -180,8 +185,6 @@ type conn struct {
// is computed as overheadSize = header size + record type byte + tag size.
// Note that there is no padding by zeros in the overhead calculation.
overheadSize int
// hsAddr stores the address of the S2A handshaker service.
hsAddr string
// readMutex guards against concurrent calls to Read. This is required since
// Close may be called during a Read.
readMutex sync.Mutex
Expand All @@ -196,6 +199,8 @@ type conn struct {
// sessionTickets holds the completed session tickets until they are sent to
// the handshaker service for processing.
sessionTickets [][]byte
// ticketSender sends session tickets to the S2A handshaker service.
ticketSender s2aTicketSender
}

// ConnParameters holds the parameters used for creating a new conn object.
Expand Down Expand Up @@ -228,6 +233,12 @@ type ConnParameters struct {
// HSAddr stores the address of the S2A handshaker service. This parameter
// is optional. If not provided, then TLS resumption is disabled.
HSAddr string
// ConnectionId is the connection identifier that was created and sent by
// S2A at the end of a handshake.
ConnectionID uint64
// LocalIdentity is the local identity that was used by S2A during session
// setup and included in the session result.
LocalIdentity *s2apb.Identity
}

func NewConn(o *ConnParameters) (net.Conn, error) {
Expand Down Expand Up @@ -272,7 +283,6 @@ func NewConn(o *ConnParameters) (net.Conn, error) {
outRecordsBuf: make([]byte, tlsRecordMaxSize),
nextRecord: unusedBuf,
overheadSize: overheadSize,
hsAddr: o.HSAddr,
ticketState: ticketsNotYetReceived,
// Pre-allocate the buffer for one session ticket message and the max
// plaintext size. This is the largest size that handshakeBuf will need
Expand All @@ -283,6 +293,11 @@ func NewConn(o *ConnParameters) (net.Conn, error) {
// completed. Therefore, the buffer size below should be large enough to
// buffer any handshake messages.
handshakeBuf: make([]byte, 0, tlsHandshakePrefixSize+tlsMaxSessionTicketSize+tlsRecordMaxPlaintextSize-1),
ticketSender: &ticketSender{
hsAddr: o.HSAddr,
connectionID: o.ConnectionID,
localIdentity: o.LocalIdentity,
},
}
return s2aConn, nil
}
Expand Down Expand Up @@ -353,7 +368,7 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
if p.ticketState == receivingTickets {
p.ticketState = notReceivingTickets
// TODO: send tickets to handshaker
p.ticketSender.sendTicketsToS2A(p.sessionTickets)
}
case alert:
return 0, p.handleAlertMessage()
Expand Down Expand Up @@ -656,7 +671,7 @@ func (p *conn) handleHandshakeMessage() error {
p.sessionTickets = append(p.sessionTickets, msg)
if len(p.sessionTickets) == maxAllowedTickets {
p.ticketState = notReceivingTickets
// TODO: send tickets to handshaker
p.ticketSender.sendTicketsToS2A(p.sessionTickets)
}
default:
return errors.New("unknown handshake message type")
Expand Down Expand Up @@ -703,6 +718,84 @@ func (p *conn) handleKeyUpdateMsg(msg []byte) error {
return nil
}

// s2aTicketSender sends session tickets to the S2A handshaker service.
type s2aTicketSender interface {
// sendTicketsToS2A sends the given session tickets to the S2A handshaker
// service.
sendTicketsToS2A(sessionTickets [][]byte)
}

// ticketStream is the stream used to send and receive session information.
type ticketStream interface {
Send(*s2apb.SessionReq) error
Recv() (*s2apb.SessionResp, error)
}

type ticketSender struct {
// hsAddr stores the address of the S2A handshaker service.
hsAddr string
// connectionID is the connection identifier that was created and sent by
// S2A at the end of a handshake.
connectionID uint64
// localIdentity is the local identity that was used by S2A during session
// setup and included in the session result.
localIdentity *s2apb.Identity
}

// sendTicketsToS2A sends the given sessionTickets to the S2A handshaker
// service. This is done asynchronously and writes to the error logs if an error
// occurs.
func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte) {
go func() {
if err := func() error {
hsConn, err := service.Dial(t.hsAddr)
if err != nil {
return err
}
client := s2apb.NewS2AServiceClient(hsConn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
session, err := client.SetUpSession(ctx)
if err != nil {
return err
}
defer func() {
if err := session.CloseSend(); err != nil {
grpclog.Error(err)
}
}()
return t.writeTicketsToStream(session, sessionTickets)
}(); err != nil {
grpclog.Error(err)
}
}()
}

// writeTicketsToStream writes the given session tickets to the given stream.
func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error {
if err := stream.Send(
&s2apb.SessionReq{
ReqOneof: &s2apb.SessionReq_ResumptionTicket{
ResumptionTicket: &s2apb.ResumptionTicketReq{
InBytes: sessionTickets,
ConnectionId: t.connectionID,
LocalIdentity: t.localIdentity,
},
},
},
); err != nil {
return err
}
sessionResp, err := stream.Recv()
if err != nil {
return err
}
if sessionResp.GetStatus().GetCode() != uint32(codes.OK) {
return errors.New("s2a session ticket response was not OK")
}
return nil
}

// bidEndianInt24 converts the given byte buffer of at least size 3 and
// outputs the resulting 24 bit integer as a uint32. This is needed because
// TLS 1.3 requires 3 byte integers, and the binary.BigEndian package does
Expand Down
Loading

0 comments on commit 725135d

Please sign in to comment.