@@ -344,8 +344,27 @@ func UnderlyingConn(c net.Conn) net.Conn {
344344 }
345345 return c
346346}
347+ func tcpConn (c net.Conn ) (t * net.TCPConn , ok bool ) {
348+ if c , ok := UnderlyingConn (c ).(* net.TCPConn ); ok {
349+ return c , ok
350+ }
351+ if c , ok := c .(* net.TCPConn ); ok {
352+ return c , ok
353+ }
354+ return nil , false
355+ }
347356
348357func goCloseConn (c net.Conn ) { go c .Close () }
358+ func closeRead (c net.Conn ) {
359+ if c , ok := tcpConn (c ); ok {
360+ c .CloseRead ()
361+ }
362+ }
363+ func closeWrite (c net.Conn ) {
364+ if c , ok := tcpConn (c ); ok {
365+ c .CloseWrite ()
366+ }
367+ }
349368
350369// HandleConn implements the Target interface.
351370func (dp * DialProxy ) HandleConn (src net.Conn ) {
@@ -371,20 +390,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
371390 defer goCloseConn (src )
372391
373392 if ka := dp .keepAlivePeriod (); ka > 0 {
374- if c , ok := UnderlyingConn (src ).(* net.TCPConn ); ok {
375- c .SetKeepAlive (true )
376- c .SetKeepAlivePeriod (ka )
377- }
378- if c , ok := dst .(* net.TCPConn ); ok {
379- c .SetKeepAlive (true )
380- c .SetKeepAlivePeriod (ka )
393+ for _ , c := range []net.Conn {src , dst } {
394+ if c , ok := tcpConn (c ); ok {
395+ c .SetKeepAlive (true )
396+ c .SetKeepAlivePeriod (ka )
397+ }
381398 }
382399 }
383400
384- errc := make (chan error , 1 )
401+ errc := make (chan error , 2 )
385402 go proxyCopy (errc , src , dst )
386403 go proxyCopy (errc , dst , src )
387404 <- errc
405+ <- errc
388406}
389407
390408func (dp * DialProxy ) sendProxyHeader (w io.Writer , src net.Conn ) error {
@@ -420,6 +438,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
420438// It's a named function instead of a func literal so users get
421439// named goroutines in debug goroutine stack dumps.
422440func proxyCopy (errc chan <- error , dst , src net.Conn ) {
441+ defer closeRead (src )
442+ defer closeWrite (dst )
443+
423444 // Before we unwrap src and/or dst, copy any buffered data.
424445 if wc , ok := src .(* Conn ); ok && len (wc .Peeked ) > 0 {
425446 if _ , err := dst .Write (wc .Peeked ); err != nil {
0 commit comments