diff --git a/shadowaead/cipher.go b/shadowaead/cipher.go index 19410df9..e60fc54d 100644 --- a/shadowaead/cipher.go +++ b/shadowaead/cipher.go @@ -4,6 +4,7 @@ import ( "crypto/aes" "crypto/cipher" "crypto/sha1" + "errors" "io" "strconv" @@ -11,6 +12,9 @@ import ( "golang.org/x/crypto/hkdf" ) +// ErrRepeatedSalt means detected a reused salt +var ErrRepeatedSalt = errors.New("repeated salt detected") + type Cipher interface { KeySize() int SaltSize() int diff --git a/shadowaead/packet.go b/shadowaead/packet.go index ae5f84d4..6f48f14c 100644 --- a/shadowaead/packet.go +++ b/shadowaead/packet.go @@ -6,6 +6,8 @@ import ( "io" "net" "sync" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) // ErrShortPacket means that the packet is too short for a valid encrypted packet. @@ -27,6 +29,7 @@ func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) { if err != nil { return nil, err } + internal.AddSalt(salt) if len(dst) < saltSize+len(plaintext)+aead.Overhead() { return nil, io.ErrShortBuffer @@ -43,10 +46,14 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) { return nil, ErrShortPacket } salt := pkt[:saltSize] + if internal.TestSalt(salt) { + return nil, ErrRepeatedSalt + } aead, err := ciph.Decrypter(salt) if err != nil { return nil, err } + internal.AddSalt(salt) if len(pkt) < saltSize+aead.Overhead() { return nil, ErrShortPacket } diff --git a/shadowaead/stream.go b/shadowaead/stream.go index 5f499a21..a41e14ea 100644 --- a/shadowaead/stream.go +++ b/shadowaead/stream.go @@ -6,6 +6,8 @@ import ( "crypto/rand" "io" "net" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) // payloadSizeMask is the maximum size of payload in bytes. @@ -203,11 +205,14 @@ func (c *streamConn) initReader() error { if _, err := io.ReadFull(c.Conn, salt); err != nil { return err } - + if internal.TestSalt(salt) { + return ErrRepeatedSalt + } aead, err := c.Decrypter(salt) if err != nil { return err } + internal.AddSalt(salt) c.r = newReader(c.Conn, aead) return nil @@ -244,6 +249,7 @@ func (c *streamConn) initWriter() error { if err != nil { return err } + internal.AddSalt(salt) c.w = newWriter(c.Conn, aead) return nil } diff --git a/shadowstream/cipher.go b/shadowstream/cipher.go index fa916c9c..a0aedba7 100644 --- a/shadowstream/cipher.go +++ b/shadowstream/cipher.go @@ -3,12 +3,16 @@ package shadowstream import ( "crypto/aes" "crypto/cipher" + "errors" "strconv" "github.com/aead/chacha20" "github.com/aead/chacha20/chacha" ) +// ErrRepeatedSalt means detected a reused salt +var ErrRepeatedSalt = errors.New("repeated salt detected") + // Cipher generates a pair of stream ciphers for encryption and decryption. type Cipher interface { IVSize() int diff --git a/shadowstream/packet.go b/shadowstream/packet.go index 0defa110..4ae7ee43 100644 --- a/shadowstream/packet.go +++ b/shadowstream/packet.go @@ -6,6 +6,8 @@ import ( "io" "net" "sync" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) // ErrShortPacket means the packet is too short to be a valid encrypted packet. @@ -23,7 +25,7 @@ func Pack(dst, plaintext []byte, s Cipher) ([]byte, error) { if err != nil { return nil, err } - + internal.AddSalt(iv) s.Encrypter(iv).XORKeyStream(dst[len(iv):], plaintext) return dst[:len(iv)+len(plaintext)], nil } @@ -39,6 +41,10 @@ func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) { return nil, io.ErrShortBuffer } iv := pkt[:s.IVSize()] + if internal.TestSalt(iv) { + return nil, ErrRepeatedSalt + } + internal.AddSalt(iv) s.Decrypter(iv).XORKeyStream(dst, pkt[len(iv):]) return dst[:len(pkt)-len(iv)], nil } diff --git a/shadowstream/stream.go b/shadowstream/stream.go index eb4d9679..0cbf8fb2 100644 --- a/shadowstream/stream.go +++ b/shadowstream/stream.go @@ -6,6 +6,8 @@ import ( "crypto/rand" "io" "net" + + "github.com/shadowsocks/go-shadowsocks2/internal" ) const bufSize = 32 * 1024 @@ -114,6 +116,10 @@ func (c *conn) initReader() error { if _, err := io.ReadFull(c.Conn, iv); err != nil { return err } + if internal.TestSalt(iv) { + return ErrRepeatedSalt + } + internal.AddSalt(iv) c.r = &reader{Reader: c.Conn, Stream: c.Decrypter(iv), buf: buf} } return nil @@ -147,6 +153,7 @@ func (c *conn) initWriter() error { if _, err := c.Conn.Write(iv); err != nil { return err } + internal.AddSalt(iv) c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf} } return nil