Skip to content

Commit e387889

Browse files
committed
Add TargetListener
1 parent 2eb0155 commit e387889

File tree

3 files changed

+210
-24
lines changed

3 files changed

+210
-24
lines changed

listener.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright 2017 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package tcpproxy
6+
7+
import (
8+
"io"
9+
"net"
10+
"sync"
11+
)
12+
13+
// TargetListener implements both net.Listener and Target.
14+
// Matched Targets become accepted connections.
15+
type TargetListener struct {
16+
Address string // Address is the string reported by TargetListener.Addr().String().
17+
18+
mu sync.Mutex
19+
cond *sync.Cond
20+
closed bool
21+
nextConn net.Conn
22+
}
23+
24+
var (
25+
_ net.Listener = (*TargetListener)(nil)
26+
_ Target = (*TargetListener)(nil)
27+
)
28+
29+
func (tl *TargetListener) lock() {
30+
tl.mu.Lock()
31+
if tl.cond == nil {
32+
tl.cond = sync.NewCond(&tl.mu)
33+
}
34+
}
35+
36+
type tcpAddr string
37+
38+
func (a tcpAddr) Network() string { return "tcp" }
39+
func (a tcpAddr) String() string { return string(a) }
40+
41+
func (tl *TargetListener) Addr() net.Addr { return tcpAddr(tl.Address) }
42+
43+
func (tl *TargetListener) Close() error {
44+
tl.lock()
45+
if tl.closed {
46+
tl.mu.Unlock()
47+
return nil
48+
}
49+
tl.closed = true
50+
tl.mu.Unlock()
51+
tl.cond.Broadcast()
52+
return nil
53+
}
54+
55+
// HandleConn implements the Target interface. It blocks until tl is
56+
// closed or another goroutine has called Accept and received c.
57+
func (tl *TargetListener) HandleConn(c net.Conn) {
58+
tl.lock()
59+
defer tl.mu.Unlock()
60+
for tl.nextConn != nil && !tl.closed {
61+
tl.cond.Wait()
62+
}
63+
if tl.closed {
64+
c.Close()
65+
return
66+
}
67+
tl.nextConn = c
68+
tl.cond.Broadcast() // Signal might be sufficient; verify.
69+
for tl.nextConn == c && !tl.closed {
70+
tl.cond.Wait()
71+
}
72+
if tl.closed {
73+
c.Close()
74+
return
75+
}
76+
}
77+
78+
func (tl *TargetListener) Accept() (net.Conn, error) {
79+
tl.lock()
80+
for tl.nextConn == nil && !tl.closed {
81+
tl.cond.Wait()
82+
}
83+
if tl.closed {
84+
tl.mu.Unlock()
85+
return nil, io.EOF
86+
}
87+
c := tl.nextConn
88+
tl.nextConn = nil
89+
tl.mu.Unlock()
90+
tl.cond.Broadcast() // Signal might be sufficient; verify.
91+
92+
return c, nil
93+
}

listener_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2017 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package tcpproxy
6+
7+
import (
8+
"io"
9+
"testing"
10+
)
11+
12+
func TestListenerAccept(t *testing.T) {
13+
tl := new(TargetListener)
14+
ch := make(chan interface{}, 1)
15+
go func() {
16+
for {
17+
conn, err := tl.Accept()
18+
if err != nil {
19+
ch <- err
20+
return
21+
} else {
22+
ch <- conn
23+
}
24+
}
25+
}()
26+
27+
for i := 0; i < 3; i++ {
28+
conn := new(Conn)
29+
tl.HandleConn(conn)
30+
got := <-ch
31+
if got != conn {
32+
t.Errorf("Accept conn = %v; want %v", got, conn)
33+
}
34+
}
35+
tl.Close()
36+
got := <-ch
37+
if got != io.EOF {
38+
t.Errorf("Accept error post-Close = %v; want io.EOF", got)
39+
}
40+
}

tcpproxy.go

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

5-
// Package tcpproxy lets users build TCP, HTTP/1, and TLS+SNI proxies.
5+
// Package tcpproxy lets users build TCP proxies, optionally making
6+
// routing decisions based on HTTP/1 Host headers and the SNI hostname
7+
// in TLS connections.
68
//
79
// Typical usage:
810
//
@@ -14,11 +16,31 @@
1416
// p.AddSNIHostRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432"))
1517
// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback
1618
// log.Fatal(p.Run())
19+
//
20+
// Calling Run (or Start) on a proxy also starts all the necessary
21+
// listeners.
22+
//
23+
// For each accepted connection, the rules for that ipPort are
24+
// matched, in order. If one matches (currently HTTP Host, SNI, or
25+
// always), then the connection is handed to the target.
26+
//
27+
// The two predefined Target implementations are:
28+
//
29+
// 1) DialProxy, proxying to another address (use the To func to return a
30+
// DialProxy value),
31+
//
32+
// 2) TargetListener, making the matched connection available via a
33+
// net.Listener.Accept call.
34+
//
35+
// But Target is an interface, so you can also write your own.
36+
//
37+
// Note that tcpproxy does not do any TLS encryption or decryption. It
38+
// only (via DialProxy) copies bytes around. The SNI hostname in the TLS
39+
// header is unencrypted, for better or worse.
1740
package tcpproxy
1841

1942
import (
2043
"bufio"
21-
"bytes"
2244
"context"
2345
"errors"
2446
"io"
@@ -28,7 +50,10 @@ import (
2850
)
2951

3052
// Proxy is a proxy. Its zero value is a valid proxy that does
31-
// nothing. Call methods to add routes before calling Run.
53+
// nothing. Call methods to add routes before calling Start or Run.
54+
//
55+
// The order that routes are added in matters; each is matched in the order
56+
// registered.
3257
type Proxy struct {
3358
routes map[string][]route // ip:port => route
3459

@@ -156,11 +181,14 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
156181
br := bufio.NewReader(c)
157182
for _, route := range routes {
158183
if route.matcher.match(br) {
159-
buffered, _ := br.Peek(br.Buffered())
160-
route.target.HandleConn(changeReaderConn{
161-
r: io.MultiReader(bytes.NewReader(buffered), c),
162-
Conn: c,
163-
}, c)
184+
if n := br.Buffered(); n > 0 {
185+
peeked, _ := br.Peek(br.Buffered())
186+
c = &Conn{
187+
Peeked: peeked,
188+
Conn: c,
189+
}
190+
}
191+
route.target.HandleConn(c)
164192
return true
165193
}
166194
}
@@ -170,32 +198,48 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
170198
return false
171199
}
172200

173-
// changeReaderConn is a net.Conn wrapper with a separate reader function.
174-
type changeReaderConn struct {
175-
r io.Reader
201+
// Conn is an incoming connection that has had some bytes read from it
202+
// to determine how to route the connection. The Read method stitches
203+
// the peeked bytes and unread bytes back together.
204+
type Conn struct {
205+
// Peeked are the bytes that have been read from Conn for the
206+
// purposes of route matching, but have not yet been consumed
207+
// by Read calls. It set to nil by Read when fully consumed.
208+
Peeked []byte
209+
210+
// Conn is the underlying connection.
211+
// It can be type asserted against *net.TCPConn or other types
212+
// as needed. It should not be read from directly unless
213+
// Peeked is nil.
176214
net.Conn
177215
}
178216

179-
func (c changeReaderConn) Read(p []byte) (int, error) { return c.r.Read(p) }
217+
func (c *Conn) Read(p []byte) (n int, err error) {
218+
if len(c.Peeked) > 0 {
219+
n = copy(p, c.Peeked)
220+
c.Peeked = c.Peeked[n:]
221+
if len(c.Peeked) == 0 {
222+
c.Peeked = nil
223+
}
224+
return n, nil
225+
}
226+
return c.Conn.Read(p)
227+
}
180228

181229
// Target is what an incoming matched connection is sent to.
182230
type Target interface {
183231
// HandleConn is called when an incoming connection is
184232
// matched. After the call to HandleConn, the tcpproxy
185233
// package never touches the conn again. Implementations are
186-
// responsible for closing the conn when needed.
187-
//
188-
// The c Conn acts like a new conn, without any bytes consumed,
189-
// but it has an unexported concrete type and cannot be type
190-
// asserted to *net.TCPConn, etc.
234+
// responsible for closing the connection when needed.
191235
//
192-
// The rawConn represents the underlying connections (with
193-
// some bytes removed) and should only be used for type
194-
// assertions and setting deadlines, not reading.
195-
HandleConn(c net.Conn, rawConn net.Conn)
236+
// The concrete type of conn will be of type *Conn if any
237+
// bytes have been consumed for the purposes of route
238+
// matching.
239+
HandleConn(net.Conn)
196240
}
197241

198-
// To is shorthand way of writing &DialProxy{Addr: addr}.
242+
// To is shorthand way of writing &tlsproxy.DialProxy{Addr: addr}.
199243
func To(addr string) *DialProxy {
200244
return &DialProxy{Addr: addr}
201245
}
@@ -229,7 +273,16 @@ type DialProxy struct {
229273
OnDialError func(src net.Conn, dstDialErr error)
230274
}
231275

232-
func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) {
276+
// UnderlyingConn returns c.Conn if c of type *Conn,
277+
// otherwise it returns c.
278+
func UnderlyingConn(c net.Conn) net.Conn {
279+
if wrap, ok := c.(*Conn); ok {
280+
return wrap.Conn
281+
}
282+
return c
283+
}
284+
285+
func (dp *DialProxy) HandleConn(src net.Conn) {
233286
ctx := context.Background()
234287
var cancel context.CancelFunc
235288
if dp.DialTimeout >= 0 {
@@ -246,7 +299,7 @@ func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) {
246299
defer src.Close()
247300
defer dst.Close()
248301
if ka := dp.keepAlivePeriod(); ka > 0 {
249-
if c, ok := rawSrc.(*net.TCPConn); ok {
302+
if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
250303
c.SetKeepAlive(true)
251304
c.SetKeepAlivePeriod(ka)
252305
}

0 commit comments

Comments
 (0)