From a9297ed029e7bff4b4451244b04c13115c31ed41 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 30 Sep 2020 12:56:54 -0400 Subject: [PATCH] Initialize saltfilter on demand --- internal/bloomring.go | 6 +++ internal/bloomring_test.go | 17 +++++++ internal/saltfilter.go | 99 +++++++++++++++++--------------------- shadowaead/packet.go | 5 +- shadowaead/stream.go | 5 +- 5 files changed, 74 insertions(+), 58 deletions(-) diff --git a/internal/bloomring.go b/internal/bloomring.go index 7a603a40..c89a4e92 100644 --- a/internal/bloomring.go +++ b/internal/bloomring.go @@ -41,6 +41,9 @@ func NewBloomRing(slot, capacity int, falsePositiveRate float64) *BloomRing { } func (r *BloomRing) Add(b []byte) { + if r == nil { + return + } r.mutex.Lock() defer r.mutex.Unlock() slot := r.slots[r.slotPosition] @@ -56,6 +59,9 @@ func (r *BloomRing) Add(b []byte) { } func (r *BloomRing) Test(b []byte) bool { + if r == nil { + return false + } r.mutex.RLock() defer r.mutex.RUnlock() for _, s := range r.slots { diff --git a/internal/bloomring_test.go b/internal/bloomring_test.go index 7f230215..f2d4319c 100644 --- a/internal/bloomring_test.go +++ b/internal/bloomring_test.go @@ -27,6 +27,16 @@ func TestBloomRing_Add(t *testing.T) { bloomRingInstance.Add(make([]byte, 16)) } +func TestBloomRing_NilAdd(t *testing.T) { + defer func() { + if any := recover(); any != nil { + t.Fatalf("Should not got panic while adding item: %v", any) + } + }() + var nilRing *internal.BloomRing + nilRing.Add(make([]byte, 16)) +} + func TestBloomRing_Test(t *testing.T) { buf := []byte("shadowsocks") bloomRingInstance.Add(buf) @@ -35,6 +45,13 @@ func TestBloomRing_Test(t *testing.T) { } } +func TestBloomRing_NilTestIsFalse(t *testing.T) { + var nilRing *internal.BloomRing + if nilRing.Test([]byte("shadowsocks")) { + t.Fatal("Test should return false for nil BloomRing") + } +} + func BenchmarkBloomRing(b *testing.B) { // Generate test samples with different length samples := make([][]byte, internal.DefaultSFCapacity-internal.DefaultSFSlot) diff --git a/internal/saltfilter.go b/internal/saltfilter.go index aed9b6e0..8913da54 100644 --- a/internal/saltfilter.go +++ b/internal/saltfilter.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strconv" + "sync" ) // Those suggest value are all set according to @@ -21,60 +22,50 @@ const EnvironmentPrefix = "SHADOWSOCKS_" // A shared instance used for checking salt repeat var saltfilter *BloomRing -func init() { - var ( - finalCapacity = DefaultSFCapacity - finalFPR = DefaultSFFPR - finalSlot = float64(DefaultSFSlot) - ) - for _, opt := range []struct { - ENVName string - Target *float64 - }{ - { - ENVName: "CAPACITY", - Target: &finalCapacity, - }, - { - ENVName: "FPR", - Target: &finalFPR, - }, - { - ENVName: "SLOT", - Target: &finalSlot, - }, - } { - envKey := EnvironmentPrefix + "SF_" + opt.ENVName - env := os.Getenv(envKey) - if env != "" { - p, err := strconv.ParseFloat(env, 64) - if err != nil { - panic(fmt.Sprintf("Invalid envrionment `%s` setting in saltfilter: %s", envKey, env)) +// Used to initialize the saltfilter singleton only once. +var initSaltfilterOnce sync.Once + +// GetSaltFilterSingleton returns the BloomRing singleton, +// initializing it on first call. +func GetSaltFilterSingleton() *BloomRing { + initSaltfilterOnce.Do(func() { + var ( + finalCapacity = DefaultSFCapacity + finalFPR = DefaultSFFPR + finalSlot = float64(DefaultSFSlot) + ) + for _, opt := range []struct { + ENVName string + Target *float64 + }{ + { + ENVName: "CAPACITY", + Target: &finalCapacity, + }, + { + ENVName: "FPR", + Target: &finalFPR, + }, + { + ENVName: "SLOT", + Target: &finalSlot, + }, + } { + envKey := EnvironmentPrefix + "SF_" + opt.ENVName + env := os.Getenv(envKey) + if env != "" { + p, err := strconv.ParseFloat(env, 64) + if err != nil { + panic(fmt.Sprintf("Invalid envrionment `%s` setting in saltfilter: %s", envKey, env)) + } + *opt.Target = p } - *opt.Target = p } - } - // Support disable saltfilter by given a negative capacity - if finalCapacity <= 0 { - return - } - saltfilter = NewBloomRing(int(finalSlot), int(finalCapacity), finalFPR) -} - -// TestSalt returns true if salt is repeated -func TestSalt(b []byte) bool { - // If nil means feature disabled, return false to bypass salt repeat detection - if saltfilter == nil { - return false - } - return saltfilter.Test(b) -} - -// AddSalt salt to filter -func AddSalt(b []byte) { - // If nil means feature disabled - if saltfilter == nil { - return - } - saltfilter.Add(b) + // Support disable saltfilter by given a negative capacity + if finalCapacity <= 0 { + return + } + saltfilter = NewBloomRing(int(finalSlot), int(finalCapacity), finalFPR) + }) + return saltfilter } diff --git a/shadowaead/packet.go b/shadowaead/packet.go index 6f48f14c..8893fd4f 100644 --- a/shadowaead/packet.go +++ b/shadowaead/packet.go @@ -45,15 +45,16 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) { if len(pkt) < saltSize { return nil, ErrShortPacket } + saltfilter := internal.GetSaltFilterSingleton() salt := pkt[:saltSize] - if internal.TestSalt(salt) { + if saltfilter.Test(salt) { return nil, ErrRepeatedSalt } aead, err := ciph.Decrypter(salt) if err != nil { return nil, err } - internal.AddSalt(salt) + saltfilter.Add(salt) if len(pkt) < saltSize+aead.Overhead() { return nil, ErrShortPacket } diff --git a/shadowaead/stream.go b/shadowaead/stream.go index a41e14ea..82b84ac2 100644 --- a/shadowaead/stream.go +++ b/shadowaead/stream.go @@ -205,14 +205,15 @@ func (c *streamConn) initReader() error { if _, err := io.ReadFull(c.Conn, salt); err != nil { return err } - if internal.TestSalt(salt) { + saltfilter := internal.GetSaltFilterSingleton() + if saltfilter.Test(salt) { return ErrRepeatedSalt } aead, err := c.Decrypter(salt) if err != nil { return err } - internal.AddSalt(salt) + saltfilter.Add(salt) c.r = newReader(c.Conn, aead) return nil