diff --git a/tcp.go b/tcp.go index ff239ae7..09461a06 100644 --- a/tcp.go +++ b/tcp.go @@ -2,9 +2,11 @@ package main import ( "bufio" + "errors" "io" "io/ioutil" "net" + "os" "sync" "time" @@ -83,11 +85,7 @@ func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func( } logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, tgt) - err = relay(rc, c) - if err != nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - return // ignore i/o timeout - } + if err = relay(rc, c); err != nil { logf("relay error: %v", err) } }() @@ -137,18 +135,14 @@ func tcpRemote(addr string, shadow func(net.Conn) net.Conn) { defer rc.Close() logf("proxy %s <-> %s", c.RemoteAddr(), tgt) - err = relay(sc, rc) - if err != nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - return // ignore i/o timeout - } + if err = relay(sc, rc); err != nil { logf("relay error: %v", err) } }() } } -// relay copies between left and right bidirectionally. Returns any error occurred. +// relay copies between left and right bidirectionally func relay(left, right net.Conn) error { var err, err1 error var wg sync.WaitGroup @@ -159,15 +153,16 @@ func relay(left, right net.Conn) error { _, err1 = io.Copy(right, left) right.SetReadDeadline(time.Now().Add(wait)) // unblock read on right }() - _, err = io.Copy(left, right) left.SetReadDeadline(time.Now().Add(wait)) // unblock read on left wg.Wait() - - if err1 != nil { - err = err1 + if err1 != nil && !errors.Is(err1, os.ErrDeadlineExceeded) { // requires Go 1.15+ + return err1 + } + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { + return err } - return err + return nil } type corkedConn struct {