@@ -345,8 +345,30 @@ func UnderlyingConn(c net.Conn) net.Conn {
345345 return c
346346}
347347
348+ func tcpConn (c net.Conn ) (t * net.TCPConn , ok bool ) {
349+ if c , ok := UnderlyingConn (c ).(* net.TCPConn ); ok {
350+ return c , ok
351+ }
352+ if c , ok := c .(* net.TCPConn ); ok {
353+ return c , ok
354+ }
355+ return nil , false
356+ }
357+
348358func goCloseConn (c net.Conn ) { go c .Close () }
349359
360+ func closeRead (c net.Conn ) {
361+ if c , ok := tcpConn (c ); ok {
362+ c .CloseRead ()
363+ }
364+ }
365+
366+ func closeWrite (c net.Conn ) {
367+ if c , ok := tcpConn (c ); ok {
368+ c .CloseWrite ()
369+ }
370+ }
371+
350372// HandleConn implements the Target interface.
351373func (dp * DialProxy ) HandleConn (src net.Conn ) {
352374 ctx := context .Background ()
@@ -371,20 +393,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
371393 defer goCloseConn (src )
372394
373395 if ka := dp .keepAlivePeriod (); ka > 0 {
374- if c , ok := UnderlyingConn (src ).(* net.TCPConn ); ok {
375- c .SetKeepAlive (true )
376- c .SetKeepAlivePeriod (ka )
377- }
378- if c , ok := dst .(* net.TCPConn ); ok {
379- c .SetKeepAlive (true )
380- c .SetKeepAlivePeriod (ka )
396+ for _ , c := range []net.Conn {src , dst } {
397+ if c , ok := tcpConn (c ); ok {
398+ c .SetKeepAlive (true )
399+ c .SetKeepAlivePeriod (ka )
400+ }
381401 }
382402 }
383403
384- errc := make (chan error , 1 )
404+ errc := make (chan error , 2 )
385405 go proxyCopy (errc , src , dst )
386406 go proxyCopy (errc , dst , src )
387407 <- errc
408+ <- errc
388409}
389410
390411func (dp * DialProxy ) sendProxyHeader (w io.Writer , src net.Conn ) error {
@@ -420,6 +441,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
420441// It's a named function instead of a func literal so users get
421442// named goroutines in debug goroutine stack dumps.
422443func proxyCopy (errc chan <- error , dst , src net.Conn ) {
444+ defer closeRead (src )
445+ defer closeWrite (dst )
446+
423447 // Before we unwrap src and/or dst, copy any buffered data.
424448 if wc , ok := src .(* Conn ); ok && len (wc .Peeked ) > 0 {
425449 if _ , err := dst .Write (wc .Peeked ); err != nil {
0 commit comments