From 477a5b4c327a4fea3cab2fe127f89940289b65e5 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Fri, 24 Nov 2023 01:25:35 +0100 Subject: [PATCH] sha3: make APIs usable with zero allocations The "buf points into storage" pattern is nice, but causes the whole state struct to escape, since escape analysis can't track the pointer once it's assigned to buf. Change-Id: I31c0e83f946d66bedb5a180e96ab5d5e936eb322 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/544817 Reviewed-by: Cherry Mui LUCI-TryBot-Result: Go LUCI Reviewed-by: Roland Shoemaker Reviewed-by: Mauri de Souza Meneguzzo Auto-Submit: Filippo Valsorda --- sha3/allocations_test.go | 53 +++++++++++++++++++++++++++++++++++ sha3/sha3.go | 60 ++++++++++++++++------------------------ 2 files changed, 77 insertions(+), 36 deletions(-) create mode 100644 sha3/allocations_test.go diff --git a/sha3/allocations_test.go b/sha3/allocations_test.go new file mode 100644 index 0000000000..c925099304 --- /dev/null +++ b/sha3/allocations_test.go @@ -0,0 +1,53 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !noopt + +package sha3_test + +import ( + "testing" + + "golang.org/x/crypto/sha3" +) + +var sink byte + +func TestAllocations(t *testing.T) { + t.Run("New", func(t *testing.T) { + if allocs := testing.AllocsPerRun(10, func() { + h := sha3.New256() + b := []byte("ABC") + h.Write(b) + out := make([]byte, 0, 32) + out = h.Sum(out) + sink ^= out[0] + }); allocs > 0 { + t.Errorf("expected zero allocations, got %0.1f", allocs) + } + }) + t.Run("NewShake", func(t *testing.T) { + if allocs := testing.AllocsPerRun(10, func() { + h := sha3.NewShake128() + b := []byte("ABC") + h.Write(b) + out := make([]byte, 0, 32) + out = h.Sum(out) + sink ^= out[0] + h.Read(out) + sink ^= out[0] + }); allocs > 0 { + t.Errorf("expected zero allocations, got %0.1f", allocs) + } + }) + t.Run("Sum", func(t *testing.T) { + if allocs := testing.AllocsPerRun(10, func() { + b := []byte("ABC") + out := sha3.Sum256(b) + sink ^= out[0] + }); allocs > 0 { + t.Errorf("expected zero allocations, got %0.1f", allocs) + } + }) +} diff --git a/sha3/sha3.go b/sha3/sha3.go index 33bd73b0f6..afedde5abf 100644 --- a/sha3/sha3.go +++ b/sha3/sha3.go @@ -23,7 +23,6 @@ const ( type state struct { // Generic sponge components. a [25]uint64 // main state of the hash - buf []byte // points into storage rate int // the number of bytes of state to use // dsbyte contains the "domain separation" bits and the first bit of @@ -40,6 +39,7 @@ type state struct { // Extendable-Output Functions (May 2014)" dsbyte byte + i, n int // storage[i:n] is the buffer, i is only used while squeezing storage [maxRate]byte // Specific to SHA-3 and SHAKE. @@ -54,24 +54,18 @@ func (d *state) BlockSize() int { return d.rate } func (d *state) Size() int { return d.outputLen } // Reset clears the internal state by zeroing the sponge state and -// the byte buffer, and setting Sponge.state to absorbing. +// the buffer indexes, and setting Sponge.state to absorbing. func (d *state) Reset() { // Zero the permutation's state. for i := range d.a { d.a[i] = 0 } d.state = spongeAbsorbing - d.buf = d.storage[:0] + d.i, d.n = 0, 0 } func (d *state) clone() *state { ret := *d - if ret.state == spongeAbsorbing { - ret.buf = ret.storage[:len(ret.buf)] - } else { - ret.buf = ret.storage[d.rate-cap(d.buf) : d.rate] - } - return &ret } @@ -82,43 +76,40 @@ func (d *state) permute() { case spongeAbsorbing: // If we're absorbing, we need to xor the input into the state // before applying the permutation. - xorIn(d, d.buf) - d.buf = d.storage[:0] + xorIn(d, d.storage[:d.rate]) + d.n = 0 keccakF1600(&d.a) case spongeSqueezing: // If we're squeezing, we need to apply the permutation before // copying more output. keccakF1600(&d.a) - d.buf = d.storage[:d.rate] - copyOut(d, d.buf) + d.i = 0 + copyOut(d, d.storage[:d.rate]) } } // pads appends the domain separation bits in dsbyte, applies // the multi-bitrate 10..1 padding rule, and permutes the state. -func (d *state) padAndPermute(dsbyte byte) { - if d.buf == nil { - d.buf = d.storage[:0] - } +func (d *state) padAndPermute() { // Pad with this instance's domain-separator bits. We know that there's // at least one byte of space in d.buf because, if it were full, // permute would have been called to empty it. dsbyte also contains the // first one bit for the padding. See the comment in the state struct. - d.buf = append(d.buf, dsbyte) - zerosStart := len(d.buf) - d.buf = d.storage[:d.rate] - for i := zerosStart; i < d.rate; i++ { - d.buf[i] = 0 + d.storage[d.n] = d.dsbyte + d.n++ + for d.n < d.rate { + d.storage[d.n] = 0 + d.n++ } // This adds the final one bit for the padding. Because of the way that // bits are numbered from the LSB upwards, the final bit is the MSB of // the last byte. - d.buf[d.rate-1] ^= 0x80 + d.storage[d.rate-1] ^= 0x80 // Apply the permutation d.permute() d.state = spongeSqueezing - d.buf = d.storage[:d.rate] - copyOut(d, d.buf) + d.n = d.rate + copyOut(d, d.storage[:d.rate]) } // Write absorbs more data into the hash's state. It panics if any @@ -127,28 +118,25 @@ func (d *state) Write(p []byte) (written int, err error) { if d.state != spongeAbsorbing { panic("sha3: Write after Read") } - if d.buf == nil { - d.buf = d.storage[:0] - } written = len(p) for len(p) > 0 { - if len(d.buf) == 0 && len(p) >= d.rate { + if d.n == 0 && len(p) >= d.rate { // The fast path; absorb a full "rate" bytes of input and apply the permutation. xorIn(d, p[:d.rate]) p = p[d.rate:] keccakF1600(&d.a) } else { // The slow path; buffer the input until we can fill the sponge, and then xor it in. - todo := d.rate - len(d.buf) + todo := d.rate - d.n if todo > len(p) { todo = len(p) } - d.buf = append(d.buf, p[:todo]...) + d.n += copy(d.storage[d.n:], p[:todo]) p = p[todo:] // If the sponge is full, apply the permutation. - if len(d.buf) == d.rate { + if d.n == d.rate { d.permute() } } @@ -161,19 +149,19 @@ func (d *state) Write(p []byte) (written int, err error) { func (d *state) Read(out []byte) (n int, err error) { // If we're still absorbing, pad and apply the permutation. if d.state == spongeAbsorbing { - d.padAndPermute(d.dsbyte) + d.padAndPermute() } n = len(out) // Now, do the squeezing. for len(out) > 0 { - n := copy(out, d.buf) - d.buf = d.buf[n:] + n := copy(out, d.storage[d.i:d.n]) + d.i += n out = out[n:] // Apply the permutation if we've squeezed the sponge dry. - if len(d.buf) == 0 { + if d.i == d.rate { d.permute() } }