Skip to content

Commit 2862066

Browse files
authored
tcpproxy: implement half-close dance in proxy (#38)
Signed-off-by: James Tucker <jftucker@gmail.com>
1 parent 91f8614 commit 2862066

File tree

2 files changed

+72
-8
lines changed

2 files changed

+72
-8
lines changed

tcpproxy.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,30 @@ func UnderlyingConn(c net.Conn) net.Conn {
345345
return c
346346
}
347347

348+
func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) {
349+
if c, ok := UnderlyingConn(c).(*net.TCPConn); ok {
350+
return c, ok
351+
}
352+
if c, ok := c.(*net.TCPConn); ok {
353+
return c, ok
354+
}
355+
return nil, false
356+
}
357+
348358
func goCloseConn(c net.Conn) { go c.Close() }
349359

360+
func closeRead(c net.Conn) {
361+
if c, ok := tcpConn(c); ok {
362+
c.CloseRead()
363+
}
364+
}
365+
366+
func closeWrite(c net.Conn) {
367+
if c, ok := tcpConn(c); ok {
368+
c.CloseWrite()
369+
}
370+
}
371+
350372
// HandleConn implements the Target interface.
351373
func (dp *DialProxy) HandleConn(src net.Conn) {
352374
ctx := context.Background()
@@ -371,20 +393,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
371393
defer goCloseConn(src)
372394

373395
if ka := dp.keepAlivePeriod(); ka > 0 {
374-
if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
375-
c.SetKeepAlive(true)
376-
c.SetKeepAlivePeriod(ka)
377-
}
378-
if c, ok := dst.(*net.TCPConn); ok {
379-
c.SetKeepAlive(true)
380-
c.SetKeepAlivePeriod(ka)
396+
for _, c := range []net.Conn{src, dst} {
397+
if c, ok := tcpConn(c); ok {
398+
c.SetKeepAlive(true)
399+
c.SetKeepAlivePeriod(ka)
400+
}
381401
}
382402
}
383403

384-
errc := make(chan error, 1)
404+
errc := make(chan error, 2)
385405
go proxyCopy(errc, src, dst)
386406
go proxyCopy(errc, dst, src)
387407
<-errc
408+
<-errc
388409
}
389410

390411
func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
@@ -420,6 +441,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
420441
// It's a named function instead of a func literal so users get
421442
// named goroutines in debug goroutine stack dumps.
422443
func proxyCopy(errc chan<- error, dst, src net.Conn) {
444+
defer closeRead(src)
445+
defer closeWrite(dst)
446+
423447
// Before we unwrap src and/or dst, copy any buffered data.
424448
if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 {
425449
if _, err := dst.Write(wc.Peeked); err != nil {

tcpproxy_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,45 @@ func testProxy(t *testing.T, front net.Listener) *Proxy {
174174
}
175175
}
176176

177+
func TestBufferedClose(t *testing.T) {
178+
front := newLocalListener(t)
179+
defer front.Close()
180+
back := newLocalListener(t)
181+
defer back.Close()
182+
183+
p := testProxy(t, front)
184+
p.AddRoute(testFrontAddr, To(back.Addr().String()))
185+
if err := p.Start(); err != nil {
186+
t.Fatal(err)
187+
}
188+
189+
toFront, err := net.Dial("tcp", front.Addr().String())
190+
if err != nil {
191+
t.Fatal(err)
192+
}
193+
defer toFront.Close()
194+
195+
fromProxy, err := back.Accept()
196+
if err != nil {
197+
t.Fatal(err)
198+
}
199+
defer fromProxy.Close()
200+
const msg = "message"
201+
if _, err := io.WriteString(toFront, msg); err != nil {
202+
t.Fatal(err)
203+
}
204+
// actively close toFront, the write should still make to the back.
205+
toFront.Close()
206+
207+
buf := make([]byte, len(msg))
208+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
209+
t.Fatal(err)
210+
}
211+
if string(buf) != msg {
212+
t.Fatalf("got %q; want %q", buf, msg)
213+
}
214+
}
215+
177216
func TestProxyAlwaysMatch(t *testing.T) {
178217
front := newLocalListener(t)
179218
defer front.Close()
@@ -196,6 +235,7 @@ func TestProxyAlwaysMatch(t *testing.T) {
196235
if err != nil {
197236
t.Fatal(err)
198237
}
238+
defer fromProxy.Close()
199239
const msg = "message"
200240
io.WriteString(toFront, msg)
201241

0 commit comments

Comments
 (0)