Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent replays of server data #78

Merged
merged 15 commits into from
Aug 27, 2020
134 changes: 105 additions & 29 deletions shadowsocks/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,99 @@ import (
// payloadSizeMask is the maximum size of payload in bytes.
const payloadSizeMask = 0x3FFF // 16*1024 - 1

// SaltGenerator generates unique salts to use in Shadowsocks connections.
type SaltGenerator interface {
// Returns a new salt
GetSalt() ([]byte, error)
}

// randomSaltGenerator generates a new random salt.
type randomSaltGenerator struct {
saltSize int
}

// ServerSaltGenerator generates unique salts that are secretly marked.
type ServerSaltGenerator struct {
bemasc marked this conversation as resolved.
Show resolved Hide resolved
saltSize int
encrypter cipher.AEAD
}

// GetSalt outputs a random salt.
func (sg randomSaltGenerator) GetSalt() ([]byte, error) {
bemasc marked this conversation as resolved.
Show resolved Hide resolved
salt := make([]byte, sg.saltSize)
_, err := io.ReadFull(rand.Reader, salt)
return salt, err
}

// Number of bytes of salt to use as a marker.
const markLen = 4

// Constant to identify this marking scheme.
var serverIndication = []byte("outline-salt-mark")

func NewServerSaltGenerator(cipher shadowaead.Cipher) (ServerSaltGenerator, error) {
saltSize := cipher.SaltSize()
zerosalt := make([]byte, saltSize)
encrypter, err := cipher.Encrypter(zerosalt)
bemasc marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return ServerSaltGenerator{}, err
}
return ServerSaltGenerator{saltSize, encrypter}, nil
}

func (sg ServerSaltGenerator) splitSalt(salt []byte) (prefix, mark []byte) {
prefixLen := sg.saltSize - markLen
prefix = salt[:prefixLen]
mark = salt[prefixLen:]
return
}

// getTag takes in a salt prefix and writes out the tag.
// len(prefix) must be saltSize - markLen
func (sg ServerSaltGenerator) getTag(prefix []byte) []byte {
nonce := make([]byte, sg.encrypter.NonceSize())
n := copy(nonce, prefix)
plaintext := prefix[n:]
encrypted := sg.encrypter.Seal(nil, nonce, plaintext, serverIndication)
return encrypted[len(plaintext):]
}

// GetSalt returns an apparently random salt that can be identified
// as server-originated by anyone who knows the Shadowsocks key.
func (sg ServerSaltGenerator) GetSalt() ([]byte, error) {
salt := make([]byte, sg.saltSize)
prefix, mark := sg.splitSalt(salt)
_, err := io.ReadFull(rand.Reader, prefix)
if err != nil {
return nil, err
}
tag := sg.getTag(prefix)
copy(mark, tag)
return salt, nil
}

// IsMarked returns true if the salt is marked as server-originated.
func (sg ServerSaltGenerator) IsMarked(salt []byte) bool {
bemasc marked this conversation as resolved.
Show resolved Hide resolved
prefix, mark := sg.splitSalt(salt)
tag := sg.getTag(prefix)
return bytes.Equal(tag[:markLen], mark)
}

// Writer is an io.Writer that also implements io.ReaderFrom to
// allow for piping the data without extra allocations and copies.
// The LazyWrite and Flush methods allow a header to be
// added but delayed until the first write, for concatenation.
// All methods except Flush must be called from a single thread.
type Writer interface {
io.Writer
io.ReaderFrom
// LazyWrite queues p to be written, but doesn't send it until
// Flush() is called, a non-lazy write is made, or the buffer
// is filled.
LazyWrite(p []byte) (int, error)
// Flush sends the pending data, if any. This method is
// thread-safe.
Flush() error
}

type shadowsocksWriter struct {
type Writer struct {
// This type is single-threaded except when needFlush is true.
// mu protects needFlush, and also protects everything
// else while needFlush could be true.
mu sync.Mutex
// Indicates that a concurrent flush is currently allowed.
needFlush bool
writer io.Writer
ssCipher shadowaead.Cipher
needFlush bool
writer io.Writer
ssCipher shadowaead.Cipher
saltGenerator SaltGenerator
// Wrapper for input that arrives as a slice.
byteWrapper bytes.Reader
// Number of plaintext bytes that are currently buffered.
Expand All @@ -68,16 +135,21 @@ type shadowsocksWriter struct {

// NewShadowsocksWriter creates a Writer that encrypts the given Writer using
// the shadowsocks protocol with the given shadowsocks cipher.
func NewShadowsocksWriter(writer io.Writer, ssCipher shadowaead.Cipher) Writer {
return &shadowsocksWriter{writer: writer, ssCipher: ssCipher}
func NewShadowsocksWriter(writer io.Writer, ssCipher shadowaead.Cipher) *Writer {
return &Writer{writer: writer, ssCipher: ssCipher, saltGenerator: randomSaltGenerator{ssCipher.SaltSize()}}
}

// SetSaltGenerator sets the salt generator to be used. Must be called before the first write.
func (sw *Writer) SetSaltGenerator(saltGenerator SaltGenerator) {
sw.saltGenerator = saltGenerator
bemasc marked this conversation as resolved.
Show resolved Hide resolved
}

// init generates a random salt, sets up the AEAD object and writes
// the salt to the inner Writer.
func (sw *shadowsocksWriter) init() (err error) {
func (sw *Writer) init() (err error) {
if sw.aead == nil {
salt := make([]byte, sw.ssCipher.SaltSize())
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
salt, err := sw.saltGenerator.GetSalt()
if err != nil {
return fmt.Errorf("failed to generate salt: %v", err)
}
sw.aead, err = sw.ssCipher.Encrypter(salt)
Expand All @@ -98,19 +170,21 @@ func (sw *shadowsocksWriter) init() (err error) {

// encryptBlock encrypts `plaintext` in-place. The slice must have enough capacity
// for the tag. Returns the total ciphertext length.
func (sw *shadowsocksWriter) encryptBlock(plaintext []byte) int {
func (sw *Writer) encryptBlock(plaintext []byte) int {
out := sw.aead.Seal(plaintext[:0], sw.counter, plaintext, nil)
increment(sw.counter)
return len(out)
}

func (sw *shadowsocksWriter) Write(p []byte) (int, error) {
func (sw *Writer) Write(p []byte) (int, error) {
sw.byteWrapper.Reset(p)
n, err := sw.ReadFrom(&sw.byteWrapper)
return int(n), err
}

func (sw *shadowsocksWriter) LazyWrite(p []byte) (int, error) {
// LazyWrite queues p to be written, but doesn't send it until Flush() is
// called, a non-lazy write is made, or the buffer is filled.
func (sw *Writer) LazyWrite(p []byte) (int, error) {
if err := sw.init(); err != nil {
return 0, err
}
Expand All @@ -137,7 +211,8 @@ func (sw *shadowsocksWriter) LazyWrite(p []byte) (int, error) {
}
}

func (sw *shadowsocksWriter) Flush() error {
// Flush sends the pending data, if any. This method is thread-safe.
func (sw *Writer) Flush() error {
sw.mu.Lock()
defer sw.mu.Unlock()
if !sw.needFlush {
Expand All @@ -156,7 +231,7 @@ func isZero(b []byte) bool {
}

// Returns the slices of sw.buf in which to place plaintext for encryption.
func (sw *shadowsocksWriter) buffers() (sizeBuf, payloadBuf []byte) {
func (sw *Writer) buffers() (sizeBuf, payloadBuf []byte) {
// sw.buf starts with the salt.
saltSize := sw.ssCipher.SaltSize()

Expand All @@ -168,7 +243,8 @@ func (sw *shadowsocksWriter) buffers() (sizeBuf, payloadBuf []byte) {
return
}

func (sw *shadowsocksWriter) ReadFrom(r io.Reader) (int64, error) {
// ReadFrom implements the io.ReaderFrom interface.
func (sw *Writer) ReadFrom(r io.Reader) (int64, error) {
if err := sw.init(); err != nil {
return 0, err
}
Expand Down Expand Up @@ -218,15 +294,15 @@ func (sw *shadowsocksWriter) ReadFrom(r io.Reader) (int64, error) {

// Adds as much of `plaintext` into the buffer as will fit, and increases
// sw.pending accordingly. Returns the number of bytes consumed.
func (sw *shadowsocksWriter) enqueue(plaintext []byte) int {
func (sw *Writer) enqueue(plaintext []byte) int {
_, payloadBuf := sw.buffers()
n := copy(payloadBuf[sw.pending:], plaintext)
sw.pending += n
return n
}

// Encrypts all pending data and writes it to the output.
func (sw *shadowsocksWriter) flush() error {
func (sw *Writer) flush() error {
if sw.pending == 0 {
return nil
}
Expand Down
42 changes: 30 additions & 12 deletions shadowsocks/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
onet "github.com/Jigsaw-Code/outline-ss-server/net"
logging "github.com/op/go-logging"

"github.com/shadowsocks/go-shadowsocks2/shadowaead"
"github.com/shadowsocks/go-shadowsocks2/socks"
)

Expand Down Expand Up @@ -56,30 +57,27 @@ func debugTCP(cipherID, template string, val interface{}) {
}
}

func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, onet.DuplexConn, []byte, time.Duration, error) {
clientIP := remoteIP(clientConn)
func findAccessKey(clientReader io.Reader, clientIP net.IP, cipherList CipherList) (string, shadowaead.Cipher, io.Reader, []byte, time.Duration, error) {
// We snapshot the list because it may be modified while we use it.
tcpTrialSize, ciphers := cipherList.SnapshotForClientIP(clientIP)
firstBytes := make([]byte, tcpTrialSize)
if n, err := io.ReadFull(clientConn, firstBytes); err != nil {
return "", clientConn, nil, 0, fmt.Errorf("Reading header failed after %d bytes: %v", n, err)
if n, err := io.ReadFull(clientReader, firstBytes); err != nil {
return "", nil, clientReader, nil, 0, fmt.Errorf("Reading header failed after %d bytes: %v", n, err)
}

findStartTime := time.Now()
entry, elt := findEntry(firstBytes, ciphers)
timeToCipher := time.Now().Sub(findStartTime)
if entry == nil {
// TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
return "", clientConn, nil, timeToCipher, fmt.Errorf("Could not find valid TCP cipher")
return "", nil, clientReader, nil, timeToCipher, fmt.Errorf("Could not find valid TCP cipher")
}

// Move the active cipher to the front, so that the search is quicker next time.
cipherList.MarkUsedByClientIP(elt, clientIP)
id, cipher := entry.ID, entry.Cipher
ssr := NewShadowsocksReader(io.MultiReader(bytes.NewReader(firstBytes), clientConn), cipher)
ssw := NewShadowsocksWriter(clientConn, cipher)
salt := firstBytes[:cipher.SaltSize()]
return id, onet.WrapConn(clientConn, ssr, ssw).(onet.DuplexConn), salt, timeToCipher, nil
return id, cipher, io.MultiReader(bytes.NewReader(firstBytes), clientReader), salt, timeToCipher, nil
}

// Implements a trial decryption search. This assumes that all ciphers are AEAD.
Expand Down Expand Up @@ -234,22 +232,42 @@ func (s *tcpService) handleConnection(listenerPort int, clientConn onet.DuplexCo
clientConn.SetReadDeadline(connStart.Add(s.readTimeout))
var proxyMetrics metrics.ProxyMetrics
clientConn = metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
keyID, clientConn, salt, timeToCipher, keyErr := findAccessKey(clientConn, s.ciphers)
keyID, cipher, clientReader, salt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), s.ciphers)

connError := func() *onet.ConnectionError {
if keyErr != nil {
logger.Debugf("Failed to find a valid cipher after reading %v bytes: %v", proxyMetrics.ClientProxy, keyErr)
const status = "ERR_CIPHER"
s.absorbProbe(listenerPort, clientConn, clientLocation, status, &proxyMetrics)
return onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
} else if !s.replayCache.Add(keyID, salt) { // Only check the cache if findAccessKey succeeded.
}

saltGenerator, err := NewServerSaltGenerator(cipher)
if err != nil {
return onet.NewConnectionError("ERR_SALTGEN", "Failed to construct salt generator", err)
}

isMarked := saltGenerator.IsMarked(salt)
bemasc marked this conversation as resolved.
Show resolved Hide resolved
// Only check the cache if findAccessKey succeeded and the salt is unmarked.
if isMarked || !s.replayCache.Add(keyID, salt) {
const status = "ERR_REPLAY"
fortuna marked this conversation as resolved.
Show resolved Hide resolved
s.absorbProbe(listenerPort, clientConn, clientLocation, status, &proxyMetrics)
logger.Debugf("Replay: %v in %s sent %d bytes", clientConn.RemoteAddr(), clientLocation, proxyMetrics.ClientProxy)
return onet.NewConnectionError(status, "Replay detected", nil)
var msg string
if isMarked {
msg = "Server replay detected"
} else {
msg = "Client replay detected"
}
logger.Debugf(msg+": %v in %s sent %d bytes", clientConn.RemoteAddr(), clientLocation, proxyMetrics.ClientProxy)
bemasc marked this conversation as resolved.
Show resolved Hide resolved
return onet.NewConnectionError(status, msg, nil)
}
// Clear the authentication deadline
clientConn.SetReadDeadline(time.Time{})

ssr := NewShadowsocksReader(clientReader, cipher)
ssw := NewShadowsocksWriter(clientConn, cipher)
ssw.SetSaltGenerator(saltGenerator)
clientConn = onet.WrapConn(clientConn, ssr, ssw)
return proxyConnection(clientConn, &proxyMetrics, s.checkAllowedIP)
}()

Expand Down
8 changes: 5 additions & 3 deletions shadowsocks/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
if err != nil {
b.Fatalf("AcceptTCP failed: %v", err)
}
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).IP
b.StartTimer()
findAccessKey(clientConn, cipherList)
findAccessKey(clientConn, clientIP, cipherList)
b.StopTimer()
}
}
Expand Down Expand Up @@ -139,12 +140,13 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
for n := 0; n < b.N; n++ {
cipherNumber := byte(n % numCiphers)
reader, writer := io.Pipe()
addr := &net.TCPAddr{IP: net.IPv4(192, 0, 2, cipherNumber), Port: 54321}
clientIP := net.IPv4(192, 0, 2, cipherNumber)
addr := &net.TCPAddr{IP: clientIP, Port: 54321}
c := conn{clientAddr: addr, reader: reader, writer: writer}
cipher := cipherEntries[cipherNumber].Cipher
go NewShadowsocksWriter(writer, cipher).Write(MakeTestPayload(50))
b.StartTimer()
_, _, _, _, err := findAccessKey(&c, cipherList)
_, _, _, _, _, err := findAccessKey(&c, clientIP, cipherList)
b.StopTimer()
if err != nil {
b.Error(err)
Expand Down