Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions brontide/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package brontide

import (
"bytes"
"io"
"math"
"math/rand"
"testing"
Expand Down Expand Up @@ -63,3 +64,47 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
}
require.NoError(t, benchErr)
}

// BenchmarkWriteMessage benchmarks the performance of writing a maximum-sized
// message and flushing it to an io.Discard to measure the allocation and CPU
// overhead of the encryption and writing logic.
func BenchmarkWriteMessage(b *testing.B) {
localConn, remoteConn, err := establishTestConnection(b)
require.NoError(b, err, "unable to establish test connection: %v", err)

noiseLocalConn, ok := localConn.(*Conn)
require.True(b, ok, "expected *Conn type for localConn")

// Create the largest possible message we can write (MaxUint16 bytes).
// This is the maximum message size allowed by the protocol.
const maxMsgSize = math.MaxUint16
largeMsg := bytes.Repeat([]byte("a"), maxMsgSize)

// Use io.Discard to simulate writing to a network connection that
// continuously accepts data without needing resets.
discard := io.Discard

b.ReportAllocs()
b.ResetTimer()

for i := 0; i < b.N; i++ {
// Write our massive message, then call flush to actually write
// the encrypted message This simulates a full write operation
// to a network.
err := noiseLocalConn.noise.WriteMessage(largeMsg)
if err != nil {
b.Fatalf("WriteMessage failed: %v", err)
}
_, err = noiseLocalConn.noise.Flush(discard)
if err != nil {
b.Fatalf("Flush failed: %v", err)
}
}

// We'll make sure to clean up the connections at the end of the
// benchmark.
b.Cleanup(func() {
localConn.Close()
remoteConn.Close()
})
}
7 changes: 7 additions & 0 deletions brontide/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,10 @@ func (c *Conn) RemotePub() *btcec.PublicKey {
func (c *Conn) LocalPub() *btcec.PublicKey {
return c.noise.localStatic.PubKey()
}

// ClearPendingSend drops references to the next header and body buffers and
// returns any pooled buffers back to their respective pools so that the memory
// can be reused.
func (c *Conn) ClearPendingSend() {
c.noise.releaseBuffers()
}
98 changes: 84 additions & 14 deletions brontide/noise.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"fmt"
"io"
"math"
"sync"
"time"

"github.com/btcsuite/btcd/btcec/v2"
Expand Down Expand Up @@ -35,6 +36,10 @@
// header and it's MAC.
encHeaderSize = lengthHeaderSize + macSize

// maxMessageSize is the maximum size of an encrypted message including
// the MAC. This is the max payload (65535) plus the MAC size (16).
maxMessageSize = math.MaxUint16 + macSize

// keyRotationInterval is the number of messages sent on a single
// cipher stream before the keys are rotated forwards.
keyRotationInterval = 1000
Expand Down Expand Up @@ -65,9 +70,25 @@
ephemeralGen = func() (*btcec.PrivateKey, error) {
return btcec.NewPrivateKey()
}
)

// TODO(roasbeef): free buffer pool?
// headerBufferPool is a pool for encrypted header buffers.
headerBufferPool = &sync.Pool{
New: func() interface{} {
b := make([]byte, 0, encHeaderSize)
return &b
},
}

// bodyBufferPool is a pool for encrypted message body buffers.
bodyBufferPool = &sync.Pool{
New: func() interface{} {
// Allocate max size to avoid reallocation.
// maxMessageSize already includes the MAC.
b := make([]byte, 0, maxMessageSize)
return &b
},
}
)

// ecdh performs an ECDH operation between pub and priv. The returned value is
// the sha256 of the compressed shared point.
Expand All @@ -87,6 +108,9 @@
// TODO(roasbeef): this should actually be 96 bit
nonce uint64

// nonceBuffer is a reusable buffer for the nonce to avoid allocations.
nonceBuffer [12]byte

// secretKey is the shared symmetric key which will be used to
// instantiate the cipher.
//
Expand All @@ -113,10 +137,12 @@
}
}()

var nonce [12]byte
binary.LittleEndian.PutUint64(nonce[4:], c.nonce)
// Write the nonce counter to the buffer (bytes 4-11).
binary.LittleEndian.PutUint64(c.nonceBuffer[4:], c.nonce)

return c.cipher.Seal(cipherText, nonce[:], plainText, associatedData)
return c.cipher.Seal(
cipherText, c.nonceBuffer[:], plainText, associatedData,
)
}

// Decrypt attempts to decrypt the passed ciphertext observing the specified
Expand All @@ -131,10 +157,12 @@
}
}()

var nonce [12]byte
binary.LittleEndian.PutUint64(nonce[4:], c.nonce)
// Write the nonce counter to the buffer (bytes 4-11).
binary.LittleEndian.PutUint64(c.nonceBuffer[4:], c.nonce)

return c.cipher.Open(plainText, nonce[:], cipherText, associatedData)
return c.cipher.Open(
plainText, c.nonceBuffer[:], cipherText, associatedData,
)
}

// InitializeKey initializes the secret key and AEAD cipher scheme based off of
Expand Down Expand Up @@ -374,6 +402,9 @@
// (of the next ciphertext), followed by a 16 byte MAC.
nextCipherHeader [encHeaderSize]byte

// pktLenBuffer is a reusable buffer for encoding the packet length.
pktLenBuffer [lengthHeaderSize]byte

// nextHeaderSend holds a reference to the remaining header bytes to
// write out for a pending message. This allows us to tolerate timeout
// errors that cause partial writes.
Expand All @@ -383,6 +414,14 @@
// out for a pending message. This allows us to tolerate timeout errors
// that cause partial writes.
nextBodySend []byte

// pooledHeaderBuf is the pooled buffer used for the header, which we
// need to track so we can return it to the pool when done.
pooledHeaderBuf *[]byte

// pooledBodyBuf is the pooled buffer used for the body, which we need
// to track so we can return it to the pool when done.
pooledBodyBuf *[]byte
}

// NewBrontideMachine creates a new instance of the brontide state-machine. If
Expand Down Expand Up @@ -740,14 +779,20 @@
// NOT include the MAC.
fullLength := uint16(len(p))

var pktLen [2]byte
binary.BigEndian.PutUint16(pktLen[:], fullLength)
binary.BigEndian.PutUint16(b.pktLenBuffer[:], fullLength)

b.pooledHeaderBuf = headerBufferPool.Get().(*[]byte)

Check failure on line 784 in brontide/noise.go

View workflow job for this annotation

GitHub Actions / Lint code

type assertion must be checked (forcetypeassert)
b.pooledBodyBuf = bodyBufferPool.Get().(*[]byte)

Check failure on line 785 in brontide/noise.go

View workflow job for this annotation

GitHub Actions / Lint code

type assertion must be checked (forcetypeassert)

// First, generate the encrypted+MAC'd length prefix for the packet.
b.nextHeaderSend = b.sendCipher.Encrypt(nil, nil, pktLen[:])
// First, generate the encrypted+MAC'd length prefix for the packet. We
// pass our pooled buffer as the cipherText (dst) parameter.
b.nextHeaderSend = b.sendCipher.Encrypt(
nil, *b.pooledHeaderBuf, b.pktLenBuffer[:],
)

// Finally, generate the encrypted packet itself.
b.nextBodySend = b.sendCipher.Encrypt(nil, nil, p)
// Finally, generate the encrypted packet itself. We pass our pooled
// buffer as the cipherText (dst) parameter.
b.nextBodySend = b.sendCipher.Encrypt(nil, *b.pooledBodyBuf, p)

return nil
}
Expand Down Expand Up @@ -824,9 +869,34 @@
}
}

// If both header and body have been fully flushed, release the pooled
// buffers back to their pools.
if len(b.nextHeaderSend) == 0 && len(b.nextBodySend) == 0 {
b.releaseBuffers()
}

return nn, nil
}

// releaseBuffers returns the pooled buffers back to their respective pools
// and clears the references.
func (b *Machine) releaseBuffers() {
if b.pooledHeaderBuf != nil {
*b.pooledHeaderBuf = (*b.pooledHeaderBuf)[:0]
headerBufferPool.Put(b.pooledHeaderBuf)
b.pooledHeaderBuf = nil
}

if b.pooledBodyBuf != nil {
*b.pooledBodyBuf = (*b.pooledBodyBuf)[:0]
bodyBufferPool.Put(b.pooledBodyBuf)
b.pooledBodyBuf = nil
}

b.nextHeaderSend = nil
b.nextBodySend = nil
}

// ReadMessage attempts to read the next message from the passed io.Reader. In
// the case of an authentication error, a non-nil error is returned.
func (b *Machine) ReadMessage(r io.Reader) ([]byte, error) {
Expand Down
Loading
Loading