diff --git a/tcp.go b/tcp.go index d196328a..3b7c7c39 100644 --- a/tcp.go +++ b/tcp.go @@ -43,6 +43,7 @@ func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net. go func() { defer c.Close() + c.(*net.TCPConn).SetKeepAlive(true) tgt, err := getAddr(c) if err != nil { @@ -50,12 +51,14 @@ func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net. return } - rc, err := core.Dial("tcp", server, ciph) + rc, err := net.Dial("tcp", server) if err != nil { logf("failed to connect to server %v: %v", server, err) return } defer rc.Close() + rc.(*net.TCPConn).SetKeepAlive(true) + rc = ciph.StreamConn(rc) if _, err = rc.Write(tgt); err != nil { logf("failed to send target address: %v", err) @@ -76,7 +79,7 @@ func tcpLocal(addr, server string, ciph core.StreamConnCipher, getAddr func(net. // Listen on addr for incoming connections. func tcpRemote(addr string, ciph core.StreamConnCipher) { - l, err := core.Listen("tcp", addr, ciph) + l, err := net.Listen("tcp", addr) if err != nil { logf("failed to listen on %s: %v", addr, err) return @@ -92,6 +95,8 @@ func tcpRemote(addr string, ciph core.StreamConnCipher) { go func() { defer c.Close() + c.(*net.TCPConn).SetKeepAlive(true) + c = ciph.StreamConn(c) tgt, err := socks.ReadAddr(c) if err != nil { @@ -105,6 +110,7 @@ func tcpRemote(addr string, ciph core.StreamConnCipher) { return } defer rc.Close() + rc.(*net.TCPConn).SetKeepAlive(true) logf("proxy %s <-> %s", c.RemoteAddr(), tgt) _, _, err = relay(c, rc)