Skip to content

Commit

Permalink
Merge pull request shadowsocks#12 from riobard/dev
Browse files Browse the repository at this point in the history
UDP relay bugfix and optimization
  • Loading branch information
riobard authored Feb 19, 2017
2 parents 3ea7b01 + 22b4ba7 commit c2beadb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
21 changes: 12 additions & 9 deletions shadowaead/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ import (
"errors"
"io"
"net"
"sync"
)

// ErrShortPacket means that the packet is too short for a valid encrypted packet.
var ErrShortPacket = errors.New("short packet")

var _zerononce [128]byte // read-only. 128 bytes is more than enough.

// Pack encrypts plaintext using Cipher with a randomly generated salt and
// returns a slice of dst containing the encrypted packet and any error occurred.
// Ensure len(dst) >= ciph.SaltSize() + len(plaintext) + aead.Overhead().
Expand All @@ -28,8 +31,7 @@ func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) {
if len(dst) < saltSize+len(plaintext)+aead.Overhead() {
return nil, io.ErrShortBuffer
}
nonce := make([]byte, aead.NonceSize())
b := aead.Seal(dst[saltSize:saltSize], nonce, plaintext, nil)
b := aead.Seal(dst[saltSize:saltSize], _zerononce[:aead.NonceSize()], plaintext, nil)
return dst[:saltSize+len(b)], nil
}

Expand All @@ -51,27 +53,28 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
if saltSize+len(dst)+aead.Overhead() < len(pkt) {
return nil, io.ErrShortBuffer
}

nonce := make([]byte, aead.NonceSize())
b, err := aead.Open(dst[:0], nonce, pkt[saltSize:], nil)
b, err := aead.Open(dst[:0], _zerononce[:aead.NonceSize()], pkt[saltSize:], nil)
return b, err
}

type packetConn struct {
net.PacketConn
Cipher
sync.Mutex
buf []byte // write lock
}

// NewPacketConn wraps a net.PacketConn with cipher
func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn {
return &packetConn{PacketConn: c, Cipher: ciph}
const maxPacketSize = 64 * 1024
return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, maxPacketSize)}
}

// WriteTo encrypts b and write to addr using the embedded PacketConn.
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
const overhead = 16
buf := make([]byte, c.Cipher.SaltSize()+len(b)+overhead)
buf, err := Pack(buf, b, c)
c.Lock()
defer c.Unlock()
buf, err := Pack(c.buf, b, c)
if err != nil {
return 0, err
}
Expand Down
10 changes: 7 additions & 3 deletions shadowstream/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"io"
"net"
"sync"
)

// ErrShortPacket means the packet is too short to be a valid encrypted packet.
Expand Down Expand Up @@ -45,16 +46,19 @@ func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) {
type packetConn struct {
net.PacketConn
Cipher
buf []byte
sync.Mutex // write lock
}

// NewPacketConn wraps a net.PacketConn with stream cipher encryption/decryption.
func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn {
return &packetConn{PacketConn: c, Cipher: ciph}
return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, 64*1024)}
}

func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
buf := make([]byte, c.IVSize()+len(b))
_, err := Pack(buf, b, c.Cipher)
c.Lock()
defer c.Unlock()
buf, err := Pack(c.buf, b, c.Cipher)
if err != nil {
return 0, err
}
Expand Down
25 changes: 17 additions & 8 deletions udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func udpLocal(laddr, server, target string, ciph core.PacketConnCipher) {
}

pc = ciph.PacketConn(pc)
nm.Add(raddr, c, pc)
nm.Add(raddr, c, pc, false)
}

_, err = pc.WriteTo(buf[:len(tgt)+n], srvAddr)
Expand Down Expand Up @@ -109,7 +109,7 @@ func udpRemote(addr string, ciph core.PacketConnCipher) {
continue
}

nm.Add(raddr, c, pc)
nm.Add(raddr, c, pc, true)
}

_, err = pc.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
Expand Down Expand Up @@ -159,29 +159,38 @@ func (m *natmap) Del(key string) net.PacketConn {
return nil
}

func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn) {
func (m *natmap) Add(peer net.Addr, dst, src net.PacketConn, srcIncluded bool) {
m.Set(peer.String(), src)

go func() {
timedCopy(dst, peer, src, m.timeout)
timedCopy(dst, peer, src, m.timeout, srcIncluded)
if pc := m.Del(peer.String()); pc != nil {
pc.Close()
}
}()
}

// copy from src to dst with addr with read timeout
func timedCopy(dst net.PacketConn, addr net.Addr, src net.PacketConn, timeout time.Duration) error {
// copy from src to dst at target with read timeout
func timedCopy(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration, srcIncluded bool) error {
buf := make([]byte, udpBufSize)

for {
src.SetReadDeadline(time.Now().Add(timeout))
n, _, err := src.ReadFrom(buf)
n, raddr, err := src.ReadFrom(buf)
if err != nil {
return err
}

_, err = dst.WriteTo(buf[:n], addr)
if srcIncluded { // server -> client: add original packet source
srcAddr := socks.ParseAddr(raddr.String())
copy(buf[len(srcAddr):], buf[:n])
copy(buf, srcAddr)
_, err = dst.WriteTo(buf[:len(srcAddr)+n], target)
} else { // client -> user: strip original packet source
srcAddr := socks.SplitAddr(buf[:n])
_, err = dst.WriteTo(buf[len(srcAddr):n], target)
}

if err != nil {
return err
}
Expand Down

0 comments on commit c2beadb

Please sign in to comment.