Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent replays of server data #78

Merged
merged 15 commits into from
Aug 27, 2020
Next Next commit
Record server salts
  • Loading branch information
fortuna committed Aug 18, 2020
commit 3938bf1065f5389e8757b2c7c7a6d76175850627
2 changes: 1 addition & 1 deletion shadowsocks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (c *ssClient) DialTCP(laddr *net.TCPAddr, raddr string) (onet.DuplexConn, e
if err != nil {
return nil, err
}
ssw := NewShadowsocksWriter(proxyConn, c.cipher)
ssw := NewShadowsocksWriter(proxyConn, c.cipher, RandomSaltGenerator)
_, err = ssw.LazyWrite(socksTargetAddr)
if err != nil {
proxyConn.Close()
Expand Down
2 changes: 1 addition & 1 deletion shadowsocks/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func startShadowsocksTCPEchoProxy(expectedTgtAddr string, t testing.TB) (net.Lis
defer running.Done()
defer clientConn.Close()
ssr := NewShadowsocksReader(clientConn, cipher)
ssw := NewShadowsocksWriter(clientConn, cipher)
ssw := NewShadowsocksWriter(clientConn, cipher, RandomSaltGenerator)
ssClientConn := onet.WrapConn(clientConn, ssr, ssw)

tgtAddr, err := socks.ReadAddr(ssClientConn)
Expand Down
5 changes: 5 additions & 0 deletions shadowsocks/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ func TestTCPEcho(t *testing.T) {
t.Fatal("Echo mismatch")
}

// Check for client and server salts.
if len(replayCache.active) != 2 {
t.Fatalf("Replay cache has wrong number of salts: %d", len(replayCache.active))
}

conn.Close()
proxy.Stop()
echoListener.Close()
Expand Down
34 changes: 28 additions & 6 deletions shadowsocks/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ import (
// payloadSizeMask is the maximum size of payload in bytes.
const payloadSizeMask = 0x3FFF // 16*1024 - 1

// SaltGenerator generates unique salts to use in Shadowsocks connections.
type SaltGenerator interface {
// Writes a new salt into the input slice
GetSalt(salt []byte) error
}

// randomSaltGenerator generates a new random salt.
type randomSaltGenerator struct{}

// GetSalt outputs a random salt.
func (*randomSaltGenerator) GetSalt(salt []byte) error {
_, err := io.ReadFull(rand.Reader, salt)
return err
}

// RandomSaltGenerator is a SaltGenerator that generates a new random salt.
var RandomSaltGenerator SaltGenerator = &randomSaltGenerator{}

// Writer is an io.Writer that also implements io.ReaderFrom to
// allow for piping the data without extra allocations and copies.
// The LazyWrite and Flush methods allow a header to be
Expand All @@ -52,9 +70,10 @@ type shadowsocksWriter struct {
// else while needFlush could be true.
mu sync.Mutex
// Indicates that a concurrent flush is currently allowed.
needFlush bool
writer io.Writer
ssCipher shadowaead.Cipher
needFlush bool
writer io.Writer
ssCipher shadowaead.Cipher
saltGenerator SaltGenerator
// Wrapper for input that arrives as a slice.
byteWrapper bytes.Reader
// Number of plaintext bytes that are currently buffered.
Expand All @@ -68,16 +87,19 @@ type shadowsocksWriter struct {

// NewShadowsocksWriter creates a Writer that encrypts the given Writer using
// the shadowsocks protocol with the given shadowsocks cipher.
func NewShadowsocksWriter(writer io.Writer, ssCipher shadowaead.Cipher) Writer {
return &shadowsocksWriter{writer: writer, ssCipher: ssCipher}
func NewShadowsocksWriter(writer io.Writer, ssCipher shadowaead.Cipher, saltGenerator SaltGenerator) Writer {
if saltGenerator == nil {
saltGenerator = RandomSaltGenerator
}
return &shadowsocksWriter{writer: writer, ssCipher: ssCipher, saltGenerator: saltGenerator}
}

// init generates a random salt, sets up the AEAD object and writes
// the salt to the inner Writer.
func (sw *shadowsocksWriter) init() (err error) {
if sw.aead == nil {
salt := make([]byte, sw.ssCipher.SaltSize())
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
if err := sw.saltGenerator.GetSalt(salt); err != nil {
return fmt.Errorf("failed to generate salt: %v", err)
}
sw.aead, err = sw.ssCipher.Encrypter(salt)
Expand Down
10 changes: 5 additions & 5 deletions shadowsocks/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func TestEndToEnd(t *testing.T) {
cipher := newTestCipher(t)

connReader, connWriter := io.Pipe()
writer := NewShadowsocksWriter(connWriter, cipher)
writer := NewShadowsocksWriter(connWriter, cipher, RandomSaltGenerator)
reader := NewShadowsocksReader(connReader, cipher)
expected := "Test"
go func() {
Expand All @@ -180,7 +180,7 @@ func TestEndToEnd(t *testing.T) {
func TestLazyWriteFlush(t *testing.T) {
cipher := newTestCipher(t)
buf := new(bytes.Buffer)
writer := NewShadowsocksWriter(buf, cipher)
writer := NewShadowsocksWriter(buf, cipher, RandomSaltGenerator)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
if n != len(header) {
Expand Down Expand Up @@ -241,7 +241,7 @@ func TestLazyWriteFlush(t *testing.T) {
func TestLazyWriteConcat(t *testing.T) {
cipher := newTestCipher(t)
buf := new(bytes.Buffer)
writer := NewShadowsocksWriter(buf, cipher)
writer := NewShadowsocksWriter(buf, cipher, RandomSaltGenerator)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
if n != len(header) {
Expand Down Expand Up @@ -295,7 +295,7 @@ func TestLazyWriteConcat(t *testing.T) {
func TestLazyWriteOversize(t *testing.T) {
cipher := newTestCipher(t)
buf := new(bytes.Buffer)
writer := NewShadowsocksWriter(buf, cipher)
writer := NewShadowsocksWriter(buf, cipher, RandomSaltGenerator)
N := 25000 // More than one block, less than two.
data := make([]byte, N)
for i := range data {
Expand Down Expand Up @@ -335,7 +335,7 @@ func TestLazyWriteOversize(t *testing.T) {
func TestLazyWriteConcurrentFlush(t *testing.T) {
cipher := newTestCipher(t)
buf := new(bytes.Buffer)
writer := NewShadowsocksWriter(buf, cipher)
writer := NewShadowsocksWriter(buf, cipher, RandomSaltGenerator)
header := []byte{1, 2, 3, 4}
n, err := writer.LazyWrite(header)
if n != len(header) {
Expand Down
35 changes: 26 additions & 9 deletions shadowsocks/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
onet "github.com/Jigsaw-Code/outline-ss-server/net"
logging "github.com/op/go-logging"

"github.com/shadowsocks/go-shadowsocks2/shadowaead"
"github.com/shadowsocks/go-shadowsocks2/socks"
)

Expand Down Expand Up @@ -56,30 +57,42 @@ func debugTCP(cipherID, template string, val interface{}) {
}
}

func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, onet.DuplexConn, []byte, time.Duration, error) {
clientIP := remoteIP(clientConn)
type recordingSaltGenerator struct {
saltGenerator SaltGenerator
replayCache *ReplayCache
keyID string
}

func (sg *recordingSaltGenerator) GetSalt(salt []byte) error {
err := sg.saltGenerator.GetSalt(salt)
if err != nil {
return err
}
_ = sg.replayCache.Add(sg.keyID, salt)
return nil
}

func findAccessKey(clientReader io.Reader, clientIP net.IP, cipherList CipherList) (string, shadowaead.Cipher, io.Reader, []byte, time.Duration, error) {
// We snapshot the list because it may be modified while we use it.
tcpTrialSize, ciphers := cipherList.SnapshotForClientIP(clientIP)
firstBytes := make([]byte, tcpTrialSize)
if n, err := io.ReadFull(clientConn, firstBytes); err != nil {
return "", clientConn, nil, 0, fmt.Errorf("Reading header failed after %d bytes: %v", n, err)
if n, err := io.ReadFull(clientReader, firstBytes); err != nil {
return "", nil, clientReader, nil, 0, fmt.Errorf("Reading header failed after %d bytes: %v", n, err)
}

findStartTime := time.Now()
entry, elt := findEntry(firstBytes, ciphers)
timeToCipher := time.Now().Sub(findStartTime)
if entry == nil {
// TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
return "", clientConn, nil, timeToCipher, fmt.Errorf("Could not find valid TCP cipher")
return "", nil, clientReader, nil, timeToCipher, fmt.Errorf("Could not find valid TCP cipher")
}

// Move the active cipher to the front, so that the search is quicker next time.
cipherList.MarkUsedByClientIP(elt, clientIP)
id, cipher := entry.ID, entry.Cipher
ssr := NewShadowsocksReader(io.MultiReader(bytes.NewReader(firstBytes), clientConn), cipher)
ssw := NewShadowsocksWriter(clientConn, cipher)
salt := firstBytes[:cipher.SaltSize()]
return id, onet.WrapConn(clientConn, ssr, ssw).(onet.DuplexConn), salt, timeToCipher, nil
return id, cipher, io.MultiReader(bytes.NewReader(firstBytes), clientReader), salt, timeToCipher, nil
}

// Implements a trial decryption search. This assumes that all ciphers are AEAD.
Expand Down Expand Up @@ -234,7 +247,7 @@ func (s *tcpService) handleConnection(listenerPort int, clientConn onet.DuplexCo
clientConn.SetReadDeadline(connStart.Add(s.readTimeout))
var proxyMetrics metrics.ProxyMetrics
clientConn = metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
keyID, clientConn, salt, timeToCipher, keyErr := findAccessKey(clientConn, s.ciphers)
keyID, cipher, clientReader, salt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), s.ciphers)

connError := func() *onet.ConnectionError {
if keyErr != nil {
Expand All @@ -250,6 +263,10 @@ func (s *tcpService) handleConnection(listenerPort int, clientConn onet.DuplexCo
}
// Clear the authentication deadline
clientConn.SetReadDeadline(time.Time{})

ssr := NewShadowsocksReader(clientReader, cipher)
ssw := NewShadowsocksWriter(clientConn, cipher, &recordingSaltGenerator{saltGenerator: RandomSaltGenerator, replayCache: s.replayCache, keyID: keyID})
clientConn = onet.WrapConn(clientConn, ssr, ssw)
return proxyConnection(clientConn, &proxyMetrics, s.checkAllowedIP)
}()

Expand Down
12 changes: 7 additions & 5 deletions shadowsocks/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
if err != nil {
b.Fatalf("AcceptTCP failed: %v", err)
}
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).IP
b.StartTimer()
findAccessKey(clientConn, cipherList)
findAccessKey(clientConn, clientIP, cipherList)
b.StopTimer()
}
}
Expand Down Expand Up @@ -139,12 +140,13 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
for n := 0; n < b.N; n++ {
cipherNumber := byte(n % numCiphers)
reader, writer := io.Pipe()
addr := &net.TCPAddr{IP: net.IPv4(192, 0, 2, cipherNumber), Port: 54321}
clientIP := net.IPv4(192, 0, 2, cipherNumber)
addr := &net.TCPAddr{IP: clientIP, Port: 54321}
c := conn{clientAddr: addr, reader: reader, writer: writer}
cipher := cipherEntries[cipherNumber].Cipher
go NewShadowsocksWriter(writer, cipher).Write(MakeTestPayload(50))
go NewShadowsocksWriter(writer, cipher, RandomSaltGenerator).Write(MakeTestPayload(50))
b.StartTimer()
_, _, _, _, err := findAccessKey(&c, cipherList)
_, _, _, _, _, err := findAccessKey(&c, clientIP, cipherList)
b.StopTimer()
if err != nil {
b.Error(err)
Expand Down Expand Up @@ -205,7 +207,7 @@ func TestReplayDefense(t *testing.T) {
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.Cipher
reader, writer := io.Pipe()
go NewShadowsocksWriter(writer, cipher).Write([]byte{0})
go NewShadowsocksWriter(writer, cipher, RandomSaltGenerator).Write([]byte{0})
preamble := make([]byte, 32+2+16)
if _, err := io.ReadFull(reader, preamble); err != nil {
t.Fatal(err)
Expand Down