diff --git a/go.mod b/go.mod index 7c066773..ec09b99e 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da + github.com/riobard/go-bloom v0.0.0-20200213042214-218e1707c495 golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 ) diff --git a/go.sum b/go.sum index 9bf41e35..5fea12d9 100644 --- a/go.sum +++ b/go.sum @@ -8,3 +8,7 @@ github.com/golang/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:5JyrLPvD/ZdaY github.com/golang/sys v0.0.0-20190412213103-97732733099d h1:blRtD+FQOxZ6P7jigy+HS0R8zyGOMOv8TET4wCpzVwM= github.com/golang/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= github.com/golang/text v0.3.0/go.mod h1:GUiq9pdJKRKKAZXiVgWFEvocYuREvC14NhI4OPgEjeE= +github.com/riobard/go-bloom v0.0.0-20170218180955-2b113c64a69b h1:H9yjH/g5w8MOPjQR2zMSP/Md1kKtj/33fIht9ChC2OU= +github.com/riobard/go-bloom v0.0.0-20170218180955-2b113c64a69b/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s= +github.com/riobard/go-bloom v0.0.0-20200213042214-218e1707c495 h1:p7xbxYTzzfXghR1kpsJDeoVVRRWAotKc8u7FP/N48rU= +github.com/riobard/go-bloom v0.0.0-20200213042214-218e1707c495/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s= diff --git a/internal/bloomring.go b/internal/bloomring.go new file mode 100644 index 00000000..7a603a40 --- /dev/null +++ b/internal/bloomring.go @@ -0,0 +1,67 @@ +package internal + +import ( + "hash/fnv" + "sync" + + "github.com/riobard/go-bloom" +) + +// simply use Double FNV here as our Bloom Filter hash +func doubleFNV(b []byte) (uint64, uint64) { + hx := fnv.New64() + hx.Write(b) + x := hx.Sum64() + hy := fnv.New64a() + hy.Write(b) + y := hy.Sum64() + return x, y +} + +type BloomRing struct { + slotCapacity int + slotPosition int + slotCount int + entryCounter int + slots []bloom.Filter + mutex sync.RWMutex +} + +func NewBloomRing(slot, capacity int, falsePositiveRate float64) *BloomRing { + // Calculate entries for each slot + r := &BloomRing{ + slotCapacity: capacity / slot, + slotCount: slot, + slots: make([]bloom.Filter, slot), + } + for i := 0; i < slot; i++ { + r.slots[i] = bloom.New(r.slotCapacity, falsePositiveRate, doubleFNV) + } + return r +} + +func (r *BloomRing) Add(b []byte) { + r.mutex.Lock() + defer r.mutex.Unlock() + slot := r.slots[r.slotPosition] + if r.entryCounter > r.slotCapacity { + // Move to next slot and reset + r.slotPosition = (r.slotPosition + 1) % r.slotCount + slot = r.slots[r.slotPosition] + slot.Reset() + r.entryCounter = 0 + } + r.entryCounter++ + slot.Add(b) +} + +func (r *BloomRing) Test(b []byte) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + for _, s := range r.slots { + if s.Test(b) { + return true + } + } + return false +} diff --git a/internal/bloomring_test.go b/internal/bloomring_test.go new file mode 100644 index 00000000..7f230215 --- /dev/null +++ b/internal/bloomring_test.go @@ -0,0 +1,68 @@ +package internal_test + +import ( + "fmt" + "os" + "testing" + + "github.com/shadowsocks/go-shadowsocks2/internal" +) + +var ( + bloomRingInstance *internal.BloomRing +) + +func TestMain(m *testing.M) { + bloomRingInstance = internal.NewBloomRing(internal.DefaultSFSlot, int(internal.DefaultSFCapacity), + internal.DefaultSFFPR) + os.Exit(m.Run()) +} + +func TestBloomRing_Add(t *testing.T) { + defer func() { + if any := recover(); any != nil { + t.Fatalf("Should not got panic while adding item: %v", any) + } + }() + bloomRingInstance.Add(make([]byte, 16)) +} + +func TestBloomRing_Test(t *testing.T) { + buf := []byte("shadowsocks") + bloomRingInstance.Add(buf) + if !bloomRingInstance.Test(buf) { + t.Fatal("Test on filter missing") + } +} + +func BenchmarkBloomRing(b *testing.B) { + // Generate test samples with different length + samples := make([][]byte, internal.DefaultSFCapacity-internal.DefaultSFSlot) + var checkPoints [][]byte + for i := 0; i < len(samples); i++ { + samples[i] = []byte(fmt.Sprint(i)) + if i%1000 == 0 { + checkPoints = append(checkPoints, samples[i]) + } + } + b.Logf("Generated %d samples and %d check points", len(samples), len(checkPoints)) + for i := 1; i < 16; i++ { + b.Run(fmt.Sprintf("Slot%d", i), benchmarkBloomRing(samples, checkPoints, i)) + } +} + +func benchmarkBloomRing(samples, checkPoints [][]byte, slot int) func(*testing.B) { + filter := internal.NewBloomRing(slot, int(internal.DefaultSFCapacity), internal.DefaultSFFPR) + for _, sample := range samples { + filter.Add(sample) + } + return func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, cp := range checkPoints { + filter.Test(cp) + } + } + } +} diff --git a/internal/saltfilter.go b/internal/saltfilter.go new file mode 100644 index 00000000..aed9b6e0 --- /dev/null +++ b/internal/saltfilter.go @@ -0,0 +1,80 @@ +package internal + +import ( + "fmt" + "os" + "strconv" +) + +// Those suggest value are all set according to +// https://github.com/shadowsocks/shadowsocks-org/issues/44#issuecomment-281021054 +// Due to this package contains various internal implementation so const named with DefaultBR prefix +const ( + DefaultSFCapacity = 1e6 + // FalsePositiveRate + DefaultSFFPR = 1e-6 + DefaultSFSlot = 10 +) + +const EnvironmentPrefix = "SHADOWSOCKS_" + +// A shared instance used for checking salt repeat +var saltfilter *BloomRing + +func init() { + var ( + finalCapacity = DefaultSFCapacity + finalFPR = DefaultSFFPR + finalSlot = float64(DefaultSFSlot) + ) + for _, opt := range []struct { + ENVName string + Target *float64 + }{ + { + ENVName: "CAPACITY", + Target: &finalCapacity, + }, + { + ENVName: "FPR", + Target: &finalFPR, + }, + { + ENVName: "SLOT", + Target: &finalSlot, + }, + } { + envKey := EnvironmentPrefix + "SF_" + opt.ENVName + env := os.Getenv(envKey) + if env != "" { + p, err := strconv.ParseFloat(env, 64) + if err != nil { + panic(fmt.Sprintf("Invalid envrionment `%s` setting in saltfilter: %s", envKey, env)) + } + *opt.Target = p + } + } + // Support disable saltfilter by given a negative capacity + if finalCapacity <= 0 { + return + } + saltfilter = NewBloomRing(int(finalSlot), int(finalCapacity), finalFPR) +} + +// TestSalt returns true if salt is repeated +func TestSalt(b []byte) bool { + // If nil means feature disabled, return false to bypass salt repeat detection + if saltfilter == nil { + return false + } + return saltfilter.Test(b) +} + +// AddSalt salt to filter +func AddSalt(b []byte) { + // If nil means feature disabled + if saltfilter == nil { + return + } + saltfilter.Add(b) +} diff --git a/shadowaead/cipher.go b/shadowaead/cipher.go index 19410df9..e60fc54d 100644 --- a/shadowaead/cipher.go +++ b/shadowaead/cipher.go @@ -4,6 +4,7 @@ import ( "crypto/aes" "crypto/cipher" "crypto/sha1" + "errors" "io" "strconv" @@ -11,6 +12,9 @@ import ( "golang.org/x/crypto/hkdf" ) +// ErrRepeatedSalt means detected a reused salt +var ErrRepeatedSalt = errors.New("repeated salt detected") + type Cipher interface { KeySize() int SaltSize() int diff --git a/shadowaead/packet.go b/shadowaead/packet.go index ae5f84d4..6f48f14c 100644 --- a/shadowaead/packet.go +++ b/shadowaead/packet.go @@ -6,6 +6,8 @@ import ( "io" "net" "sync" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) // ErrShortPacket means that the packet is too short for a valid encrypted packet. @@ -27,6 +29,7 @@ func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) { if err != nil { return nil, err } + internal.AddSalt(salt) if len(dst) < saltSize+len(plaintext)+aead.Overhead() { return nil, io.ErrShortBuffer @@ -43,10 +46,14 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) { return nil, ErrShortPacket } salt := pkt[:saltSize] + if internal.TestSalt(salt) { + return nil, ErrRepeatedSalt + } aead, err := ciph.Decrypter(salt) if err != nil { return nil, err } + internal.AddSalt(salt) if len(pkt) < saltSize+aead.Overhead() { return nil, ErrShortPacket } diff --git a/shadowaead/stream.go b/shadowaead/stream.go index 5f499a21..a41e14ea 100644 --- a/shadowaead/stream.go +++ b/shadowaead/stream.go @@ -6,6 +6,8 @@ import ( "crypto/rand" "io" "net" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) // payloadSizeMask is the maximum size of payload in bytes. @@ -203,11 +205,14 @@ func (c *streamConn) initReader() error { if _, err := io.ReadFull(c.Conn, salt); err != nil { return err } - + if internal.TestSalt(salt) { + return ErrRepeatedSalt + } aead, err := c.Decrypter(salt) if err != nil { return err } + internal.AddSalt(salt) c.r = newReader(c.Conn, aead) return nil @@ -244,6 +249,7 @@ func (c *streamConn) initWriter() error { if err != nil { return err } + internal.AddSalt(salt) c.w = newWriter(c.Conn, aead) return nil } diff --git a/shadowstream/cipher.go b/shadowstream/cipher.go index fa916c9c..a0aedba7 100644 --- a/shadowstream/cipher.go +++ b/shadowstream/cipher.go @@ -3,12 +3,16 @@ package shadowstream import ( "crypto/aes" "crypto/cipher" + "errors" "strconv" "github.com/aead/chacha20" "github.com/aead/chacha20/chacha" ) +// ErrRepeatedSalt means detected a reused salt +var ErrRepeatedSalt = errors.New("repeated salt detected") + // Cipher generates a pair of stream ciphers for encryption and decryption. type Cipher interface { IVSize() int diff --git a/shadowstream/packet.go b/shadowstream/packet.go index 0defa110..4ae7ee43 100644 --- a/shadowstream/packet.go +++ b/shadowstream/packet.go @@ -6,6 +6,8 @@ import ( "io" "net" "sync" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) // ErrShortPacket means the packet is too short to be a valid encrypted packet. @@ -23,7 +25,7 @@ func Pack(dst, plaintext []byte, s Cipher) ([]byte, error) { if err != nil { return nil, err } - + internal.AddSalt(iv) s.Encrypter(iv).XORKeyStream(dst[len(iv):], plaintext) return dst[:len(iv)+len(plaintext)], nil } @@ -39,6 +41,10 @@ func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) { return nil, io.ErrShortBuffer } iv := pkt[:s.IVSize()] + if internal.TestSalt(iv) { + return nil, ErrRepeatedSalt + } + internal.AddSalt(iv) s.Decrypter(iv).XORKeyStream(dst, pkt[len(iv):]) return dst[:len(pkt)-len(iv)], nil } diff --git a/shadowstream/stream.go b/shadowstream/stream.go index eb4d9679..0cbf8fb2 100644 --- a/shadowstream/stream.go +++ b/shadowstream/stream.go @@ -6,6 +6,8 @@ import ( "crypto/rand" "io" "net" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) const bufSize = 32 * 1024 @@ -114,6 +116,10 @@ func (c *conn) initReader() error { if _, err := io.ReadFull(c.Conn, iv); err != nil { return err } + if internal.TestSalt(iv) { + return ErrRepeatedSalt + } + internal.AddSalt(iv) c.r = &reader{Reader: c.Conn, Stream: c.Decrypter(iv), buf: buf} } return nil @@ -147,6 +153,7 @@ func (c *conn) initWriter() error { if _, err := c.Conn.Write(iv); err != nil { return err } + internal.AddSalt(iv) c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf} } return nil