Skip to content

Commit

Permalink
Add locks to Read and Write (matthewstevenson88#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryanfsdf authored Jul 31, 2020
1 parent eb4811d commit 6c462b1
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions security/s2a/internal/record/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"fmt"
"math"
"net"
"sync"

"google.golang.org/grpc/grpclog"
s2apb "google.golang.org/grpc/security/s2a/internal/proto"
Expand Down Expand Up @@ -161,6 +162,13 @@ type conn struct {
overheadSize int
// hsAddr stores the address of the S2A handshaker service.
hsAddr string
// readMutex guards against concurrent calls to Read. This is required since
// Close may be called during a Read.
readMutex sync.Mutex
// writeMutex guards against concurrent calls to Write. This is required
// since Close may be called during a Write, and also because a key update
// message may be written during a Read.
writeMutex sync.Mutex
}

// ConnParameters holds the parameters used for creating a new conn object.
Expand Down Expand Up @@ -248,6 +256,8 @@ func NewConn(o *ConnParameters) (net.Conn, error) {
// that the user should close the connection via Close() if an error is thrown
// by a call to Read.
func (p *conn) Read(b []byte) (n int, err error) {
p.readMutex.Lock()
defer p.readMutex.Unlock()
// Check if p.pendingApplication data has leftover application data from
// the previous call to Read.
if len(p.pendingApplicationData) == 0 {
Expand Down Expand Up @@ -324,7 +334,8 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
// Send a key update message back to the peer if requested.
if keyUpdateRequest == byte(updateRequested) {
// TODO: lock before sending and updating.
p.writeMutex.Lock()
defer p.writeMutex.Unlock()
n, err := p.writeTLSRecord(preConstructedKeyUpdateMsg, byte(handshake))
if err != nil {
return 0, err
Expand Down Expand Up @@ -362,8 +373,9 @@ func (p *conn) Read(b []byte) (n int, err error) {
// the record to the peer. It returns the number of plaintext bytes that were
// successfully sent to the peer.
func (p *conn) Write(b []byte) (n int, err error) {
p.writeMutex.Lock()
defer p.writeMutex.Unlock()
return p.writeTLSRecord(b, tlsApplicationData)

}

// writeTLSRecord divides b into segments of size maxPlaintextBytesPerRecord,
Expand Down Expand Up @@ -456,7 +468,10 @@ func (p *conn) buildRecord(plaintext []byte, recordType byte, recordStartIndex i
}

func (p *conn) Close() error {
// TODO: Implement close with locks.
p.readMutex.Lock()
defer p.readMutex.Unlock()
p.writeMutex.Lock()
defer p.writeMutex.Unlock()
return p.Conn.Close()
}

Expand Down

0 comments on commit 6c462b1

Please sign in to comment.