diff --git a/etcdmain/gateway.go b/etcdmain/gateway.go index 1a72bddcf082..5487414ebd58 100644 --- a/etcdmain/gateway.go +++ b/etcdmain/gateway.go @@ -91,17 +91,28 @@ func stripSchema(eps []string) []string { return endpoints } -func startGateway(cmd *cobra.Command, args []string) { - endpoints := gatewayEndpoints - if eps := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery); len(eps) != 0 { - endpoints = eps +func startGateway(cmd *cobra.Command, args []string) { + srvs := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery) + if len(srvs.Endpoints) == 0 { + // no endpoints discovered, fall back to provided endpoints + srvs.Endpoints = gatewayEndpoints } - // Strip the schema from the endpoints because we start just a TCP proxy - endpoints = stripSchema(endpoints) + srvs.Endpoints = stripSchema(srvs.Endpoints) + if len(srvs.SRVs) == 0 { + for _, ep := range srvs.Endpoints { + h, p, err := net.SplitHostPort(ep) + if err != nil { + plog.Fatalf("error parsing endpoint %q", ep) + } + var port uint16 + fmt.Sscanf(p, "%d", &port) + srvs.SRVs = append(srvs.SRVs, &net.SRV{Target: h, Port: port}) + } + } - if len(endpoints) == 0 { + if len(srvs.Endpoints) == 0 { plog.Fatalf("no endpoints found") } @@ -113,7 +124,7 @@ func startGateway(cmd *cobra.Command, args []string) { tp := tcpproxy.TCPProxy{ Listener: l, - Endpoints: endpoints, + Endpoints: srvs.SRVs, MonitorInterval: getewayRetryDelay, } diff --git a/etcdmain/grpc_proxy.go b/etcdmain/grpc_proxy.go index 1f701ba12979..ae5af8bbf847 100644 --- a/etcdmain/grpc_proxy.go +++ b/etcdmain/grpc_proxy.go @@ -106,8 +106,9 @@ func startGRPCProxy(cmd *cobra.Command, args []string) { os.Exit(1) } - if eps := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery); len(eps) != 0 { - grpcProxyEndpoints = eps + srvs := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery) + if len(srvs.Endpoints) != 0 { + grpcProxyEndpoints = srvs.Endpoints } l, err := net.Listen("tcp", grpcProxyListenAddr) diff --git a/etcdmain/util.go b/etcdmain/util.go index 5de07275b5bf..9657271d53a5 100644 --- a/etcdmain/util.go +++ b/etcdmain/util.go @@ -22,19 +22,19 @@ import ( "github.com/coreos/etcd/pkg/transport" ) -func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string) { +func discoverEndpoints(dns string, ca string, insecure bool) (s srv.SRVClients) { if dns == "" { - return nil + return s } srvs, err := srv.GetClient("etcd-client", dns) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } - endpoints = srvs.Endpoints + endpoints := srvs.Endpoints plog.Infof("discovered the cluster %s from %s", endpoints, dns) if insecure { - return endpoints + return *srvs } // confirm TLS connections are good tlsInfo := transport.TLSInfo{ @@ -47,5 +47,19 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string plog.Warningf("%v", err) } plog.Infof("using discovered endpoints %v", endpoints) - return endpoints + + // map endpoints back to SRVClients struct with SRV data + eps := make(map[string]struct{}) + for _, ep := range endpoints { + eps[ep] = struct{}{} + } + for i := range srvs.Endpoints { + if _, ok := eps[srvs.Endpoints[i]]; !ok { + continue + } + s.Endpoints = append(s.Endpoints, srvs.Endpoints[i]) + s.SRVs = append(s.SRVs, srvs.SRVs[i]) + } + + return s } diff --git a/proxy/tcpproxy/userspace.go b/proxy/tcpproxy/userspace.go index 5de017a70de0..01e40a24c5e9 100644 --- a/proxy/tcpproxy/userspace.go +++ b/proxy/tcpproxy/userspace.go @@ -15,7 +15,9 @@ package tcpproxy import ( + "fmt" "io" + "math/rand" "net" "sync" "time" @@ -29,6 +31,7 @@ var ( type remote struct { mu sync.Mutex + srv *net.SRV addr string inactive bool } @@ -59,14 +62,14 @@ func (r *remote) isActive() bool { type TCPProxy struct { Listener net.Listener - Endpoints []string + Endpoints []*net.SRV MonitorInterval time.Duration donec chan struct{} - mu sync.Mutex // guards the following fields - remotes []*remote - nextRemote int + mu sync.Mutex // guards the following fields + remotes []*remote + pickCount int // for round robin } func (tp *TCPProxy) Run() error { @@ -74,11 +77,12 @@ func (tp *TCPProxy) Run() error { if tp.MonitorInterval == 0 { tp.MonitorInterval = 5 * time.Minute } - for _, ep := range tp.Endpoints { - tp.remotes = append(tp.remotes, &remote{addr: ep}) + for _, srv := range tp.Endpoints { + addr := fmt.Sprintf("%s:%d", srv.Target, srv.Port) + tp.remotes = append(tp.remotes, &remote{srv: srv, addr: addr}) } - plog.Printf("ready to proxy client requests to %v", tp.Endpoints) + plog.Printf("ready to proxy client requests to %+v", tp.Endpoints) go tp.runMonitor() for { in, err := tp.Listener.Accept() @@ -90,10 +94,61 @@ func (tp *TCPProxy) Run() error { } } -func (tp *TCPProxy) numRemotes() int { - tp.mu.Lock() - defer tp.mu.Unlock() - return len(tp.remotes) +func (tp *TCPProxy) pick() *remote { + var weighted []*remote + var unweighted []*remote + + bestPr := uint16(65535) + w := 0 + // find best priority class + for _, r := range tp.remotes { + switch { + case !r.isActive(): + case r.srv.Priority < bestPr: + bestPr = r.srv.Priority + w = 0 + weighted, unweighted = nil, nil + unweighted = []*remote{r} + fallthrough + case r.srv.Priority == bestPr: + if r.srv.Weight > 0 { + weighted = append(weighted, r) + w += int(r.srv.Weight) + } else { + unweighted = append(unweighted, r) + } + } + } + if weighted != nil { + if len(unweighted) > 0 && rand.Intn(100) == 1 { + // In the presence of records containing weights greater + // than 0, records with weight 0 should have a very small + // chance of being selected. + r := unweighted[tp.pickCount%len(unweighted)] + tp.pickCount++ + return r + } + // choose a uniform random number between 0 and the sum computed + // (inclusive), and select the RR whose running sum value is the + // first in the selected order + choose := rand.Intn(w) + for i := 0; i < len(weighted); i++ { + choose -= int(weighted[i].srv.Weight) + if choose <= 0 { + return weighted[i] + } + } + } + if unweighted != nil { + for i := 0; i < len(tp.remotes); i++ { + picked := tp.remotes[tp.pickCount%len(tp.remotes)] + tp.pickCount++ + if picked.isActive() { + return picked + } + } + } + return nil } func (tp *TCPProxy) serve(in net.Conn) { @@ -102,10 +157,12 @@ func (tp *TCPProxy) serve(in net.Conn) { out net.Conn ) - for i := 0; i < tp.numRemotes(); i++ { + for { + tp.mu.Lock() remote := tp.pick() - if !remote.isActive() { - continue + tp.mu.Unlock() + if remote == nil { + break } // TODO: add timeout out, err = net.Dial("tcp", remote.addr) @@ -132,16 +189,6 @@ func (tp *TCPProxy) serve(in net.Conn) { in.Close() } -// pick picks a remote in round-robin fashion -func (tp *TCPProxy) pick() *remote { - tp.mu.Lock() - defer tp.mu.Unlock() - - picked := tp.remotes[tp.nextRemote] - tp.nextRemote = (tp.nextRemote + 1) % len(tp.remotes) - return picked -} - func (tp *TCPProxy) runMonitor() { for { select { diff --git a/proxy/tcpproxy/userspace_test.go b/proxy/tcpproxy/userspace_test.go index e239c19c6624..bf65f570c214 100644 --- a/proxy/tcpproxy/userspace_test.go +++ b/proxy/tcpproxy/userspace_test.go @@ -42,9 +42,11 @@ func TestUserspaceProxy(t *testing.T) { t.Fatal(err) } + var port uint16 + fmt.Sscanf(u.Port(), "%d", &port) p := TCPProxy{ Listener: l, - Endpoints: []string{u.Host}, + Endpoints: []*net.SRV{{Target: u.Hostname(), Port: port}}, } go p.Run() defer p.Stop()