diff --git a/cipher.go b/cipher.go new file mode 100644 index 00000000..079841d9 --- /dev/null +++ b/cipher.go @@ -0,0 +1,132 @@ +package main + +import ( + "crypto/aes" + "crypto/cipher" + "errors" + "fmt" + "net" + "strings" + + "golang.org/x/crypto/chacha20poly1305" + + "github.com/Yawning/chacha20" + "github.com/riobard/go-shadowsocks2/core" + "github.com/riobard/go-shadowsocks2/shadowaead" + "github.com/riobard/go-shadowsocks2/shadowstream" +) + +// ErrKeySize means the supplied key size does not meet the requirement of cipher choosed. +var ErrKeySize = errors.New("key size error") + +func pickCipher(name string, key []byte) (core.StreamConnCipher, core.PacketConnCipher, error) { + + switch strings.ToLower(name) { + case "aes-128-gcm", "aes-192-gcm", "aes-256-gcm": + aead, err := aesGCM(key, 0) // 0 for standard 12-byte nonce + return aeadStream(aead), aeadPacket(aead), err + + case "aes-128-gcm-16", "aes-192-gcm-16", "aes-256-gcm-16": + aead, err := aesGCM(key, 16) // 16-byte nonce for better collision avoidance + return aeadStream(aead), aeadPacket(aead), err + + case "chacha20-ietf-poly1305": + aead, err := chacha20poly1305.New(key) + return aeadStream(aead), aeadPacket(aead), err + + case "aes-128-ctr", "aes-192-ctr", "aes-256-ctr": + ciph, err := aesCTR(key) + return streamStream(ciph), streamPacket(ciph), err + + case "aes-128-cfb", "aes-192-cfb", "aes-256-cfb": + ciph, err := aesCFB(key) + return streamStream(ciph), streamPacket(ciph), err + + case "chacha20-ietf": + if len(key) != chacha20.KeySize { + return nil, nil, ErrKeySize + } + k := chacha20ietfkey(key) + return streamStream(k), streamPacket(k), nil + + case "dummy": // only for benchmarking and debugging + return dummyStream(), dummyPacket(), nil + + default: + err := fmt.Errorf("cipher not supported: %s", name) + return nil, nil, err + } +} + +func dummyStream() core.StreamConnCipher { + return func(c net.Conn) net.Conn { return c } +} +func dummyPacket() core.PacketConnCipher { + return func(c net.PacketConn) net.PacketConn { return c } +} + +func aeadStream(aead cipher.AEAD) core.StreamConnCipher { + return func(c net.Conn) net.Conn { return shadowaead.NewConn(c, aead) } +} +func aeadPacket(aead cipher.AEAD) core.PacketConnCipher { + return func(c net.PacketConn) net.PacketConn { return shadowaead.NewPacketConn(c, aead) } +} + +func aesGCM(key []byte, nonceSize int) (cipher.AEAD, error) { + blk, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if nonceSize > 0 { + return cipher.NewGCMWithNonceSize(blk, nonceSize) + } + return cipher.NewGCM(blk) // standard 12-byte nonce +} + +func streamStream(ciph shadowstream.Cipher) core.StreamConnCipher { + return func(c net.Conn) net.Conn { return shadowstream.NewConn(c, ciph) } +} + +func streamPacket(ciph shadowstream.Cipher) core.PacketConnCipher { + return func(c net.PacketConn) net.PacketConn { return shadowstream.NewPacketConn(c, ciph) } +} + +type ctrStream struct{ cipher.Block } + +func (b *ctrStream) IVSize() int { return b.BlockSize() } +func (b *ctrStream) Encrypter(iv []byte) cipher.Stream { return cipher.NewCTR(b, iv) } +func (b *ctrStream) Decrypter(iv []byte) cipher.Stream { return b.Encrypter(iv) } + +func aesCTR(key []byte) (shadowstream.Cipher, error) { + blk, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return &ctrStream{blk}, nil +} + +type cfbStream struct{ cipher.Block } + +func (b *cfbStream) IVSize() int { return b.BlockSize() } +func (b *cfbStream) Encrypter(iv []byte) cipher.Stream { return cipher.NewCFBEncrypter(b, iv) } +func (b *cfbStream) Decrypter(iv []byte) cipher.Stream { return cipher.NewCFBDecrypter(b, iv) } + +func aesCFB(key []byte) (shadowstream.Cipher, error) { + blk, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return &ctrStream{blk}, nil +} + +type chacha20ietfkey []byte + +func (k chacha20ietfkey) IVSize() int { return chacha20.INonceSize } +func (k chacha20ietfkey) Encrypter(iv []byte) cipher.Stream { + ciph, err := chacha20.NewCipher(k, iv) + if err != nil { + panic(err) + } + return ciph +} +func (k chacha20ietfkey) Decrypter(iv []byte) cipher.Stream { return k.Encrypter(iv) } diff --git a/core/doc.go b/core/doc.go new file mode 100644 index 00000000..553729a3 --- /dev/null +++ b/core/doc.go @@ -0,0 +1,2 @@ +// Package core provides essential interfaces for Shadowsocks +package core diff --git a/core/packet.go b/core/packet.go new file mode 100644 index 00000000..67cf2f8b --- /dev/null +++ b/core/packet.go @@ -0,0 +1,10 @@ +package core + +import "net" + +type PacketConnCipher func(net.PacketConn) net.PacketConn + +func ListenPacket(network, address string, ciph PacketConnCipher) (net.PacketConn, error) { + c, err := net.ListenPacket(network, address) + return ciph(c), err +} diff --git a/core/stream.go b/core/stream.go new file mode 100644 index 00000000..2868db05 --- /dev/null +++ b/core/stream.go @@ -0,0 +1,25 @@ +package core + +import "net" + +type StreamConnCipher func(net.Conn) net.Conn + +type listener struct { + net.Listener + StreamConnCipher +} + +func Listen(network, address string, ciph StreamConnCipher) (net.Listener, error) { + l, err := net.Listen(network, address) + return &listener{l, ciph}, err +} + +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + return l.StreamConnCipher(c), err +} + +func Dial(network, address string, ciph StreamConnCipher) (net.Conn, error) { + c, err := net.Dial(network, address) + return ciph(c), err +} diff --git a/main.go b/main.go new file mode 100644 index 00000000..90cc7c39 --- /dev/null +++ b/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "encoding/hex" + "flag" + "log" + "os" + "os/signal" + "strings" + "syscall" + "time" +) + +var config struct { + Verbose bool + UDPTimeout time.Duration +} + +func logf(f string, v ...interface{}) { + if config.Verbose { + log.Printf(f, v...) + } +} + +func main() { + + var flags struct { + Client string + Server string + Cipher string + Key string + Socks string + RedirTCP string + RedirTCP6 string + TCPTun string + UDPTun string + } + + flag.BoolVar(&config.Verbose, "verbose", false, "verbose mode") + flag.StringVar(&flags.Cipher, "cipher", "aes-128-gcm", "cipher to encrypt/decrypt") + flag.StringVar(&flags.Key, "key", "", "secret key in hexadecimal") + flag.StringVar(&flags.Server, "s", "", "server listen address") + flag.StringVar(&flags.Client, "c", "", "client connect address") + flag.StringVar(&flags.Socks, "socks", ":1080", "(client-only) SOCKS listen address") + flag.StringVar(&flags.RedirTCP, "redir", "", "(client-only) redirect TCP from this address") + flag.StringVar(&flags.RedirTCP6, "redir6", "", "(client-only) redirect TCP IPv6 from this address") + flag.StringVar(&flags.TCPTun, "tcptun", "", "(client-only) TCP tunnel (laddr1=raddr1,laddr2=raddr2,...)") + flag.StringVar(&flags.UDPTun, "udptun", "", "(client-only) UDP tunnel (laddr1=raddr1,laddr2=raddr2,...)") + flag.DurationVar(&config.UDPTimeout, "udptimeout", 5*time.Minute, "UDP tunnel timeout") + flag.Parse() + + key, err := hex.DecodeString(flags.Key) + if err != nil { + log.Fatalf("failed to parse key: %v", err) + } + + streamCipher, packetCipher, err := pickCipher(flags.Cipher, key) + if err != nil { + log.Fatalf("failed to create cipher %s: %v", flags.Cipher, err) + } + + if flags.Client != "" { // client mode + if flags.UDPTun != "" { + for _, tun := range strings.Split(flags.UDPTun, ",") { + p := strings.Split(tun, "=") + go udpLocal(p[0], flags.Client, p[1], packetCipher) + } + } + + if flags.TCPTun != "" { + for _, tun := range strings.Split(flags.TCPTun, ",") { + p := strings.Split(tun, "=") + go tcpTun(p[0], flags.Client, p[1], streamCipher) + } + } + + if flags.Socks != "" { + go socksLocal(flags.Socks, flags.Client, streamCipher) + } + + if flags.RedirTCP != "" { + go redirLocal(flags.RedirTCP, flags.Client, streamCipher) + } + + if flags.RedirTCP6 != "" { + go redir6Local(flags.RedirTCP6, flags.Client, streamCipher) + } + } else if flags.Server != "" { // server mode + go udpRemote(flags.Server, packetCipher) + go tcpRemote(flags.Server, streamCipher) + } else { + flag.Usage() + return + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + <-sigCh +} diff --git a/shadowaead/doc.go b/shadowaead/doc.go new file mode 100644 index 00000000..9cb1427f --- /dev/null +++ b/shadowaead/doc.go @@ -0,0 +1,35 @@ +/* +Package shadowaead implements a simple AEAD-protected secure protocol. + +In general, there are two types of connections: stream-oriented and packet-oriented. +Stream-oriented connections (e.g. TCP) assume reliable and orderly delivery of bytes. +Packet-oriented connections (e.g. UDP) assume unreliable and out-of-order delivery of packets, +where each packet is either delivered intact or lost. + +An encrypted stream starts with a nonce, followed by any number of encrypted records. +Each encrypted record has the following structure: + + [encrypted payload length] + [payload length tag] + [encrypted payload] + [payload tag] + +Payload length is 2-byte unsigned big-endian integer capped at 0x3FFF (16383). +The higher 2 bits are reserved and must be set to zero. The first AEAD encrypt/decrypt +operation uses the nonce at the beginning of the stream. After each encrypt/decrypt operation, +the nonce is incremented by one as if it were an unsigned little-endian integer. + + +Each encrypted packet transmitted on a packet-oriented connection has the following structure: + + [nonce] + [encrypted payload] + [payload tag] + +Packets are encrypted/decrypted independently. + +In both stream-oriented and packet-oriented connections, length of nonce and tag varies +depending on which AEAD is used. Nonces are assumed to be randomly generated and +of sufficient length (at least 12 bytes). +*/ +package shadowaead diff --git a/shadowaead/packet.go b/shadowaead/packet.go new file mode 100644 index 00000000..09e851b4 --- /dev/null +++ b/shadowaead/packet.go @@ -0,0 +1,79 @@ +package shadowaead + +import ( + "crypto/cipher" + "crypto/rand" + "errors" + "io" + "net" +) + +// ErrShortPacket means that the packet is too short for a valid encrypted packet. +var ErrShortPacket = errors.New("shadow: short packet") + +// Pack encrypts plaintext using aead with a randomly generated nonce and +// returns a slice of dst containing the encrypted packet and any error occurred. +// Ensure len(dst) >= aead.NonceSize() + len(plaintext) + aead.Overhead(). +func Pack(dst, plaintext []byte, aead cipher.AEAD) ([]byte, error) { + nsiz := aead.NonceSize() + if len(dst) < nsiz+len(plaintext)+aead.Overhead() { + return nil, io.ErrShortBuffer + } + + nonce := dst[:nsiz] + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + b := aead.Seal(dst[nsiz:nsiz], nonce, plaintext, nil) + return dst[:nsiz+len(b)], nil +} + +// Unpack decrypts pkt using aead and returns a slice of dst containing the decrypted payload and any error occurred. +// Ensure len(dst) >= len(pkt) - aead.NonceSize() - aead.Overhead(). +func Unpack(dst, pkt []byte, aead cipher.AEAD) ([]byte, error) { + nsiz := aead.NonceSize() + + if len(pkt) < nsiz+aead.Overhead() { + return nil, ErrShortPacket + } + + if len(dst) < len(pkt)-nsiz-aead.Overhead() { + return nil, io.ErrShortBuffer + } + + b, err := aead.Open(dst[:0], pkt[:nsiz], pkt[nsiz:], nil) + return b, err +} + +// packetConn encrypts net.packetConn with cipher.AEAD +type packetConn struct { + net.PacketConn + cipher.AEAD +} + +// NewPacketConn wraps a net.PacketConn with AEAD protection. +func NewPacketConn(c net.PacketConn, aead cipher.AEAD) net.PacketConn { + return &packetConn{PacketConn: c, AEAD: aead} +} + +// WriteTo encrypts b and write to addr using the embedded PacketConn. +func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { + buf := make([]byte, c.AEAD.NonceSize()+len(b)+c.AEAD.Overhead()) + buf, err := Pack(buf, b, c.AEAD) + if err != nil { + return 0, err + } + _, err = c.PacketConn.WriteTo(buf, addr) + return len(b), err +} + +// ReadFrom reads from the embedded PacketConn and decrypts into b. +func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, addr, err := c.PacketConn.ReadFrom(b) + if err != nil { + return n, addr, err + } + b, err = Unpack(b, b[:n], c.AEAD) + return len(b), addr, err +} diff --git a/shadowaead/stream.go b/shadowaead/stream.go new file mode 100644 index 00000000..603ebe51 --- /dev/null +++ b/shadowaead/stream.go @@ -0,0 +1,249 @@ +package shadowaead + +import ( + "bytes" + "crypto/cipher" + "crypto/rand" + "io" + "net" +) + +// payloadSizeMask is the maximum size of payload in bytes. +const payloadSizeMask = 0x3FFF // 16*1024 - 1 + +type writer struct { + io.Writer + cipher.AEAD + nonce []byte + buf []byte +} + +// NewWriter wraps an io.Writer with AEAD encryption. +func NewWriter(w io.Writer, aead cipher.AEAD) io.Writer { + return &writer{Writer: w, AEAD: aead} +} + +func (w *writer) init() error { + w.buf = make([]byte, 2+w.Overhead()+payloadSizeMask+w.Overhead()) + w.nonce = make([]byte, w.NonceSize()) + _, err := io.ReadFull(rand.Reader, w.nonce) + if err != nil { + return err + } + _, err = w.Writer.Write(w.nonce) + return err +} + +// Write encrypts b and writes to the embedded io.Writer. +func (w *writer) Write(b []byte) (int, error) { + n, err := w.ReadFrom(bytes.NewBuffer(b)) + return int(n), err +} + +// ReadFrom reads from the given io.Reader until EOF or error, encrypts and +// writes to the embedded io.Writer. Returns number of bytes read from r and +// any error encountered. +func (w *writer) ReadFrom(r io.Reader) (n int64, err error) { + if w.nonce == nil { + if err := w.init(); err != nil { + return 0, err + } + } + + for { + buf := w.buf + payloadBuf := buf[2+w.Overhead() : 2+w.Overhead()+payloadSizeMask] + nr, er := r.Read(payloadBuf) + + if nr > 0 { + n += int64(nr) + buf = buf[:2+w.Overhead()+nr+w.Overhead()] + payloadBuf = payloadBuf[:nr] + buf[0], buf[1] = byte(nr>>8), byte(nr) // big-endian payload size + w.Seal(buf[:0], w.nonce, buf[:2], nil) + increment(w.nonce) + + w.Seal(payloadBuf[:0], w.nonce, payloadBuf, nil) + increment(w.nonce) + + _, ew := w.Writer.Write(buf) + if ew != nil { + err = ew + break + } + } + + if er != nil { + if er != io.EOF { // ignore EOF as per io.ReaderFrom contract + err = er + } + break + } + } + + return n, err +} + +type reader struct { + io.Reader + cipher.AEAD + nonce []byte + buf []byte + leftover []byte +} + +// NewReader wraps an io.Reader with AEAD decryption. +func NewReader(r io.Reader, aead cipher.AEAD) io.Reader { + return &reader{Reader: r, AEAD: aead} +} + +func (r *reader) init() error { + r.buf = make([]byte, payloadSizeMask+r.Overhead()) + r.nonce = make([]byte, r.NonceSize()) + _, err := io.ReadFull(r.Reader, r.nonce) + return err +} + +// read and decrypt a record into the internal buffer. Return decrypted payload length and any error encountered. +func (r *reader) read() (int, error) { + if r.nonce == nil { + if err := r.init(); err != nil { + return 0, err + } + } + + // decrypt payload size + buf := r.buf[:2+r.Overhead()] + _, err := io.ReadFull(r.Reader, buf) + if err != nil { + return 0, err + } + + _, err = r.Open(buf[:0], r.nonce, buf, nil) + increment(r.nonce) + if err != nil { + return 0, err + } + + size := (int(buf[0])<<8 + int(buf[1])) & payloadSizeMask + + // decrypt payload + buf = r.buf[:size+r.Overhead()] + _, err = io.ReadFull(r.Reader, buf) + if err != nil { + return 0, err + } + + _, err = r.Open(buf[:0], r.nonce, buf, nil) + increment(r.nonce) + if err != nil { + return 0, err + } + + return size, nil +} + +// Read reads from the embedded io.Reader, decrypts and writes to b. +func (r *reader) Read(b []byte) (int, error) { + // copy decrypted bytes (if any) from previous record first + if len(r.leftover) > 0 { + n := copy(b, r.leftover) + r.leftover = r.leftover[n:] + return n, nil + } + + n, err := r.read() + m := copy(b, r.buf[:n]) + if m < n { // insufficient len(b), keep leftover for next read + r.leftover = r.buf[m:n] + } + return m, err +} + +// WriteTo reads from the embedded io.Reader, decrypts and writes to w until +// there's no more data to write or when an error occurs. Return number of +// bytes written to w and any error encountered. +func (r *reader) WriteTo(w io.Writer) (n int64, err error) { + for { + nr, er := r.read() + if nr > 0 { + nw, ew := w.Write(r.buf[:nr]) + n += int64(nw) + + if ew != nil { + err = ew + break + } + } + + if er != nil { + if er != io.EOF { // ignore EOF as per io.Copy contract (using src.WriteTo shortcut) + err = er + } + break + } + } + + return n, err +} + +type streamConn struct { + net.Conn + r *reader + w *writer +} + +type closeWriter interface { + CloseWrite() error +} + +type closeReader interface { + CloseRead() error +} + +func (c *streamConn) Read(b []byte) (int, error) { + return c.r.Read(b) +} + +func (c *streamConn) WriteTo(w io.Writer) (int64, error) { + return c.r.WriteTo(w) +} + +func (c *streamConn) Write(b []byte) (int, error) { + return c.w.Write(b) +} + +func (c *streamConn) ReadFrom(r io.Reader) (int64, error) { + return c.w.ReadFrom(r) +} + +func (c *streamConn) CloseRead() error { + if c, ok := c.Conn.(closeReader); ok { + return c.CloseRead() + } + return nil +} + +func (c *streamConn) CloseWrite() error { + if c, ok := c.Conn.(closeWriter); ok { + return c.CloseWrite() + } + return nil +} + +// NewConn wraps a stream-oriented net.Conn with AEAD protection. +func NewConn(c net.Conn, aead cipher.AEAD) net.Conn { + r := &reader{Reader: c, AEAD: aead} + w := &writer{Writer: c, AEAD: aead} + return &streamConn{Conn: c, r: r, w: w} +} + +// increment little-endian encoded unsigned integer b. Wrap around on overflow. +func increment(b []byte) { + for i := range b { + b[i]++ + if b[i] != 0 { + return + } + } +} diff --git a/shadowstream/doc.go b/shadowstream/doc.go new file mode 100644 index 00000000..f25f5b09 --- /dev/null +++ b/shadowstream/doc.go @@ -0,0 +1,13 @@ +/* +Package shadowstream implements the original Shadowsocks protocol protected by stream cipher. +*/ +package shadowstream + +import "crypto/cipher" + +// Cipher generates a pair of stream ciphers for encryption and decryption. +type Cipher interface { + IVSize() int + Encrypter(iv []byte) cipher.Stream + Decrypter(iv []byte) cipher.Stream +} diff --git a/shadowstream/packet.go b/shadowstream/packet.go new file mode 100644 index 00000000..31d2bc80 --- /dev/null +++ b/shadowstream/packet.go @@ -0,0 +1,72 @@ +package shadowstream + +import ( + "crypto/rand" + "errors" + "io" + "net" +) + +// ErrShortPacket means the packet is too short to be a valid encrypted packet. +var ErrShortPacket = errors.New("short packet") + +// Pack encrypts plaintext using stream cipher s and a random IV. +// Returns a slice of dst containing random IV and ciphertext. +// Ensure len(dst) >= s.IVSize() + len(plaintext). +func Pack(dst, plaintext []byte, s Cipher) ([]byte, error) { + if len(dst) < s.IVSize()+len(plaintext) { + return nil, io.ErrShortBuffer + } + iv := dst[:s.IVSize()] + _, err := io.ReadFull(rand.Reader, iv) + if err != nil { + return nil, err + } + + s.Encrypter(iv).XORKeyStream(dst[len(iv):], plaintext) + return dst[:len(iv)+len(plaintext)], nil +} + +// Unpack decrypts pkt using stream cipher s. +// Returns a slice of dst containing decrypted plaintext. +func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) { + if len(pkt) < s.IVSize() { + return nil, ErrShortPacket + } + + if len(dst) < len(pkt)-s.IVSize() { + return nil, io.ErrShortBuffer + } + iv := pkt[:s.IVSize()] + s.Decrypter(iv).XORKeyStream(dst, pkt[len(iv):]) + return dst[:len(pkt)-len(iv)], nil +} + +type packetConn struct { + net.PacketConn + Cipher +} + +// NewPacketConn wraps a net.PacketConn with stream cipher encryption/decryption. +func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn { + return &packetConn{PacketConn: c, Cipher: ciph} +} + +func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { + buf := make([]byte, c.IVSize()+len(b)) + _, err := Pack(buf, b, c.Cipher) + if err != nil { + return 0, err + } + _, err = c.PacketConn.WriteTo(buf, addr) + return len(b), err +} + +func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, addr, err := c.PacketConn.ReadFrom(b) + if err != nil { + return n, addr, err + } + b, err = Unpack(b, b[:n], c.Cipher) + return len(b), addr, err +} diff --git a/shadowstream/stream.go b/shadowstream/stream.go new file mode 100644 index 00000000..f7584ec6 --- /dev/null +++ b/shadowstream/stream.go @@ -0,0 +1,174 @@ +package shadowstream + +import ( + "bytes" + "crypto/cipher" + "crypto/rand" + "io" + "net" +) + +const bufSize = 32 * 1024 + +type writer struct { + io.Writer + Cipher + cipher.Stream + buf []byte +} + +// NewWriter wraps an io.Writer with stream cipher encryption. +func NewWriter(w io.Writer, s Cipher) io.Writer { + return &writer{ + Writer: w, + Cipher: s, + } +} + +func (w *writer) ReadFrom(r io.Reader) (n int64, err error) { + if w.Stream == nil { + w.buf = make([]byte, bufSize) + iv := w.buf[:w.IVSize()] + if _, err = io.ReadFull(rand.Reader, iv); err != nil { + return + } + if _, err = w.Writer.Write(iv); err != nil { + return + } + + w.Stream = w.Encrypter(iv) + } + + for { + buf := w.buf + nr, er := r.Read(buf) + if nr > 0 { + n += int64(nr) + buf = buf[:nr] + w.XORKeyStream(buf, buf) + _, ew := w.Writer.Write(buf) + if ew != nil { + err = ew + return + } + } + + if er != nil { + if er != io.EOF { // ignore EOF as per io.ReaderFrom contract + err = er + } + return + } + } +} + +func (w *writer) Write(b []byte) (int, error) { + n, err := w.ReadFrom(bytes.NewBuffer(b)) + return int(n), err +} + +type reader struct { + io.Reader + Cipher + cipher.Stream + buf []byte +} + +// NewReader wraps an io.Reader with stream cipher decryption. +func NewReader(r io.Reader, s Cipher) io.Reader { + return &reader{Reader: r, Cipher: s} +} + +func (r *reader) Read(b []byte) (int, error) { + if r.Stream == nil { + r.buf = make([]byte, bufSize) + iv := make([]byte, r.IVSize()) + if _, err := io.ReadFull(r.Reader, iv); err != nil { + return 0, err + } + + r.Stream = r.Decrypter(iv) + } + + n, err := r.Reader.Read(b) + if err != nil { + return 0, err + } + b = b[:n] + r.XORKeyStream(b, b) + return n, nil +} + +func (r *reader) WriteTo(w io.Writer) (n int64, err error) { + for { + buf := r.buf + nr, er := r.Read(buf) + if nr > 0 { + nw, ew := w.Write(buf[:nr]) + n += int64(nw) + + if ew != nil { + err = ew + return + } + } + + if er != nil { + if er != io.EOF { // ignore EOF as per io.Copy contract (using src.WriteTo shortcut) + err = er + } + return + } + } +} + +type conn struct { + net.Conn + r *reader + w *writer +} + +// NewConn wraps a stream-oriented net.Conn with stream cipher encryption/decryption. +func NewConn(c net.Conn, ciph Cipher) net.Conn { + r := &reader{Reader: c, Cipher: ciph} + w := &writer{Writer: c, Cipher: ciph} + return &conn{Conn: c, r: r, w: w} +} + +func (c *conn) Read(b []byte) (int, error) { + return c.r.Read(b) +} + +func (c *conn) WriteTo(w io.Writer) (int64, error) { + return c.r.WriteTo(w) +} + +func (c *conn) Write(b []byte) (int, error) { + return c.w.Write(b) +} + +func (c *conn) ReadFrom(r io.Reader) (int64, error) { + return c.w.ReadFrom(r) +} + +type closeWriter interface { + CloseWrite() error +} + +type closeReader interface { + CloseRead() error +} + +func (c *conn) CloseRead() error { + if c, ok := c.Conn.(closeReader); ok { + return c.CloseRead() + } + return nil +} + +func (c *conn) CloseWrite() error { + if c, ok := c.Conn.(closeWriter); ok { + return c.CloseWrite() + } + return nil +} diff --git a/socks/socks.go b/socks/socks.go new file mode 100644 index 00000000..2ee6d4c9 --- /dev/null +++ b/socks/socks.go @@ -0,0 +1,191 @@ +// Package socks implements essential parts of SOCKS protocol. +package socks + +import ( + "io" + "net" + "strconv" +) + +// SOCKS request commands as defined in RFC 1928 section 4. +const ( + CmdConnect = 1 + CmdBind = 2 + CmdUDPAssociate = 3 +) + +// SOCKS address types as defined in RFC 1928 section 5. +const ( + AtypIPv4 = 1 + AtypDomainName = 3 + AtypIPv6 = 4 +) + +// Error represents a SOCKS error +type Error byte + +func (err Error) Error() string { + return "SOCKS error: " + strconv.Itoa(int(err)) +} + +// SOCKS errors as defined in RFC 1928 section 6. +const ( + ErrGeneralFailure = Error(1) + ErrConnectionNotAllowed = Error(2) + ErrNetworkUnreachable = Error(3) + ErrHostUnreachable = Error(4) + ErrConnectionRefused = Error(5) + ErrTTLExpired = Error(6) + ErrCommandNotSupported = Error(7) + ErrAddressNotSupported = Error(8) +) + +// MaxAddrLen is the maximum size of SOCKS address in bytes. +const MaxAddrLen = 1 + 1 + 255 + 2 + +// MaxReqLen is the maximum size of SOCKS request in bytes. +const MaxReqLen = 1 + 1 + 1 + MaxAddrLen + +// Addr represents a SOCKS address as defined in RFC 1928 section 5. +type Addr []byte + +// String serializes SOCKS address a to string form. +func (a Addr) String() string { + var host, port string + + switch a[0] { // address type + case AtypDomainName: + host = string(a[2 : 2+a[1]]) + port = strconv.Itoa((int(a[2+a[1]]) << 8) | int(a[2+a[1]+1])) + case AtypIPv4: + host = net.IP(a[1 : 1+net.IPv4len]).String() + port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) + case AtypIPv6: + host = net.IP(a[1 : 1+net.IPv6len]).String() + port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) + } + + return net.JoinHostPort(host, port) +} + +// ReadAddr reads just enough bytes from r to get a valid Addr. +func ReadAddr(r io.Reader) (Addr, error) { + b := make([]byte, MaxAddrLen) + _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type + if err != nil { + return nil, err + } + + switch b[0] { + case AtypDomainName: + _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length + if err != nil { + return nil, err + } + _, err = io.ReadFull(r, b[2:2+b[1]+2]) + return b[:1+1+b[1]+2], err + case AtypIPv4: + _, err = io.ReadFull(r, b[1:1+net.IPv4len+2]) + return b[:1+net.IPv4len+2], err + case AtypIPv6: + _, err = io.ReadFull(r, b[1:1+net.IPv6len+2]) + return b[:1+net.IPv6len+2], err + } + + return nil, ErrAddressNotSupported +} + +// SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed. +func SplitAddr(b []byte) Addr { + addrLen := 1 + if len(b) < addrLen { + return nil + } + + switch b[0] { + case AtypDomainName: + if len(b) < 2 { + return nil + } + addrLen = 1 + 1 + int(b[1]) + 2 + case AtypIPv4: + addrLen = 1 + net.IPv4len + 2 + case AtypIPv6: + addrLen = 1 + net.IPv6len + 2 + default: + return nil + + } + + if len(b) < addrLen { + return nil + } + + return b[:addrLen] +} + +// ParseAddr parses the address in string s. Returns nil if failed. +func ParseAddr(s string) Addr { + var addr Addr + host, port, err := net.SplitHostPort(s) + if err != nil { + return nil + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + addr = make([]byte, 1+net.IPv4len+2) + addr[0] = AtypIPv4 + copy(addr[1:], ip4) + } else { + addr = make([]byte, 1+net.IPv6len+2) + addr[0] = AtypIPv6 + copy(addr[1:], ip) + } + } else { + if len(host) > 255 { + return nil + } + addr = make([]byte, 1+1+len(host)+2) + addr[0] = AtypDomainName + addr[1] = byte(len(host)) + copy(addr[2:], host) + } + + portnum, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil + } + + addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum) + + return addr +} + +// Handshake fast-tracks SOCKS initialization to get target address to connect. +func Handshake(rw io.ReadWriter) (Addr, error) { + // Read RFC 1928 section 4 for request and reply structure and sizes + buf := make([]byte, MaxReqLen) + + _, err := rw.Read(buf) // SOCKS version and auth methods + if err != nil { + return nil, err + } + + _, err = rw.Write([]byte{5, 0}) // SOCKS v5, no auth required + if err != nil { + return nil, err + } + + n, err := rw.Read(buf) // SOCKS request: VER, CMD, RSV, Addr + if err != nil { + return nil, err + } + buf = buf[:n] + + if buf[1] != CmdConnect { + return nil, ErrCommandNotSupported + } + + _, err = rw.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) // SOCKS v5, reply succeeded + return buf[3:], err // skip VER, CMD, RSV fields +} diff --git a/tcp.go b/tcp.go new file mode 100644 index 00000000..50654da9 --- /dev/null +++ b/tcp.go @@ -0,0 +1,166 @@ +package main + +import ( + "io" + "net" + + "github.com/riobard/go-shadowsocks2/core" + "github.com/riobard/go-shadowsocks2/socks" +) + +// Create a SOCKS server listening on addr and proxy to server. +func socksLocal(addr, server string, ciph core.StreamConnCipher) { + logf("SOCKS proxy %s <-> %s", addr, server) + tcpLocal(addr, server, ciph, func(c net.Conn) (socks.Addr, error) { return socks.Handshake(c) }) +} + +// Create a TCP tunnel from addr to target via server. +func tcpTun(addr, server, target string, ciph core.StreamConnCipher) { + tgt := socks.ParseAddr(target) + if tgt == nil { + logf("invalid target address %q", target) + return + } + logf("TCP tunnel %s <-> %s <-> %s", addr, server, target) + tcpLocal(addr, server, ciph, func(net.Conn) (socks.Addr, error) { return tgt, nil }) +} + +// Listen on addr and proxy to server to reach target from getAddr. +func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net.Conn) (socks.Addr, error)) { + ln, err := net.Listen("tcp", addr) + if err != nil { + logf("failed to listen on %s: %v", addr, err) + return + } + + for { + conn, err := ln.Accept() + if err != nil { + logf("failed to accept: %s", err) + continue + } + + tgt, err := getAddr(conn) + if err != nil { + logf("failed to get target address: %v", err) + continue + } + + go tcpLocalHandle(conn, server, tgt, ciph) + } +} + +func tcpLocalHandle(c net.Conn, server string, target socks.Addr, ciph core.StreamConnCipher) { + logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, target) + + defer c.Close() + + sc, err := core.Dial("tcp", server, ciph) + if err != nil { + logf("failed to connect to server %v: %v", server, err) + return + } + defer sc.Close() + + if _, err = sc.Write(target); err != nil { + logf("failed to send target address: %v", err) + return + } + + _, _, err = relay(sc, c) + if err != nil { + logf("relay error: %v", err) + return + } +} + +// Listen on addr for incoming connections. +func tcpRemote(addr string, ciph core.StreamConnCipher) { + ln, err := core.Listen("tcp", addr, ciph) + if err != nil { + logf("failed to listen on %s: %v", addr, err) + return + } + + logf("listening TCP on %s", addr) + for { + conn, err := ln.Accept() + if err != nil { + logf("failed to accept: %s", err) + continue + } + go tcpRemoteHandle(conn) + } +} + +func tcpRemoteHandle(c net.Conn) { + defer c.Close() + + addr, err := socks.ReadAddr(c) + if err != nil { + logf("failed to read address: %v", err) + return + } + logf("proxy %s <-> %s", c.RemoteAddr(), addr) + + conn, err := net.Dial("tcp", addr.String()) + if err != nil { + logf("failed to connect to target: %s", err) + return + } + defer conn.Close() + + _, _, err = relay(c, conn) + if err != nil { + logf("relay error: %v", err) + return + } +} + +// relay copies between left and right bidirectionally. Returns number of +// bytes copied from right to left, from left to right, and any error occurred. +func relay(left, right io.ReadWriter) (int64, int64, error) { + type res struct { + N int64 + Err error + } + ch := make(chan res) + + go func() { + n, err := copyHalfClose(right, left) + ch <- res{n, err} + }() + + n, err := copyHalfClose(left, right) + rs := <-ch + + if err == nil { + err = rs.Err + } + return n, rs.N, err +} + +type closeWriter interface { + CloseWrite() error +} + +type closeReader interface { + CloseRead() error +} + +// copyHalfClose copies to dst from src and optionally closes dst for writing and src for reading. +func copyHalfClose(dst io.Writer, src io.Reader) (int64, error) { + defer func() { + // half-close to wake up other goroutines blocking on dst and src + + if c, ok := dst.(closeWriter); ok { + c.CloseWrite() + } + + if c, ok := src.(closeReader); ok { + c.CloseRead() + } + }() + + return io.Copy(dst, src) // will use io.ReaderFrom or io.WriterTo shortcut if possible +} diff --git a/tcp_linux.go b/tcp_linux.go new file mode 100644 index 00000000..9cc263fa --- /dev/null +++ b/tcp_linux.go @@ -0,0 +1,86 @@ +package main + +import ( + "errors" + "net" + "syscall" + "unsafe" + + "github.com/riobard/go-shadowsocks2/core" + "github.com/riobard/go-shadowsocks2/socks" +) + +// Listen on addr for netfilter redirected TCP connections +func redirLocal(addr, server string, ciph core.StreamConnCipher) { + logf("TCP redirect %s <-> %s", addr, server) + tcpLocal(addr, server, ciph, func(c net.Conn) (socks.Addr, error) { return getOrigDst(c, false) }) +} + +// Listen on addr for netfilter redirected TCP IPv6 connections. +func redir6Local(addr, server string, ciph core.StreamConnCipher) { + logf("TCP6 redirect %s <-> %s", addr, server) + tcpLocal(addr, server, ciph, func(c net.Conn) (socks.Addr, error) { return getOrigDst(c, true) }) +} + +// Get the original destination of a TCP connection. +func getOrigDst(conn net.Conn, ipv6 bool) (socks.Addr, error) { + c, ok := conn.(*net.TCPConn) + if !ok { + return nil, errors.New("only work with TCP connection") + } + f, err := c.File() + if err != nil { + return nil, err + } + defer f.Close() + + fd := f.Fd() + + // The File() call above puts both the original socket fd and the file fd in blocking mode. + // Set the file fd back to non-blocking mode and the original socket fd will become non-blocking as well. + // Otherwise blocking I/O will waste OS threads. + if err := syscall.SetNonblock(int(fd), true); err != nil { + return nil, err + } + + if ipv6 { + return ipv6_getorigdst(fd) + } + + return getorigdst(fd) +} + +// Call getorigdst() from linux/net/ipv4/netfilter/nf_conntrack_l3proto_ipv4.c +func getorigdst(fd uintptr) (socks.Addr, error) { + const SO_ORIGINAL_DST = 80 // from linux/include/uapi/linux/netfilter_ipv4.h + raw := syscall.RawSockaddrInet4{} + siz := unsafe.Sizeof(raw) + if _, _, errno := syscall.Syscall6(syscall.SYS_GETSOCKOPT, fd, syscall.IPPROTO_IP, SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); errno != 0 { + return nil, errno + } + + addr := make([]byte, 1+net.IPv4len+2) + addr[0] = socks.AtypIPv4 + copy(addr[1:1+net.IPv4len], raw.Addr[:]) + port := (*[2]byte)(unsafe.Pointer(&raw.Port)) // big-endian + addr[1+net.IPv4len], addr[1+net.IPv4len+1] = port[0], port[1] + return addr, nil +} + +// Call ipv6_getorigdst() from linux/net/ipv6/netfilter/nf_conntrack_l3proto_ipv6.c +// NOTE: I haven't tried yet but it should work since Linux 3.8. +func ipv6_getorigdst(fd uintptr) (socks.Addr, error) { + const IP6T_SO_ORIGINAL_DST = 80 // from linux/include/uapi/linux/netfilter_ipv6/ip6_tables.h + raw := syscall.RawSockaddrInet6{} + siz := unsafe.Sizeof(raw) + if _, _, errno := syscall.Syscall6(syscall.SYS_GETSOCKOPT, fd, syscall.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); errno != 0 { + return nil, errno + } + + addr := make([]byte, 1+net.IPv6len+2) + addr[0] = socks.AtypIPv6 + copy(addr[1:1+net.IPv6len], raw.Addr[:]) + port := (*[2]byte)(unsafe.Pointer(&raw.Port)) // big-endian + addr[1+net.IPv6len], addr[1+net.IPv6len+1] = port[0], port[1] + return addr, nil +} diff --git a/tcp_other.go b/tcp_other.go new file mode 100644 index 00000000..f5873e69 --- /dev/null +++ b/tcp_other.go @@ -0,0 +1,13 @@ +// +build !linux + +package main + +import "github.com/riobard/go-shadowsocks2/core" + +func redirLocal(addr, server string, ciph core.StreamConnCipher) { + logf("TCP redirect not supported") +} + +func redir6Local(addr, server string, ciph core.StreamConnCipher) { + logf("TCP6 redirect not supported") +} diff --git a/udp.go b/udp.go new file mode 100644 index 00000000..ac472b36 --- /dev/null +++ b/udp.go @@ -0,0 +1,189 @@ +package main + +import ( + "fmt" + "net" + "time" + + "sync" + + "github.com/riobard/go-shadowsocks2/core" + "github.com/riobard/go-shadowsocks2/socks" +) + +const udpBufSize = 64 * 1024 + +// Listen on laddr for UDP packets, encrypt and send to server to reach target. +func udpLocal(laddr, server, target string, ciph core.PacketConnCipher) { + srvAddr, err := net.ResolveUDPAddr("udp", server) + if err != nil { + logf("UDP server address error: %v", err) + return + } + + tgt := socks.ParseAddr(target) + if tgt == nil { + err = fmt.Errorf("invalid target address: %q", target) + logf("UDP target address error: %v", err) + return + } + + c, err := net.ListenPacket("udp", laddr) + if err != nil { + logf("UDP local listen error: %v", err) + return + } + defer c.Close() + + nm := newNATmap(config.UDPTimeout) + buf := make([]byte, udpBufSize) + copy(buf, tgt) + + logf("UDP tunnel %s <-> %s <-> %s", laddr, server, target) + for { + n, raddr, err := c.ReadFrom(buf[len(tgt):]) + if err != nil { + logf("UDP local read error: %v", err) + continue + } + + pc := nm.Get(raddr.String()) + if pc == nil { + pc, err = net.ListenPacket("udp", "") + if err != nil { + logf("UDP local listen error: %v", err) + continue + } + + pc = ciph(pc) + nm.Add(raddr, c, pc) + } + + _, err = pc.WriteTo(buf[:len(tgt)+n], srvAddr) + if err != nil { + logf("UDP local write error: %v", err) + continue + } + } +} + +// Listen on addr for encrypted packets and basically do UDP NAT. +func udpRemote(addr string, ciph core.PacketConnCipher) { + c, err := core.ListenPacket("udp", addr, ciph) + if err != nil { + logf("UDP remote listen error: %v", err) + return + } + defer c.Close() + + nm := newNATmap(config.UDPTimeout) + buf := make([]byte, udpBufSize) + + logf("listening UDP on %s", addr) + for { + n, raddr, err := c.ReadFrom(buf) + if err != nil { + logf("UDP remote read error: %v", err) + continue + } + + tgtAddr := socks.SplitAddr(buf[:n]) + if tgtAddr == nil { + logf("failed to split target address from packet: %q", buf[:n]) + continue + } + + tgtUDPAddr, err := net.ResolveUDPAddr("udp", tgtAddr.String()) + if err != nil { + logf("failed to resolve target UDP address: %v", err) + continue + } + + payload := buf[len(tgtAddr):n] + + pc := nm.Get(raddr.String()) + if pc == nil { + pc, err = net.ListenPacket("udp", "") + if err != nil { + logf("UDP remote listen error: %v", err) + continue + } + + nm.Add(raddr, c, pc) + } + + _, err = pc.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + if err != nil { + logf("UDP remote write error: %v", err) + continue + } + } +} + +// Packet NAT table +type natmap struct { + sync.RWMutex + m map[string]net.PacketConn + timeout time.Duration +} + +func newNATmap(timeout time.Duration) *natmap { + m := &natmap{} + m.m = make(map[string]net.PacketConn) + m.timeout = timeout + return m +} + +func (m *natmap) Get(key string) net.PacketConn { + m.RLock() + defer m.RUnlock() + return m.m[key] +} + +func (m *natmap) Set(key string, pc net.PacketConn) { + m.Lock() + defer m.Unlock() + + m.m[key] = pc +} + +func (m *natmap) Del(key string) net.PacketConn { + m.Lock() + defer m.Unlock() + + pc, ok := m.m[key] + if ok { + delete(m.m, key) + return pc + } + return nil +} + +func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn) { + m.Set(peer.String(), src) + + go func() { + timedCopy(dst, peer, src, m.timeout) + if pc := m.Del(peer.String()); pc != nil { + pc.Close() + } + }() +} + +// copy from src to dst with addr with read timeout +func timedCopy(dst net.PacketConn, addr net.Addr, src net.PacketConn, timeout time.Duration) error { + buf := make([]byte, udpBufSize) + + for { + src.SetReadDeadline(time.Now().Add(timeout)) + n, _, err := src.ReadFrom(buf) + if err != nil { + return err + } + + _, err = dst.WriteTo(buf[:n], addr) + if err != nil { + return err + } + } +}