diff --git a/sha3/allocations_test.go b/sha3/allocations_test.go new file mode 100644 index 0000000000..c925099304 --- /dev/null +++ b/sha3/allocations_test.go @@ -0,0 +1,53 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !noopt + +package sha3_test + +import ( + "testing" + + "golang.org/x/crypto/sha3" +) + +var sink byte + +func TestAllocations(t *testing.T) { + t.Run("New", func(t *testing.T) { + if allocs := testing.AllocsPerRun(10, func() { + h := sha3.New256() + b := []byte("ABC") + h.Write(b) + out := make([]byte, 0, 32) + out = h.Sum(out) + sink ^= out[0] + }); allocs > 0 { + t.Errorf("expected zero allocations, got %0.1f", allocs) + } + }) + t.Run("NewShake", func(t *testing.T) { + if allocs := testing.AllocsPerRun(10, func() { + h := sha3.NewShake128() + b := []byte("ABC") + h.Write(b) + out := make([]byte, 0, 32) + out = h.Sum(out) + sink ^= out[0] + h.Read(out) + sink ^= out[0] + }); allocs > 0 { + t.Errorf("expected zero allocations, got %0.1f", allocs) + } + }) + t.Run("Sum", func(t *testing.T) { + if allocs := testing.AllocsPerRun(10, func() { + b := []byte("ABC") + out := sha3.Sum256(b) + sink ^= out[0] + }); allocs > 0 { + t.Errorf("expected zero allocations, got %0.1f", allocs) + } + }) +} diff --git a/sha3/sha3.go b/sha3/sha3.go index 33bd73b0f6..afedde5abf 100644 --- a/sha3/sha3.go +++ b/sha3/sha3.go @@ -23,7 +23,6 @@ const ( type state struct { // Generic sponge components. a [25]uint64 // main state of the hash - buf []byte // points into storage rate int // the number of bytes of state to use // dsbyte contains the "domain separation" bits and the first bit of @@ -40,6 +39,7 @@ type state struct { // Extendable-Output Functions (May 2014)" dsbyte byte + i, n int // storage[i:n] is the buffer, i is only used while squeezing storage [maxRate]byte // Specific to SHA-3 and SHAKE. @@ -54,24 +54,18 @@ func (d *state) BlockSize() int { return d.rate } func (d *state) Size() int { return d.outputLen } // Reset clears the internal state by zeroing the sponge state and -// the byte buffer, and setting Sponge.state to absorbing. +// the buffer indexes, and setting Sponge.state to absorbing. func (d *state) Reset() { // Zero the permutation's state. for i := range d.a { d.a[i] = 0 } d.state = spongeAbsorbing - d.buf = d.storage[:0] + d.i, d.n = 0, 0 } func (d *state) clone() *state { ret := *d - if ret.state == spongeAbsorbing { - ret.buf = ret.storage[:len(ret.buf)] - } else { - ret.buf = ret.storage[d.rate-cap(d.buf) : d.rate] - } - return &ret } @@ -82,43 +76,40 @@ func (d *state) permute() { case spongeAbsorbing: // If we're absorbing, we need to xor the input into the state // before applying the permutation. - xorIn(d, d.buf) - d.buf = d.storage[:0] + xorIn(d, d.storage[:d.rate]) + d.n = 0 keccakF1600(&d.a) case spongeSqueezing: // If we're squeezing, we need to apply the permutation before // copying more output. keccakF1600(&d.a) - d.buf = d.storage[:d.rate] - copyOut(d, d.buf) + d.i = 0 + copyOut(d, d.storage[:d.rate]) } } // pads appends the domain separation bits in dsbyte, applies // the multi-bitrate 10..1 padding rule, and permutes the state. -func (d *state) padAndPermute(dsbyte byte) { - if d.buf == nil { - d.buf = d.storage[:0] - } +func (d *state) padAndPermute() { // Pad with this instance's domain-separator bits. We know that there's // at least one byte of space in d.buf because, if it were full, // permute would have been called to empty it. dsbyte also contains the // first one bit for the padding. See the comment in the state struct. - d.buf = append(d.buf, dsbyte) - zerosStart := len(d.buf) - d.buf = d.storage[:d.rate] - for i := zerosStart; i < d.rate; i++ { - d.buf[i] = 0 + d.storage[d.n] = d.dsbyte + d.n++ + for d.n < d.rate { + d.storage[d.n] = 0 + d.n++ } // This adds the final one bit for the padding. Because of the way that // bits are numbered from the LSB upwards, the final bit is the MSB of // the last byte. - d.buf[d.rate-1] ^= 0x80 + d.storage[d.rate-1] ^= 0x80 // Apply the permutation d.permute() d.state = spongeSqueezing - d.buf = d.storage[:d.rate] - copyOut(d, d.buf) + d.n = d.rate + copyOut(d, d.storage[:d.rate]) } // Write absorbs more data into the hash's state. It panics if any @@ -127,28 +118,25 @@ func (d *state) Write(p []byte) (written int, err error) { if d.state != spongeAbsorbing { panic("sha3: Write after Read") } - if d.buf == nil { - d.buf = d.storage[:0] - } written = len(p) for len(p) > 0 { - if len(d.buf) == 0 && len(p) >= d.rate { + if d.n == 0 && len(p) >= d.rate { // The fast path; absorb a full "rate" bytes of input and apply the permutation. xorIn(d, p[:d.rate]) p = p[d.rate:] keccakF1600(&d.a) } else { // The slow path; buffer the input until we can fill the sponge, and then xor it in. - todo := d.rate - len(d.buf) + todo := d.rate - d.n if todo > len(p) { todo = len(p) } - d.buf = append(d.buf, p[:todo]...) + d.n += copy(d.storage[d.n:], p[:todo]) p = p[todo:] // If the sponge is full, apply the permutation. - if len(d.buf) == d.rate { + if d.n == d.rate { d.permute() } } @@ -161,19 +149,19 @@ func (d *state) Write(p []byte) (written int, err error) { func (d *state) Read(out []byte) (n int, err error) { // If we're still absorbing, pad and apply the permutation. if d.state == spongeAbsorbing { - d.padAndPermute(d.dsbyte) + d.padAndPermute() } n = len(out) // Now, do the squeezing. for len(out) > 0 { - n := copy(out, d.buf) - d.buf = d.buf[n:] + n := copy(out, d.storage[d.i:d.n]) + d.i += n out = out[n:] // Apply the permutation if we've squeezed the sponge dry. - if len(d.buf) == 0 { + if d.i == d.rate { d.permute() } }