-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.go
176 lines (161 loc) · 5.34 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package main
import (
"crypto/tls"
"flag"
"fmt"
"log"
"net"
"os"
"os/signal"
"runtime"
"syscall"
"time"
conn "github.com/Snawoot/steady-tun/conn"
"github.com/Snawoot/steady-tun/dnscache"
clog "github.com/Snawoot/steady-tun/log"
"github.com/Snawoot/steady-tun/pool"
"github.com/Snawoot/steady-tun/server"
)
var (
version = "undefined"
)
func perror(msg string) {
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, msg)
}
func arg_fail(msg string) {
perror(msg)
perror("Usage:")
flag.PrintDefaults()
os.Exit(2)
}
type CLIArgs struct {
host string
port uint
verbosity int
bind_address string
bind_port uint
pool_size uint
dialers uint
backoff, ttl, timeout time.Duration
cert, key, cafile string
hostname_check bool
tls_servername string
tlsSessionCache bool
tlsEnabled bool
dnsCacheTTL time.Duration
dnsNegCacheTTL time.Duration
showVersion bool
}
func parse_args() CLIArgs {
args := CLIArgs{}
flag.StringVar(&args.host, "dsthost", "", "destination server hostname")
flag.UintVar(&args.port, "dstport", 0, "destination server port")
flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity "+
"(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)")
flag.StringVar(&args.bind_address, "bind-address", "127.0.0.1", "bind address")
flag.UintVar(&args.bind_port, "bind-port", 57800, "bind port")
flag.UintVar(&args.pool_size, "pool-size", 50, "connection pool size")
flag.UintVar(&args.dialers, "dialers", uint(4*runtime.GOMAXPROCS(0)), "concurrency limit for TLS connection attempts")
flag.DurationVar(&args.backoff, "backoff", 5*time.Second, "delay between connection attempts")
flag.DurationVar(&args.ttl, "ttl", 30*time.Second, "lifetime of idle pool connection in seconds")
flag.DurationVar(&args.timeout, "timeout", 4*time.Second, "server connect timeout")
flag.StringVar(&args.cert, "cert", "", "use certificate for client TLS auth")
flag.StringVar(&args.key, "key", "", "key for TLS certificate")
flag.StringVar(&args.cafile, "cafile", "", "override default CA certs by specified in file")
flag.BoolVar(&args.hostname_check, "hostname-check", true, "check hostname in server cert subject")
flag.StringVar(&args.tls_servername, "tls-servername", "", "specifies hostname to expect in server cert")
flag.BoolVar(&args.tlsSessionCache, "tls-session-cache", true, "enable TLS session cache")
flag.BoolVar(&args.showVersion, "version", false, "show program version and exit")
flag.BoolVar(&args.tlsEnabled, "tls-enabled", true, "enable TLS client for pool connections")
flag.DurationVar(&args.dnsCacheTTL, "dns-cache-ttl", 30*time.Second, "DNS cache TTL")
flag.DurationVar(&args.dnsNegCacheTTL, "dns-neg-cache-ttl", 1*time.Second, "negative DNS cache TTL")
flag.Parse()
if args.showVersion {
return args
}
if args.host == "" {
arg_fail("Destination host argument is required!")
}
if args.port == 0 {
arg_fail("Destination host argument is required!")
}
if args.port >= 65536 {
arg_fail("Bad destination port!")
}
if args.bind_port >= 65536 {
arg_fail("Bad bind port!")
}
if args.dialers < 1 {
arg_fail("dialers parameter should be not less than 1")
}
return args
}
func main() {
args := parse_args()
if args.showVersion {
fmt.Println(version)
return
}
logWriter := clog.NewLogWriter(os.Stderr)
defer logWriter.Close()
mainLogger := clog.NewCondLogger(log.New(logWriter, "MAIN : ", log.LstdFlags|log.Lshortfile),
args.verbosity)
listenerLogger := clog.NewCondLogger(log.New(logWriter, "LISTENER: ", log.LstdFlags|log.Lshortfile),
args.verbosity)
handlerLogger := clog.NewCondLogger(log.New(logWriter, "HANDLER : ", log.LstdFlags|log.Lshortfile),
args.verbosity)
connLogger := clog.NewCondLogger(log.New(logWriter, "CONN : ", log.LstdFlags|log.Lshortfile),
args.verbosity)
poolLogger := clog.NewCondLogger(log.New(logWriter, "POOL : ", log.LstdFlags|log.Lshortfile),
args.verbosity)
var (
dialer conn.ContextDialer
connfactory conn.Factory
err error
)
dialer = (&net.Dialer{
Timeout: args.timeout,
}).DialContext
if args.dnsCacheTTL > 0 {
dialer = dnscache.WrapDialer(dialer, net.DefaultResolver, 128, args.dnsCacheTTL, args.dnsNegCacheTTL, args.timeout)
}
if args.tlsEnabled {
var sessionCache tls.ClientSessionCache
if args.tlsSessionCache {
sessionCache = tls.NewLRUClientSessionCache(2 * int(args.pool_size))
}
connfactory, err = conn.NewTLSConnFactory(args.host,
uint16(args.port),
dialer,
args.cert,
args.key,
args.cafile,
args.hostname_check,
args.tls_servername,
args.dialers,
sessionCache,
connLogger)
if err != nil {
panic(err)
}
} else {
connfactory = conn.NewPlainConnFactory(args.host, uint16(args.port), dialer)
}
connPool := pool.NewConnPool(args.pool_size, args.ttl, args.backoff, connfactory.DialContext, poolLogger)
connPool.Start()
defer connPool.Stop()
listener := server.NewTCPListener(args.bind_address,
uint16(args.bind_port),
server.NewConnHandler(connPool, handlerLogger).Handle,
listenerLogger)
if err := listener.Start(); err != nil {
panic(err)
}
defer listener.Stop()
mainLogger.Info("Listener started.")
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
<-sigs
mainLogger.Info("Shutting down...")
}