diff --git a/main.go b/main.go index ca7dbf66..36db9283 100644 --- a/main.go +++ b/main.go @@ -101,27 +101,27 @@ func main() { if flags.UDPTun != "" { for _, tun := range strings.Split(flags.UDPTun, ",") { p := strings.Split(tun, "=") - go udpLocal(p[0], addr, p[1], ciph) + go udpLocal(p[0], addr, p[1], ciph.PacketConn) } } if flags.TCPTun != "" { for _, tun := range strings.Split(flags.TCPTun, ",") { p := strings.Split(tun, "=") - go tcpTun(p[0], addr, p[1], ciph) + go tcpTun(p[0], addr, p[1], ciph.StreamConn) } } if flags.Socks != "" { - go socksLocal(flags.Socks, addr, ciph) + go socksLocal(flags.Socks, addr, ciph.StreamConn) } if flags.RedirTCP != "" { - go redirLocal(flags.RedirTCP, addr, ciph) + go redirLocal(flags.RedirTCP, addr, ciph.StreamConn) } if flags.RedirTCP6 != "" { - go redir6Local(flags.RedirTCP6, addr, ciph) + go redir6Local(flags.RedirTCP6, addr, ciph.StreamConn) } } @@ -143,8 +143,8 @@ func main() { log.Fatal(err) } - go udpRemote(addr, ciph) - go tcpRemote(addr, ciph) + go udpRemote(addr, ciph.PacketConn) + go tcpRemote(addr, ciph.StreamConn) } sigCh := make(chan os.Signal, 1) diff --git a/shadowaead/stream.go b/shadowaead/stream.go index a8dce31a..5f499a21 100644 --- a/shadowaead/stream.go +++ b/shadowaead/stream.go @@ -148,6 +148,16 @@ func (r *reader) Read(b []byte) (int, error) { // there's no more data to write or when an error occurs. Return number of // bytes written to w and any error encountered. func (r *reader) WriteTo(w io.Writer) (n int64, err error) { + // write decrypted bytes left over from previous record + for len(r.leftover) > 0 { + nw, ew := w.Write(r.leftover) + r.leftover = r.leftover[nw:] + n += int64(nw) + if ew != nil { + return n, ew + } + } + for { nr, er := r.read() if nr > 0 { diff --git a/tcp.go b/tcp.go index 3b7c7c39..d88bcf2a 100644 --- a/tcp.go +++ b/tcp.go @@ -5,29 +5,28 @@ import ( "net" "time" - "github.com/shadowsocks/go-shadowsocks2/core" "github.com/shadowsocks/go-shadowsocks2/socks" ) // Create a SOCKS server listening on addr and proxy to server. -func socksLocal(addr, server string, ciph core.StreamConnCipher) { +func socksLocal(addr, server string, shadow func(net.Conn) net.Conn) { logf("SOCKS proxy %s <-> %s", addr, server) - tcpLocal(addr, server, ciph, func(c net.Conn) (socks.Addr, error) { return socks.Handshake(c) }) + tcpLocal(addr, server, shadow, func(c net.Conn) (socks.Addr, error) { return socks.Handshake(c) }) } // Create a TCP tunnel from addr to target via server. -func tcpTun(addr, server, target string, ciph core.StreamConnCipher) { +func tcpTun(addr, server, target string, shadow func(net.Conn) net.Conn) { tgt := socks.ParseAddr(target) if tgt == nil { logf("invalid target address %q", target) return } logf("TCP tunnel %s <-> %s <-> %s", addr, server, target) - tcpLocal(addr, server, ciph, func(net.Conn) (socks.Addr, error) { return tgt, nil }) + tcpLocal(addr, server, shadow, func(net.Conn) (socks.Addr, error) { return tgt, nil }) } // Listen on addr and proxy to server to reach target from getAddr. -func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net.Conn) (socks.Addr, error)) { +func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func(net.Conn) (socks.Addr, error)) { l, err := net.Listen("tcp", addr) if err != nil { logf("failed to listen on %s: %v", addr, err) @@ -58,7 +57,7 @@ func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net. } defer rc.Close() rc.(*net.TCPConn).SetKeepAlive(true) - rc = ciph.StreamConn(rc) + rc = shadow(rc) if _, err = rc.Write(tgt); err != nil { logf("failed to send target address: %v", err) @@ -78,7 +77,7 @@ func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net. } // Listen on addr for incoming connections. -func tcpRemote(addr string, ciph core.StreamConnCipher) { +func tcpRemote(addr string, shadow func(net.Conn) net.Conn) { l, err := net.Listen("tcp", addr) if err != nil { logf("failed to listen on %s: %v", addr, err) @@ -96,7 +95,7 @@ func tcpRemote(addr string, ciph core.StreamConnCipher) { go func() { defer c.Close() c.(*net.TCPConn).SetKeepAlive(true) - c = ciph.StreamConn(c) + c = shadow(c) tgt, err := socks.ReadAddr(c) if err != nil { diff --git a/tcp_linux.go b/tcp_linux.go index fbf32b72..84d10e78 100644 --- a/tcp_linux.go +++ b/tcp_linux.go @@ -6,7 +6,6 @@ import ( "syscall" "unsafe" - "github.com/shadowsocks/go-shadowsocks2/core" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -16,15 +15,15 @@ const ( ) // Listen on addr for netfilter redirected TCP connections -func redirLocal(addr, server string, ciph core.StreamConnCipher) { +func redirLocal(addr, server string, shadow func(net.Conn) net.Conn) { logf("TCP redirect %s <-> %s", addr, server) - tcpLocal(addr, server, ciph, func(c net.Conn) (socks.Addr, error) { return getOrigDst(c, false) }) + tcpLocal(addr, server, shadow, func(c net.Conn) (socks.Addr, error) { return getOrigDst(c, false) }) } // Listen on addr for netfilter redirected TCP IPv6 connections. -func redir6Local(addr, server string, ciph core.StreamConnCipher) { +func redir6Local(addr, server string, shadow func(net.Conn) net.Conn) { logf("TCP6 redirect %s <-> %s", addr, server) - tcpLocal(addr, server, ciph, func(c net.Conn) (socks.Addr, error) { return getOrigDst(c, true) }) + tcpLocal(addr, server, shadow, func(c net.Conn) (socks.Addr, error) { return getOrigDst(c, true) }) } // Get the original destination of a TCP connection. diff --git a/tcp_other.go b/tcp_other.go index 9176a82c..a7c41d63 100644 --- a/tcp_other.go +++ b/tcp_other.go @@ -2,12 +2,12 @@ package main -import "github.com/shadowsocks/go-shadowsocks2/core" +import "net" -func redirLocal(addr, server string, ciph core.StreamConnCipher) { +func redirLocal(addr, server string, shadow func(net.Conn) net.Conn) { logf("TCP redirect not supported") } -func redir6Local(addr, server string, ciph core.StreamConnCipher) { +func redir6Local(addr, server string, shadow func(net.Conn) net.Conn) { logf("TCP6 redirect not supported") } diff --git a/udp.go b/udp.go index 83870dc5..f8c787b4 100644 --- a/udp.go +++ b/udp.go @@ -7,14 +7,13 @@ import ( "sync" - "github.com/shadowsocks/go-shadowsocks2/core" "github.com/shadowsocks/go-shadowsocks2/socks" ) const udpBufSize = 64 * 1024 // Listen on laddr for UDP packets, encrypt and send to server to reach target. -func udpLocal(laddr, server, target string, ciph core.PacketConnCipher) { +func udpLocal(laddr, server, target string, shadow func(net.PacketConn) net.PacketConn) { srvAddr, err := net.ResolveUDPAddr("udp", server) if err != nil { logf("UDP server address error: %v", err) @@ -55,7 +54,7 @@ func udpLocal(laddr, server, target string, ciph core.PacketConnCipher) { continue } - pc = ciph.PacketConn(pc) + pc = shadow(pc) nm.Add(raddr, c, pc, false) } @@ -68,13 +67,14 @@ func udpLocal(laddr, server, target string, ciph core.PacketConnCipher) { } // Listen on addr for encrypted packets and basically do UDP NAT. -func udpRemote(addr string, ciph core.PacketConnCipher) { - c, err := core.ListenPacket("udp", addr, ciph) +func udpRemote(addr string, shadow func(net.PacketConn) net.PacketConn) { + c, err := net.ListenPacket("udp", addr) if err != nil { logf("UDP remote listen error: %v", err) return } defer c.Close() + c = shadow(c) nm := newNATmap(config.UDPTimeout) buf := make([]byte, udpBufSize)