Skip to content

Commit

Permalink
clientv3: Fix endpoint resolver to create a new resolver for each grp…
Browse files Browse the repository at this point in the history
…c client connection
  • Loading branch information
jpbetz authored and gyuho committed Jun 15, 2018
1 parent 9304d1a commit 8569b9c
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 126 deletions.
42 changes: 25 additions & 17 deletions clientv3/balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -58,14 +57,17 @@ func TestRoundRobinBalancedResolvableNoFailover(t *testing.T) {
}
defer ms.Stop()

var resolvedAddrs []resolver.Address
var eps []string
for _, svr := range ms.Servers {
resolvedAddrs = append(resolvedAddrs, svr.ResolverAddress())
eps = append(eps, svr.ResolverAddress().Addr)
}

rsv := endpoint.EndpointResolver("nofailover")
rsv, err := endpoint.NewResolverGroup("nofailover")
if err != nil {
t.Fatal(err)
}
defer rsv.Close()
rsv.InitialAddrs(resolvedAddrs)
rsv.SetEndpoints(eps)

name := genName()
cfg := Config{
Expand Down Expand Up @@ -121,14 +123,17 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
t.Fatalf("failed to start mock servers: %s", err)
}
defer ms.Stop()
var resolvedAddrs []resolver.Address
var eps []string
for _, svr := range ms.Servers {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address})
eps = append(eps, svr.ResolverAddress().Addr)
}

rsv := endpoint.EndpointResolver("serverfail")
rsv, err := endpoint.NewResolverGroup("serverfail")
if err != nil {
t.Fatal(err)
}
defer rsv.Close()
rsv.InitialAddrs(resolvedAddrs)
rsv.SetEndpoints(eps)

name := genName()
cfg := Config{
Expand Down Expand Up @@ -158,7 +163,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
ms.StopAt(0)
available := make(map[string]struct{})
for i := 1; i < serverCount; i++ {
available[resolvedAddrs[i].Addr] = struct{}{}
available[eps[i]] = struct{}{}
}

reqN := 10
Expand All @@ -169,8 +174,8 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
continue
}
if prev == "" { // first failover
if resolvedAddrs[0].Addr == picked {
t.Fatalf("expected failover from %q, picked %q", resolvedAddrs[0].Addr, picked)
if eps[0] == picked {
t.Fatalf("expected failover from %q, picked %q", eps[0], picked)
}
prev = picked
continue
Expand All @@ -194,7 +199,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
time.Sleep(time.Second)

prev, switches = "", 0
recoveredAddr, recovered := resolvedAddrs[0].Addr, 0
recoveredAddr, recovered := eps[0], 0
available[recoveredAddr] = struct{}{}

for i := 0; i < 2*reqN; i++ {
Expand Down Expand Up @@ -234,15 +239,18 @@ func TestRoundRobinBalancedResolvableFailoverFromRequestFail(t *testing.T) {
t.Fatalf("failed to start mock servers: %s", err)
}
defer ms.Stop()
var resolvedAddrs []resolver.Address
var eps []string
available := make(map[string]struct{})
for _, svr := range ms.Servers {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address})
eps = append(eps, svr.ResolverAddress().Addr)
available[svr.Address] = struct{}{}
}
rsv := endpoint.EndpointResolver("requestfail")
rsv, err := endpoint.NewResolverGroup("requestfail")
if err != nil {
t.Fatal(err)
}
defer rsv.Close()
rsv.InitialAddrs(resolvedAddrs)
rsv.SetEndpoints(eps)

name := genName()
cfg := Config{
Expand Down
180 changes: 104 additions & 76 deletions clientv3/balancer/resolver/endpoint/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<clientId>/<endpoint>'.
// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<id>/<endpoint>'.
package endpoint

import (
Expand All @@ -36,91 +36,140 @@ var (

func init() {
bldr = &builder{
clientResolvers: make(map[string]*Resolver),
resolverGroups: make(map[string]*ResolverGroup),
}
resolver.Register(bldr)
}

type builder struct {
clientResolvers map[string]*Resolver
resolverGroups map[string]*ResolverGroup
sync.RWMutex
}

// NewResolverGroup creates a new ResolverGroup with the given id.
func NewResolverGroup(id string) (*ResolverGroup, error) {
return bldr.newResolverGroup(id)
}

// ResolverGroup keeps all endpoints of resolvers using a common endpoint://<id>/ target
// up-to-date.
type ResolverGroup struct {
id string
endpoints []string
resolvers []*Resolver
sync.RWMutex
}

func (e *ResolverGroup) addResolver(r *Resolver) {
e.Lock()
addrs := epsToAddrs(e.endpoints...)
e.resolvers = append(e.resolvers, r)
e.Unlock()
r.cc.NewAddress(addrs)
}

func (e *ResolverGroup) removeResolver(r *Resolver) {
e.Lock()
for i, er := range e.resolvers {
if er == r {
e.resolvers = append(e.resolvers[:i], e.resolvers[i+1:]...)
break
}
}
e.Unlock()
}

// SetEndpoints updates the endpoints for ResolverGroup. All registered resolver are updated
// immediately with the new endpoints.
func (e *ResolverGroup) SetEndpoints(endpoints []string) {
addrs := epsToAddrs(endpoints...)
e.Lock()
e.endpoints = endpoints
for _, r := range e.resolvers {
r.cc.NewAddress(addrs)
}
e.Unlock()
}

// Target constructs a endpoint target using the endpoint id of the ResolverGroup.
func (e *ResolverGroup) Target(endpoint string) string {
return Target(e.id, endpoint)
}

// Target constructs a endpoint resolver target.
func Target(id, endpoint string) string {
return fmt.Sprintf("%s://%s/%s", scheme, id, endpoint)
}

// IsTarget checks if a given target string in an endpoint resolver target.
func IsTarget(target string) bool {
return strings.HasPrefix(target, "endpoint://")
}

func (e *ResolverGroup) Close() {
bldr.close(e.id)
}

// Build creates or reuses an etcd resolver for the etcd cluster name identified by the authority part of the target.
func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
if len(target.Authority) < 1 {
return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to")
}
r := b.getResolver(target.Authority)
r.cc = cc
if r.addrs != nil {
r.NewAddress(r.addrs)
id := target.Authority
es, err := b.getResolverGroup(id)
if err != nil {
return nil, fmt.Errorf("failed to build resolver: %v", err)
}
r := &Resolver{
endpointId: id,
cc: cc,
}
es.addResolver(r)
return r, nil
}

func (b *builder) getResolver(clientId string) *Resolver {
func (b *builder) newResolverGroup(id string) (*ResolverGroup, error) {
b.RLock()
r, ok := b.clientResolvers[clientId]
es, ok := b.resolverGroups[id]
b.RUnlock()
if !ok {
r = &Resolver{
clientId: clientId,
}
es = &ResolverGroup{id: id}
b.Lock()
b.clientResolvers[clientId] = r
b.resolverGroups[id] = es
b.Unlock()
} else {
return nil, fmt.Errorf("Endpoint already exists for id: %s", id)
}
return r
return es, nil
}

func (b *builder) addResolver(r *Resolver) {
bldr.Lock()
bldr.clientResolvers[r.clientId] = r
bldr.Unlock()
func (b *builder) getResolverGroup(id string) (*ResolverGroup, error) {
b.RLock()
es, ok := b.resolverGroups[id]
b.RUnlock()
if !ok {
return nil, fmt.Errorf("ResolverGroup not found for id: %s", id)
}
return es, nil
}

func (b *builder) removeResolver(r *Resolver) {
bldr.Lock()
delete(bldr.clientResolvers, r.clientId)
bldr.Unlock()
func (b *builder) close(id string) {
b.Lock()
delete(b.resolverGroups, id)
b.Unlock()
}

func (r *builder) Scheme() string {
return scheme
}

// EndpointResolver gets the resolver for given etcd cluster name.
func EndpointResolver(clientId string) *Resolver {
return bldr.getResolver(clientId)
}

// Resolver provides a resolver for a single etcd cluster, identified by name.
type Resolver struct {
clientId string
cc resolver.ClientConn
addrs []resolver.Address
endpointId string
cc resolver.ClientConn
sync.RWMutex
}

// InitialAddrs sets the initial endpoint addresses for the resolver.
func (r *Resolver) InitialAddrs(addrs []resolver.Address) {
r.Lock()
r.addrs = addrs
r.Unlock()
}

// InitialEndpoints sets the initial endpoints to for the resolver.
// This should be called before dialing. The endpoints may be updated after the dial using NewAddress.
// At least one endpoint is required.
func (r *Resolver) InitialEndpoints(eps []string) error {
if len(eps) < 1 {
return fmt.Errorf("At least one endpoint is required, but got: %v", eps)
}
r.InitialAddrs(epsToAddrs(eps...))
return nil
}

// TODO: use balancer.epsToAddrs
func epsToAddrs(eps ...string) (addrs []resolver.Address) {
addrs = make([]resolver.Address, 0, len(eps))
Expand All @@ -130,35 +179,14 @@ func epsToAddrs(eps ...string) (addrs []resolver.Address) {
return addrs
}

// NewAddress updates the addresses of the resolver.
func (r *Resolver) NewAddress(addrs []resolver.Address) {
r.Lock()
r.addrs = addrs
r.Unlock()
if r.cc != nil {
r.cc.NewAddress(addrs)
}
}

func (*Resolver) ResolveNow(o resolver.ResolveNowOption) {}

func (r *Resolver) Close() {
bldr.removeResolver(r)
}

// Target constructs a endpoint target with current resolver's clientId.
func (r *Resolver) Target(endpoint string) string {
return Target(r.clientId, endpoint)
}

// Target constructs a endpoint resolver target.
func Target(clientId, endpoint string) string {
return fmt.Sprintf("%s://%s/%s", scheme, clientId, endpoint)
}

// IsTarget checks if a given target string in an endpoint resolver target.
func IsTarget(target string) bool {
return strings.HasPrefix(target, "endpoint://")
es, err := bldr.getResolverGroup(r.endpointId)
if err != nil {
return
}
es.removeResolver(r)
}

// Parse endpoint parses a endpoint of the form (http|https)://<host>*|(unix|unixs)://<path>) and returns a
Expand All @@ -185,7 +213,7 @@ func ParseEndpoint(endpoint string) (proto string, host string, scheme string) {
return proto, host, scheme
}

// ParseTarget parses a endpoint://<clientId>/<endpoint> string and returns the parsed clientId and endpoint.
// ParseTarget parses a endpoint://<id>/<endpoint> string and returns the parsed id and endpoint.
// If the target is malformed, an error is returned.
func ParseTarget(target string) (string, string, error) {
noPrefix := strings.TrimPrefix(target, targetPrefix)
Expand All @@ -194,7 +222,7 @@ func ParseTarget(target string) (string, string, error) {
}
parts := strings.SplitN(noPrefix, "/", 2)
if len(parts) != 2 {
return "", "", fmt.Errorf("malformed target, expected %s://<clientId>/<endpoint>, but got %s", scheme, target)
return "", "", fmt.Errorf("malformed target, expected %s://<id>/<endpoint>, but got %s", scheme, target)
}
return parts[0], parts[1], nil
}
Loading

0 comments on commit 8569b9c

Please sign in to comment.