diff --git a/internal/saltfilter.go b/internal/saltfilter.go index 8913da54..edb7e2d1 100644 --- a/internal/saltfilter.go +++ b/internal/saltfilter.go @@ -27,7 +27,7 @@ var initSaltfilterOnce sync.Once // GetSaltFilterSingleton returns the BloomRing singleton, // initializing it on first call. -func GetSaltFilterSingleton() *BloomRing { +func getSaltFilterSingleton() *BloomRing { initSaltfilterOnce.Do(func() { var ( finalCapacity = DefaultSFCapacity @@ -69,3 +69,13 @@ func GetSaltFilterSingleton() *BloomRing { }) return saltfilter } + +// TestSalt returns true if salt is repeated +func TestSalt(b []byte) bool { + return getSaltFilterSingleton().Test(b) +} + +// AddSalt salt to filter +func AddSalt(b []byte) { + getSaltFilterSingleton().Add(b) +} diff --git a/shadowaead/packet.go b/shadowaead/packet.go index 7329f2b0..6f48f14c 100644 --- a/shadowaead/packet.go +++ b/shadowaead/packet.go @@ -29,7 +29,7 @@ func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) { if err != nil { return nil, err } - internal.GetSaltFilterSingleton().Add(salt) + internal.AddSalt(salt) if len(dst) < saltSize+len(plaintext)+aead.Overhead() { return nil, io.ErrShortBuffer @@ -45,16 +45,15 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) { if len(pkt) < saltSize { return nil, ErrShortPacket } - saltfilter := internal.GetSaltFilterSingleton() salt := pkt[:saltSize] - if saltfilter.Test(salt) { + if internal.TestSalt(salt) { return nil, ErrRepeatedSalt } aead, err := ciph.Decrypter(salt) if err != nil { return nil, err } - saltfilter.Add(salt) + internal.AddSalt(salt) if len(pkt) < saltSize+aead.Overhead() { return nil, ErrShortPacket } diff --git a/shadowaead/stream.go b/shadowaead/stream.go index 1af2ea3f..a41e14ea 100644 --- a/shadowaead/stream.go +++ b/shadowaead/stream.go @@ -205,15 +205,14 @@ func (c *streamConn) initReader() error { if _, err := io.ReadFull(c.Conn, salt); err != nil { return err } - saltfilter := internal.GetSaltFilterSingleton() - if saltfilter.Test(salt) { + if internal.TestSalt(salt) { return ErrRepeatedSalt } aead, err := c.Decrypter(salt) if err != nil { return err } - saltfilter.Add(salt) + internal.AddSalt(salt) c.r = newReader(c.Conn, aead) return nil @@ -250,7 +249,7 @@ func (c *streamConn) initWriter() error { if err != nil { return err } - internal.GetSaltFilterSingleton().Add(salt) + internal.AddSalt(salt) c.w = newWriter(c.Conn, aead) return nil }