diff --git a/main.go b/main.go index d1d55dc1..ca7dbf66 100644 --- a/main.go +++ b/main.go @@ -84,18 +84,13 @@ func main() { addr := flags.Client cipher := flags.Cipher password := flags.Password + var err error if strings.HasPrefix(addr, "ss://") { - u, err := url.Parse(addr) + addr, cipher, password, err = parseURL(addr) if err != nil { log.Fatal(err) } - - addr = u.Host - if u.User != nil { - cipher = u.User.Username() - password, _ = u.User.Password() - } } ciph, err := core.PickCipher(cipher, key, password) @@ -134,18 +129,13 @@ func main() { addr := flags.Server cipher := flags.Cipher password := flags.Password + var err error if strings.HasPrefix(addr, "ss://") { - u, err := url.Parse(addr) + addr, cipher, password, err = parseURL(addr) if err != nil { log.Fatal(err) } - - addr = u.Host - if u.User != nil { - cipher = u.User.Username() - password, _ = u.User.Password() - } } ciph, err := core.PickCipher(cipher, key, password) @@ -161,3 +151,17 @@ func main() { signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh } + +func parseURL(s string) (addr, cipher, password string, err error) { + u, err := url.Parse(s) + if err != nil { + return + } + + addr = u.Host + if u.User != nil { + cipher = u.User.Username() + password, _ = u.User.Password() + } + return +} diff --git a/tcp.go b/tcp.go index 7c4607d2..d196328a 100644 --- a/tcp.go +++ b/tcp.go @@ -28,58 +28,55 @@ func tcpTun(addr, server, target string, ciph core.StreamConnCipher) { // Listen on addr and proxy to server to reach target from getAddr. func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net.Conn) (socks.Addr, error)) { - ln, err := net.Listen("tcp", addr) + l, err := net.Listen("tcp", addr) if err != nil { logf("failed to listen on %s: %v", addr, err) return } for { - conn, err := ln.Accept() + c, err := l.Accept() if err != nil { logf("failed to accept: %s", err) continue } - tgt, err := getAddr(conn) - if err != nil { - logf("failed to get target address: %v", err) - continue - } - - go tcpLocalHandle(conn, server, tgt, ciph) - } -} - -func tcpLocalHandle(c net.Conn, server string, target socks.Addr, ciph core.StreamConnCipher) { - logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, target) - - defer c.Close() - - sc, err := core.Dial("tcp", server, ciph) - if err != nil { - logf("failed to connect to server %v: %v", server, err) - return - } - defer sc.Close() - - if _, err = sc.Write(target); err != nil { - logf("failed to send target address: %v", err) - return - } - - _, _, err = relay(sc, c) - if err != nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - return // ignore i/o timeout - } - logf("relay error: %v", err) + go func() { + defer c.Close() + + tgt, err := getAddr(c) + if err != nil { + logf("failed to get target address: %v", err) + return + } + + rc, err := core.Dial("tcp", server, ciph) + if err != nil { + logf("failed to connect to server %v: %v", server, err) + return + } + defer rc.Close() + + if _, err = rc.Write(tgt); err != nil { + logf("failed to send target address: %v", err) + return + } + + logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, tgt) + _, _, err = relay(rc, c) + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + return // ignore i/o timeout + } + logf("relay error: %v", err) + } + }() } } // Listen on addr for incoming connections. func tcpRemote(addr string, ciph core.StreamConnCipher) { - ln, err := core.Listen("tcp", addr, ciph) + l, err := core.Listen("tcp", addr, ciph) if err != nil { logf("failed to listen on %s: %v", addr, err) return @@ -87,38 +84,37 @@ func tcpRemote(addr string, ciph core.StreamConnCipher) { logf("listening TCP on %s", addr) for { - conn, err := ln.Accept() + c, err := l.Accept() if err != nil { - logf("failed to accept: %s", err) + logf("failed to accept: %v", err) continue } - go tcpRemoteHandle(conn) - } -} - -func tcpRemoteHandle(c net.Conn) { - defer c.Close() - - addr, err := socks.ReadAddr(c) - if err != nil { - logf("failed to read address: %v", err) - return - } - logf("proxy %s <-> %s", c.RemoteAddr(), addr) - - conn, err := net.Dial("tcp", addr.String()) - if err != nil { - logf("failed to connect to target: %s", err) - return - } - defer conn.Close() - _, _, err = relay(c, conn) - if err != nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - return // ignore i/o timeout - } - logf("relay error: %v", err) + go func() { + defer c.Close() + + tgt, err := socks.ReadAddr(c) + if err != nil { + logf("failed to get target address: %v", err) + return + } + + rc, err := net.Dial("tcp", tgt.String()) + if err != nil { + logf("failed to connect to target: %v", err) + return + } + defer rc.Close() + + logf("proxy %s <-> %s", c.RemoteAddr(), tgt) + _, _, err = relay(c, rc) + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + return // ignore i/o timeout + } + logf("relay error: %v", err) + } + }() } }