Skip to content

Commit 06215b5

Browse files
committed
tcpproxy: implement half-close dance in proxy
1 parent 91f8614 commit 06215b5

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

tcpproxy.go

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,27 @@ func UnderlyingConn(c net.Conn) net.Conn {
344344
}
345345
return c
346346
}
347+
func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) {
348+
if c, ok := UnderlyingConn(c).(*net.TCPConn); ok {
349+
return c, ok
350+
}
351+
if c, ok := c.(*net.TCPConn); ok {
352+
return c, ok
353+
}
354+
return nil, false
355+
}
347356

348357
func goCloseConn(c net.Conn) { go c.Close() }
358+
func closeRead(c net.Conn) {
359+
if c, ok := tcpConn(c); ok {
360+
c.CloseRead()
361+
}
362+
}
363+
func closeWrite(c net.Conn) {
364+
if c, ok := tcpConn(c); ok {
365+
c.CloseWrite()
366+
}
367+
}
349368

350369
// HandleConn implements the Target interface.
351370
func (dp *DialProxy) HandleConn(src net.Conn) {
@@ -371,20 +390,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
371390
defer goCloseConn(src)
372391

373392
if ka := dp.keepAlivePeriod(); ka > 0 {
374-
if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
375-
c.SetKeepAlive(true)
376-
c.SetKeepAlivePeriod(ka)
377-
}
378-
if c, ok := dst.(*net.TCPConn); ok {
379-
c.SetKeepAlive(true)
380-
c.SetKeepAlivePeriod(ka)
393+
for _, c := range []net.Conn{src, dst} {
394+
if c, ok := tcpConn(c); ok {
395+
c.SetKeepAlive(true)
396+
c.SetKeepAlivePeriod(ka)
397+
}
381398
}
382399
}
383400

384-
errc := make(chan error, 1)
401+
errc := make(chan error, 2)
385402
go proxyCopy(errc, src, dst)
386403
go proxyCopy(errc, dst, src)
387404
<-errc
405+
<-errc
388406
}
389407

390408
func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
@@ -420,6 +438,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
420438
// It's a named function instead of a func literal so users get
421439
// named goroutines in debug goroutine stack dumps.
422440
func proxyCopy(errc chan<- error, dst, src net.Conn) {
441+
defer closeRead(src)
442+
defer closeWrite(dst)
443+
423444
// Before we unwrap src and/or dst, copy any buffered data.
424445
if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 {
425446
if _, err := dst.Write(wc.Peeked); err != nil {

tcpproxy_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,45 @@ func testProxy(t *testing.T, front net.Listener) *Proxy {
174174
}
175175
}
176176

177+
func TestBufferedClose(t *testing.T) {
178+
front := newLocalListener(t)
179+
defer front.Close()
180+
back := newLocalListener(t)
181+
defer back.Close()
182+
183+
p := testProxy(t, front)
184+
p.AddRoute(testFrontAddr, To(back.Addr().String()))
185+
if err := p.Start(); err != nil {
186+
t.Fatal(err)
187+
}
188+
189+
toFront, err := net.Dial("tcp", front.Addr().String())
190+
if err != nil {
191+
t.Fatal(err)
192+
}
193+
defer toFront.Close()
194+
195+
fromProxy, err := back.Accept()
196+
if err != nil {
197+
t.Fatal(err)
198+
}
199+
defer fromProxy.Close()
200+
const msg = "message"
201+
if _, err := io.WriteString(toFront, msg); err != nil {
202+
t.Fatal(err)
203+
}
204+
// actively close toFront, the write should still make to the back.
205+
toFront.Close()
206+
207+
buf := make([]byte, len(msg))
208+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
209+
t.Fatal(err)
210+
}
211+
if string(buf) != msg {
212+
t.Fatalf("got %q; want %q", buf, msg)
213+
}
214+
}
215+
177216
func TestProxyAlwaysMatch(t *testing.T) {
178217
front := newLocalListener(t)
179218
defer front.Close()
@@ -196,6 +235,7 @@ func TestProxyAlwaysMatch(t *testing.T) {
196235
if err != nil {
197236
t.Fatal(err)
198237
}
238+
defer fromProxy.Close()
199239
const msg = "message"
200240
io.WriteString(toFront, msg)
201241

0 commit comments

Comments
 (0)