Skip to content

Commit

Permalink
Better tests for DNSHostProvider; fix data race
Browse files Browse the repository at this point in the history
- Added a test for DNSHostProvider that actually fails over the
  connection.
- Wrapped access to the connection's current server in a mutex.
  • Loading branch information
zellyn committed Oct 28, 2015
1 parent 57af1c8 commit bac02d3
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 6 deletions.
20 changes: 15 additions & 5 deletions zk/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ type Conn struct {

dialer Dialer
hostProvider HostProvider
server string // remember the address/port of the current server
serverMu sync.Mutex // protects server
server string // remember the address/port of the current server
conn net.Conn
eventChan chan Event
shouldQuit chan struct{}
Expand Down Expand Up @@ -253,7 +254,7 @@ func (c *Conn) setTimeouts(sessionTimeoutMs int32) {
func (c *Conn) setState(state State) {
atomic.StoreInt32((*int32)(&c.state), int32(state))
select {
case c.eventChan <- Event{Type: EventSession, State: state, Server: c.server}:
case c.eventChan <- Event{Type: EventSession, State: state, Server: c.Server()}:
default:
// panic("zk: event channel full - it must be monitored and never allowed to be full")
}
Expand All @@ -262,7 +263,9 @@ func (c *Conn) setState(state State) {
func (c *Conn) connect() error {
var retryStart bool
for {
c.serverMu.Lock()
c.server, retryStart = c.hostProvider.Next()
c.serverMu.Unlock()
c.setState(StateConnecting)
if retryStart {
c.flushUnsentRequests(ErrNoServer)
Expand All @@ -276,15 +279,15 @@ func (c *Conn) connect() error {
}
}

zkConn, err := c.dialer("tcp", c.server, c.connectTimeout)
zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
if err == nil {
c.conn = zkConn
c.setState(StateConnected)
c.logger.Printf("Connected to %s", c.server)
c.logger.Printf("Connected to %s", c.Server())
return nil
}

c.logger.Printf("Failed to connect to %s: %+v", c.server, err)
c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err)
}
}

Expand Down Expand Up @@ -893,3 +896,10 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
}
return mr, err
}

// Server returns the current or last-connected server name.
func (c *Conn) Server() string {
c.serverMu.Lock()
defer c.serverMu.Unlock()
return c.server
}
121 changes: 120 additions & 1 deletion zk/dnshostprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zk

import (
"fmt"
"log"
"testing"
"time"
)
Expand Down Expand Up @@ -49,7 +50,125 @@ func TestDNSHostProviderCreate(t *testing.T) {
}
}

// Test the `retryStart` functionality of DNSHostProvider.
// localHostPortsFacade wraps a HostProvider, remapping the
// address/port combinations it returns to "localhost:$PORT" where
// $PORT is chosen from the provided ports.
type localHostPortsFacade struct {
inner HostProvider // The wrapped HostProvider
ports []int // The provided list of ports
nextPort int // The next port to use
mapped map[string]string // Already-mapped address/port combinations
}

func newLocalHostPortsFacade(inner HostProvider, ports []int) *localHostPortsFacade {
return &localHostPortsFacade{
inner: inner,
ports: ports,
mapped: make(map[string]string),
}
}

func (lhpf *localHostPortsFacade) Len() int { return lhpf.inner.Len() }
func (lhpf *localHostPortsFacade) Connected() { lhpf.inner.Connected() }
func (lhpf *localHostPortsFacade) Init(servers []string) error { return lhpf.inner.Init(servers) }
func (lhpf *localHostPortsFacade) Next() (string, bool) {
server, retryStart := lhpf.inner.Next()

// If we've already set up a mapping for that server, just return it.
if localMapping := lhpf.mapped[server]; localMapping != "" {
return localMapping, retryStart
}

if lhpf.nextPort == len(lhpf.ports) {
log.Fatalf("localHostPortsFacade out of ports to assign to %q; current config: %q", server, lhpf.mapped)
}

localMapping := fmt.Sprintf("localhost:%d", lhpf.ports[lhpf.nextPort])
lhpf.mapped[server] = localMapping
lhpf.nextPort++
return localMapping, retryStart
}

var _ HostProvider = &localHostPortsFacade{}

// TestDNSHostProviderReconnect tests that the zk.Conn correctly
// reconnects when the Zookeeper instance it's connected to
// restarts. It wraps the DNSHostProvider in a lightweight facade that
// remaps addresses to localhost:$PORT combinations corresponding to
// the test ZooKeeper instances.
func TestDNSHostProviderReconnect(t *testing.T) {
ts, err := StartTestCluster(3, nil, logWriter{t: t, p: "[ZKERR] "})
if err != nil {
t.Fatal(err)
}
defer ts.Stop()

innerHp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
}}
ports := make([]int, 0, len(ts.Servers))
for _, server := range ts.Servers {
ports = append(ports, server.Port)
}
hp := newLocalHostPortsFacade(innerHp, ports)

zk, _, err := Connect([]string{"foo.example.com:12345"}, time.Second, WithHostProvider(hp))
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
defer zk.Close()

path := "/gozk-test"

// Initial operation to force connection.
if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
t.Fatalf("Delete returned error: %+v", err)
}

// Figure out which server we're connected to.
currentServer := zk.Server()
t.Logf("Connected to %q. Finding test server index…", currentServer)
serverIndex := -1
for i, server := range ts.Servers {
server := fmt.Sprintf("localhost:%d", server.Port)
t.Logf("…trying %q", server)
if currentServer == server {
serverIndex = i
t.Logf("…found at index %d", i)
break
}
}
if serverIndex == -1 {
t.Fatalf("Cannot determine test server index.")
}

// Restart the connected server.
ts.Servers[serverIndex].Srv.Stop()
ts.Servers[serverIndex].Srv.Start()

// Continue with the basic TestCreate tests.
if p, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil {
t.Fatalf("Create returned error: %+v", err)
} else if p != path {
t.Fatalf("Create returned different path '%s' != '%s'", p, path)
}
if data, stat, err := zk.Get(path); err != nil {
t.Fatalf("Get returned error: %+v", err)
} else if stat == nil {
t.Fatal("Get returned nil stat")
} else if len(data) < 4 {
t.Fatal("Get returned wrong size data")
}

if zk.Server() == currentServer {
t.Errorf("Still connected to %q after restart.", currentServer)
}
}

// TestDNSHostProviderRetryStart tests the `retryStart` functionality
// of DNSHostProvider.
// It's also probably the clearest visual explanation of exactly how
// it works.
func TestDNSHostProviderRetryStart(t *testing.T) {
hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
Expand Down

0 comments on commit bac02d3

Please sign in to comment.