Skip to content

Commit

Permalink
Merge pull request #11 from Snawoot/refactoring
Browse files Browse the repository at this point in the history
Factor out connection wrapping
  • Loading branch information
Snawoot authored Oct 8, 2024
2 parents 79ad2bc + 6bd308f commit b1e14b4
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 94 deletions.
74 changes: 0 additions & 74 deletions conn/connection.go

This file was deleted.

4 changes: 4 additions & 0 deletions conn/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ import (
)

type ContextDialer = func(ctx context.Context, network, address string) (net.Conn, error)

type Factory interface {
DialContext(ctx context.Context) (net.Conn, error)
}
8 changes: 4 additions & 4 deletions conn/plainfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ type PlainConnFactory struct {
dialer ContextDialer
}

var _ Factory = &PlainConnFactory{}

func NewPlainConnFactory(host string, port uint16, dialer ContextDialer) *PlainConnFactory {
return &PlainConnFactory{
addr: net.JoinHostPort(host, strconv.Itoa(int(port))),
dialer: dialer,
}
}

func (cf *PlainConnFactory) DialContext(ctx context.Context) (WrappedConn, error) {
func (cf *PlainConnFactory) DialContext(ctx context.Context) (net.Conn, error) {
conn, err := cf.dialer(ctx, "tcp", cf.addr)
if err != nil {
return nil, fmt.Errorf("cf.dialer.DialContext(ctx, \"tcp\", %q) failed: %v", cf.addr, err)
}
return &wrappedConn{
conn: conn,
}, nil
return conn, nil
}
8 changes: 4 additions & 4 deletions conn/tlsfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type TLSConnFactory struct {
sem *semaphore.Weighted
}

var _ Factory = &TLSConnFactory{}

func NewTLSConnFactory(host string, port uint16, dialer ContextDialer,
certfile, keyfile string, cafile string, hostname_check bool,
tls_servername string, dialers uint, sessionCache tls.ClientSessionCache, logger *clog.CondLogger) (*TLSConnFactory, error) {
Expand Down Expand Up @@ -92,7 +94,7 @@ func NewTLSConnFactory(host string, port uint16, dialer ContextDialer,
}, nil
}

func (cf *TLSConnFactory) DialContext(ctx context.Context) (WrappedConn, error) {
func (cf *TLSConnFactory) DialContext(ctx context.Context) (net.Conn, error) {
if cf.sem.Acquire(ctx, 1) != nil {
return nil, errors.New("Context was cancelled")
}
Expand All @@ -107,7 +109,5 @@ func (cf *TLSConnFactory) DialContext(ctx context.Context) (WrappedConn, error)
netConn.Close()
return nil, fmt.Errorf("tlsConn.HandshakeContext(ctx) failed: %v", err)
}
return &wrappedConn{
conn: tlsConn,
}, nil
return tlsConn, nil
}
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func main() {

var (
dialer conn.ContextDialer
connfactory pool.ConnFactory
connfactory conn.Factory
err error
)
dialer = (&net.Dialer{
Expand Down Expand Up @@ -154,7 +154,7 @@ func main() {
} else {
connfactory = conn.NewPlainConnFactory(args.host, uint16(args.port), dialer)
}
connPool := pool.NewConnPool(args.pool_size, args.ttl, args.backoff, connfactory, poolLogger)
connPool := pool.NewConnPool(args.pool_size, args.ttl, args.backoff, connfactory.DialContext, poolLogger)
connPool.Start()
defer connPool.Stop()

Expand Down
33 changes: 23 additions & 10 deletions pool/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ import (
"time"

"github.com/Snawoot/steady-tun/clock"
"github.com/Snawoot/steady-tun/conn"
clog "github.com/Snawoot/steady-tun/log"
"github.com/Snawoot/steady-tun/queue"
)

type ConnFactory interface {
DialContext(context.Context) (conn.WrappedConn, error)
}
type ConnFactory = func(context.Context) (net.Conn, error)

type ConnPool struct {
size uint
ttl, backoff time.Duration
connfactory ConnFactory
connFactory ConnFactory
waiters, prepared *queue.RAQueue
qmux sync.Mutex
logger *clog.CondLogger
Expand All @@ -35,13 +32,13 @@ type watchedConn struct {
}

func NewConnPool(size uint, ttl, backoff time.Duration,
connfactory ConnFactory, logger *clog.CondLogger) *ConnPool {
connFactory ConnFactory, logger *clog.CondLogger) *ConnPool {
ctx, cancel := context.WithCancel(context.Background())
return &ConnPool{
size: size,
ttl: ttl,
backoff: backoff,
connfactory: connfactory,
connFactory: connFactory,
waiters: queue.NewRAQueue(),
prepared: queue.NewRAQueue(),
logger: logger,
Expand Down Expand Up @@ -87,7 +84,7 @@ func (p *ConnPool) worker() {
return
default:
}
conn, err := p.connfactory.DialContext(p.ctx)
conn, err := p.connFactory(p.ctx)
if err != nil {
select {
case <-p.ctx.Done():
Expand All @@ -113,8 +110,8 @@ func (p *ConnPool) worker() {
readctx, readcancel := context.WithCancel(p.ctx)
readdone := make(chan struct{}, 1)
go func() {
conn.ReadContext(readctx, dummybuf)
readdone <- struct{}{}
connReadContext(readctx, conn, dummybuf)
close(readdone)
}()
watched := &watchedConn{conn, readcancel, readdone}
select {
Expand Down Expand Up @@ -174,3 +171,19 @@ func (p *ConnPool) Stop() {
p.cancel()
p.shutdown.Wait()
}

func connReadContext(ctx context.Context, conn net.Conn, p []byte) (n int, err error) {
readDone := make(chan struct{})
go func() {
defer close(readDone)
n, err = conn.Read(p)
}()
select {
case <-ctx.Done():
conn.SetReadDeadline(time.Unix(0, 0))
<-readDone
conn.SetReadDeadline(time.Time{})
case <-readDone:
}
return
}

0 comments on commit b1e14b4

Please sign in to comment.