Skip to content

Commit

Permalink
Added PriorityDialer to choose best server based on estimated RTT.
Browse files Browse the repository at this point in the history
  • Loading branch information
riobard committed Oct 1, 2017
1 parent 2c8602a commit cbc8403
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 100 deletions.
118 changes: 118 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package main

import (
"net"
"sync/atomic"
"time"

"github.com/riobard/go-shadowsocks2/core"
"github.com/riobard/go-shadowsocks2/socks"
)

type Dialer interface {
Dial(network, address string) (net.Conn, error)
}

type dialer struct {
dialTime int64 // time to dial in nanoseconds (exponetially smoothed)
lastUpdated atomic.Value // of time.Time
server string
shadow func(net.Conn) net.Conn
}

func NewDialer(u string) (*dialer, error) {
addr, cipher, password, err := parseURL(u)
if err != nil {
return nil, err
}
ciph, err := core.PickCipher(cipher, nil, password)
if err != nil {
return nil, err
}
d := &dialer{server: addr, shadow: ciph.StreamConn}
d.lastUpdated.Store(time.Time{})
return d, nil
}

func (d *dialer) Dial(network, address string) (net.Conn, error) {
c, err := d.dial()
if err != nil {
return c, err
}
c.(*net.TCPConn).SetKeepAlive(true)
c = d.shadow(c)
_, err = c.Write(socks.ParseAddr(address))
return c, err
}

func (d *dialer) dial() (net.Conn, error) {
const timeout = 2 * time.Second
const wt = 4

t0 := time.Now()
c, err := net.DialTimeout("tcp", d.server, timeout)
td := time.Since(t0)
if err != nil {
td = timeout // penality
}

new := td.Nanoseconds()
if old := atomic.LoadInt64(&d.dialTime); old > 0 {
new = (wt*old + new) / (wt + 1) // Exponentially Weighted Moving Average
}
atomic.StoreInt64(&d.dialTime, new)
logf("probe %s [%d ms] err=%v", d.server, new/1e6, err)
d.lastUpdated.Store(time.Now())
return c, err
}

// Actively measure average dial time
func (d *dialer) probe() {
const interval = 10 * time.Second
for {
age := time.Since(d.lastUpdated.Load().(time.Time))
if age > interval {
if c, err := d.dial(); err == nil {
c.Close()
}
} else {
time.Sleep(interval - age)
}
}
}

type priorityDialer struct {
dialers []*dialer
}

func NewPriorityDialer(u ...string) (*priorityDialer, error) {
var dialers []*dialer

for _, each := range u {
d, err := NewDialer(each)
if err != nil {
return nil, err
}
dialers = append(dialers, d)
}

for _, d := range dialers {
go d.probe()
}

return &priorityDialer{dialers}, nil
}

const maxInt64 = int64(1<<63 - 1)

func (d *priorityDialer) Dial(network, address string) (net.Conn, error) {
tMin := maxInt64
var dMin *dialer
for _, d := range d.dialers {
if t := atomic.LoadInt64(&d.dialTime); t < tMin {
dMin, tMin = d, t
}
}
logf("best server %s [%d ms]", dMin.server, tMin/1e6)
return dMin.Dial(network, address)
}
140 changes: 68 additions & 72 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package main

import (
"crypto/rand"
"encoding/base64"
"flag"
"fmt"
"io"
"log"
"net/url"
"os"
Expand All @@ -31,120 +27,92 @@ func logf(f string, v ...interface{}) {
func main() {

var flags struct {
Client string
Server string
Cipher string
Key string
Password string
Keygen int
Client SpaceSeparatedList
Server SpaceSeparatedList
TCPTun PairList
UDPTun PairList
Socks string
RedirTCP string
RedirTCP6 string
TCPTun string
UDPTun string
}

listCiphers := flag.Bool("cipher", false, "List supported ciphers")
flag.BoolVar(&config.Verbose, "verbose", false, "verbose mode")
flag.StringVar(&flags.Cipher, "cipher", "AEAD_CHACHA20_POLY1305", "available ciphers: "+strings.Join(core.ListCipher(), " "))
flag.StringVar(&flags.Key, "key", "", "base64url-encoded key (derive from password if empty)")
flag.IntVar(&flags.Keygen, "keygen", 0, "generate a base64url-encoded random key of given length in byte")
flag.StringVar(&flags.Password, "password", "", "password")
flag.StringVar(&flags.Server, "s", "", "server listen address or url")
flag.StringVar(&flags.Client, "c", "", "client connect address or url")
flag.Var(&flags.Server, "s", "server listen url")
flag.Var(&flags.Client, "c", "client connect url")
flag.Var(&flags.TCPTun, "tcptun", "(client-only) TCP tunnel (laddr1=raddr1,laddr2=raddr2,...)")
flag.Var(&flags.UDPTun, "udptun", "(client-only) UDP tunnel (laddr1=raddr1,laddr2=raddr2,...)")
flag.StringVar(&flags.Socks, "socks", "", "(client-only) SOCKS listen address")
flag.StringVar(&flags.RedirTCP, "redir", "", "(client-only) redirect TCP from this address")
flag.StringVar(&flags.RedirTCP6, "redir6", "", "(client-only) redirect TCP IPv6 from this address")
flag.StringVar(&flags.TCPTun, "tcptun", "", "(client-only) TCP tunnel (laddr1=raddr1,laddr2=raddr2,...)")
flag.StringVar(&flags.UDPTun, "udptun", "", "(client-only) UDP tunnel (laddr1=raddr1,laddr2=raddr2,...)")
flag.DurationVar(&config.UDPTimeout, "udptimeout", 5*time.Minute, "UDP tunnel timeout")
flag.DurationVar(&config.UDPTimeout, "udptimeout", 120*time.Second, "UDP tunnel timeout")
flag.Parse()

if flags.Keygen > 0 {
key := make([]byte, flags.Keygen)
io.ReadFull(rand.Reader, key)
fmt.Println(base64.URLEncoding.EncodeToString(key))
if *listCiphers {
println(strings.Join(core.ListCipher(), " "))
return
}

if flags.Client == "" && flags.Server == "" {
if len(flags.Client) == 0 && len(flags.Server) == 0 {
flag.Usage()
return
}

var key []byte
if flags.Key != "" {
k, err := base64.URLEncoding.DecodeString(flags.Key)
if err != nil {
log.Fatal(err)
}
key = k
}

if flags.Client != "" { // client mode
addr := flags.Client
cipher := flags.Cipher
password := flags.Password
var err error
if len(flags.Client) > 0 { // client mode
if len(flags.UDPTun) > 0 { // use first server for UDP
addr, cipher, password, err := parseURL(flags.Client[0])
if err != nil {
log.Fatal(err)
}

if strings.HasPrefix(addr, "ss://") {
addr, cipher, password, err = parseURL(addr)
ciph, err := core.PickCipher(cipher, nil, password)
if err != nil {
log.Fatal(err)
}
for _, p := range flags.UDPTun {
go udpLocal(p[0], addr, p[1], ciph.PacketConn)
}
}

ciph, err := core.PickCipher(cipher, key, password)
d, err := NewPriorityDialer(flags.Client...)
if err != nil {
log.Fatal(err)
log.Fatalf("failed to create dialer: %v", err)
}

if flags.UDPTun != "" {
for _, tun := range strings.Split(flags.UDPTun, ",") {
p := strings.Split(tun, "=")
go udpLocal(p[0], addr, p[1], ciph.PacketConn)
}
}

if flags.TCPTun != "" {
for _, tun := range strings.Split(flags.TCPTun, ",") {
p := strings.Split(tun, "=")
go tcpTun(p[0], addr, p[1], ciph.StreamConn)
if len(flags.TCPTun) > 0 {
for _, p := range flags.TCPTun {
go tcpTun(p[0], p[1], d)
}
}

if flags.Socks != "" {
go socksLocal(flags.Socks, addr, ciph.StreamConn)
go socksLocal(flags.Socks, d)
}

if flags.RedirTCP != "" {
go redirLocal(flags.RedirTCP, addr, ciph.StreamConn)
go redirLocal(flags.RedirTCP, d)
}

if flags.RedirTCP6 != "" {
go redir6Local(flags.RedirTCP6, addr, ciph.StreamConn)
go redir6Local(flags.RedirTCP6, d)
}
}

if flags.Server != "" { // server mode
addr := flags.Server
cipher := flags.Cipher
password := flags.Password
var err error
if len(flags.Server) > 0 { // server mode
for _, each := range flags.Server {
addr, cipher, password, err := parseURL(each)
if err != nil {
log.Fatal(err)
}

if strings.HasPrefix(addr, "ss://") {
addr, cipher, password, err = parseURL(addr)
ciph, err := core.PickCipher(cipher, nil, password)
if err != nil {
log.Fatal(err)
}
}

ciph, err := core.PickCipher(cipher, key, password)
if err != nil {
log.Fatal(err)
go udpRemote(addr, ciph.PacketConn)
go tcpRemote(addr, ciph.StreamConn)
}

go udpRemote(addr, ciph.PacketConn)
go tcpRemote(addr, ciph.StreamConn)
}

sigCh := make(chan os.Signal, 1)
Expand All @@ -165,3 +133,31 @@ func parseURL(s string) (addr, cipher, password string, err error) {
}
return
}

type PairList [][2]string // key1=val1,key2=val2,...

func (l PairList) String() string {
s := make([]string, len(l))
for i, pair := range l {
s[i] = pair[0] + "=" + pair[1]
}
return strings.Join(s, ",")
}
func (l *PairList) Set(s string) error {
for _, item := range strings.Split(s, ",") {
pair := strings.Split(item, "=")
if len(pair) != 2 {
return nil
}
*l = append(*l, [2]string{pair[0], pair[1]})
}
return nil
}

type SpaceSeparatedList []string

func (l SpaceSeparatedList) String() string { return strings.Join(l, " ") }
func (l *SpaceSeparatedList) Set(s string) error {
*l = strings.Split(s, " ")
return nil
}
28 changes: 10 additions & 18 deletions tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,24 @@ import (
)

// Create a SOCKS server listening on addr and proxy to server.
func socksLocal(addr, server string, shadow func(net.Conn) net.Conn) {
logf("SOCKS proxy %s <-> %s", addr, server)
tcpLocal(addr, server, shadow, func(c net.Conn) (socks.Addr, error) { return socks.Handshake(c) })
func socksLocal(addr string, d Dialer) {
logf("SOCKS proxy %s", addr)
tcpLocal(addr, d, func(c net.Conn) (socks.Addr, error) { return socks.Handshake(c) })
}

// Create a TCP tunnel from addr to target via server.
func tcpTun(addr, server, target string, shadow func(net.Conn) net.Conn) {
func tcpTun(addr, target string, d Dialer) {
tgt := socks.ParseAddr(target)
if tgt == nil {
logf("invalid target address %q", target)
return
}
logf("TCP tunnel %s <-> %s <-> %s", addr, server, target)
tcpLocal(addr, server, shadow, func(net.Conn) (socks.Addr, error) { return tgt, nil })
logf("TCP tunnel %s <-> %s", addr, target)
tcpLocal(addr, d, func(net.Conn) (socks.Addr, error) { return tgt, nil })
}

// Listen on addr and proxy to server to reach target from getAddr.
func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func(net.Conn) (socks.Addr, error)) {
func tcpLocal(addr string, d Dialer, getAddr func(net.Conn) (socks.Addr, error)) {
l, err := net.Listen("tcp", addr)
if err != nil {
logf("failed to listen on %s: %v", addr, err)
Expand All @@ -51,21 +51,13 @@ func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func(
return
}

rc, err := net.Dial("tcp", server)
rc, err := d.Dial("tcp", tgt.String())
if err != nil {
logf("failed to connect to server %v: %v", server, err)
return
}
defer rc.Close()
rc.(*net.TCPConn).SetKeepAlive(true)
rc = shadow(rc)

if _, err = rc.Write(tgt); err != nil {
logf("failed to send target address: %v", err)
logf("failed to connect: %v", err)
return
}

logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, tgt)
logf("proxy %s <--[%s]--> %s", c.RemoteAddr(), rc.RemoteAddr(), tgt)
_, _, err = relay(rc, c)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
Expand Down
Loading

0 comments on commit cbc8403

Please sign in to comment.