diff --git a/zk/cluster_test.go b/zk/cluster_test.go index b8571734..445f305c 100644 --- a/zk/cluster_test.go +++ b/zk/cluster_test.go @@ -1,8 +1,7 @@ package zk import ( - "fmt" - "strings" + "sync" "testing" "time" ) @@ -46,58 +45,41 @@ func TestBasicCluster(t *testing.T) { } } +// If the current leader dies, then the session is reestablished with the new one. func TestClientClusterFailover(t *testing.T) { - ts, err := StartTestCluster(3, nil, logWriter{t: t, p: "[ZKERR] "}) + tc, err := StartTestCluster(3, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil { t.Fatal(err) } - defer ts.Stop() - zk, evCh, err := ts.ConnectAll() + defer tc.Stop() + zk, evCh, err := tc.ConnectAll() if err != nil { t.Fatalf("Connect returned error: %+v", err) } defer zk.Close() - hasSession := make(chan string, 1) - go func() { - for ev := range evCh { - if ev.Type == EventSession && ev.State == StateHasSession { - select { - case hasSession <- ev.Server: - default: - } - } - } - }() + sl := NewStateLogger(evCh) - waitSession := func() string { - select { - case srv := <-hasSession: - return srv - case <-time.After(time.Second * 8): - t.Fatal("Failed to connect and get a session") - } - return "" + hasSessionEvent1 := sl.NewWatcher(sessionStateMatcher(StateHasSession)).Wait(8 * time.Second) + if hasSessionEvent1 == nil { + t.Fatalf("Failed to connect and get session") } - srv := waitSession() if _, err := zk.Create("/gozk-test", []byte("foo-cluster"), 0, WorldACL(PermAll)); err != nil { t.Fatalf("Create failed on node 1: %+v", err) } - stopped := false - for _, s := range ts.Servers { - if strings.HasSuffix(srv, fmt.Sprintf(":%d", s.Port)) { - s.Srv.Stop() - stopped = true - break - } - } - if !stopped { - t.Fatal("Failed to stop server") + hasSessionWatcher2 := sl.NewWatcher(sessionStateMatcher(StateHasSession)) + + // Kill the current leader + tc.StopServer(hasSessionEvent1.Server) + + // Wait for the session to be reconnected with the new leader. + hasSessionWatcher2.Wait(8 * time.Second) + if hasSessionWatcher2 == nil { + t.Fatalf("Failover failed") } - waitSession() if by, _, err := zk.Get("/gozk-test"); err != nil { t.Fatalf("Get failed on node 2: %+v", err) } else if string(by) != "foo-cluster" { @@ -105,6 +87,83 @@ func TestClientClusterFailover(t *testing.T) { } } +// If a ZooKeeper cluster looses quorum then a session is reconnected as soon +// as the quorum is restored. +func TestNoQuorum(t *testing.T) { + tc, err := StartTestCluster(3, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + defer tc.Stop() + zk, evCh, err := tc.ConnectAllTimeout(4 * time.Second) + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk.Close() + sl := NewStateLogger(evCh) + + // Wait for initial session to be established + hasSessionEvent1 := sl.NewWatcher(sessionStateMatcher(StateHasSession)).Wait(8 * time.Second) + if hasSessionEvent1 == nil { + t.Fatalf("Failed to connect and get session") + } + initialSessionID := zk.sessionID + DefaultLogger.Printf(" Session established: id=%d, timeout=%d", zk.sessionID, zk.sessionTimeoutMs) + + // Kill the ZooKeeper leader and wait for the session to reconnect. + DefaultLogger.Printf(" Kill the leader") + hasSessionWatcher2 := sl.NewWatcher(sessionStateMatcher(StateHasSession)) + tc.StopServer(hasSessionEvent1.Server) + hasSessionEvent2 := hasSessionWatcher2.Wait(8 * time.Second) + if hasSessionEvent2 == nil { + t.Fatalf("Failover failed") + } + + // Kill the ZooKeeper leader leaving the cluster without quorum. + DefaultLogger.Printf(" Kill the leader") + tc.StopServer(hasSessionEvent2.Server) + + // Make sure that we keep retrying connecting to the only remaining + // ZooKeeper server, but the attempts are being dropped because there is + // no quorum. + DefaultLogger.Printf(" Retrying no luck...") + var firstDisconnect *Event + begin := time.Now() + for time.Now().Sub(begin) < 6*time.Second { + disconnectedEvent := sl.NewWatcher(sessionStateMatcher(StateDisconnected)).Wait(4 * time.Second) + if disconnectedEvent == nil { + t.Fatalf("Disconnected event expected") + } + if firstDisconnect == nil { + firstDisconnect = disconnectedEvent + continue + } + if disconnectedEvent.Server != firstDisconnect.Server { + t.Fatalf("Disconnect from wrong server: expected=%s, actual=%s", + firstDisconnect.Server, disconnectedEvent.Server) + } + } + + // Start a ZooKeeper node to restore quorum. + hasSessionWatcher3 := sl.NewWatcher(sessionStateMatcher(StateHasSession)) + tc.StartServer(hasSessionEvent1.Server) + + // Make sure that session is reconnected with the same ID. + hasSessionEvent3 := hasSessionWatcher3.Wait(8 * time.Second) + if hasSessionEvent3 == nil { + t.Fatalf("Session has not been reconnected") + } + if zk.sessionID != initialSessionID { + t.Fatalf("Wrong session ID: expected=%d, actual=%d", initialSessionID, zk.sessionID) + } + + // Make sure that the session is not dropped soon after reconnect + e := sl.NewWatcher(sessionStateMatcher(StateDisconnected)).Wait(6 * time.Second) + if e != nil { + t.Fatalf("Unexpected disconnect") + } +} + func TestWaitForClose(t *testing.T) { ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil { @@ -164,3 +223,72 @@ func TestBadSession(t *testing.T) { t.Fatalf("Delete returned error: %+v", err) } } + +type EventLogger struct { + events []Event + watchers []*EventWatcher + lock sync.Mutex + wg sync.WaitGroup +} + +func NewStateLogger(eventCh <-chan Event) *EventLogger { + el := &EventLogger{} + el.wg.Add(1) + go func() { + defer el.wg.Done() + for event := range eventCh { + el.lock.Lock() + for _, sw := range el.watchers { + if !sw.triggered && sw.matcher(event) { + sw.triggered = true + sw.matchCh <- event + } + } + DefaultLogger.Printf(" event received: %v\n", event) + el.events = append(el.events, event) + el.lock.Unlock() + } + }() + return el +} + +func (el *EventLogger) NewWatcher(matcher func(Event) bool) *EventWatcher { + ew := &EventWatcher{matcher: matcher, matchCh: make(chan Event, 1)} + el.lock.Lock() + el.watchers = append(el.watchers, ew) + el.lock.Unlock() + return ew +} + +func (el *EventLogger) Events() []Event { + el.lock.Lock() + transitions := make([]Event, len(el.events)) + copy(transitions, el.events) + el.lock.Unlock() + return transitions +} + +func (el *EventLogger) Wait4Stop() { + el.wg.Wait() +} + +type EventWatcher struct { + matcher func(Event) bool + matchCh chan Event + triggered bool +} + +func (ew *EventWatcher) Wait(timeout time.Duration) *Event { + select { + case event := <-ew.matchCh: + return &event + case <-time.After(timeout): + return nil + } +} + +func sessionStateMatcher(s State) func(Event) bool { + return func(e Event) bool { + return e.Type == EventSession && e.State == s + } +} diff --git a/zk/conn.go b/zk/conn.go index 4c003668..b12bcc9e 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -62,12 +62,12 @@ type Logger interface { } type Conn struct { - lastZxid int64 - sessionID int64 - state State // must be 32-bit aligned - xid uint32 - timeout int32 // session timeout in milliseconds - passwd []byte + lastZxid int64 + sessionID int64 + state State // must be 32-bit aligned + xid uint32 + sessionTimeoutMs int32 // session timeout in milliseconds + passwd []byte dialer Dialer servers []string @@ -140,8 +140,6 @@ func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Di return nil, nil, errors.New("zk: server list must not be empty") } - recvTimeout := sessionTimeout * 2 / 3 - srvs := make([]string, len(servers)) for i, addr := range servers { @@ -168,19 +166,18 @@ func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Di state: StateDisconnected, eventChan: ec, shouldQuit: make(chan struct{}), - recvTimeout: recvTimeout, - pingInterval: recvTimeout / 2, connectTimeout: 1 * time.Second, sendChan: make(chan *request, sendChanSize), requests: make(map[int32]*request), watchers: make(map[watchPathType][]chan Event), passwd: emptyPassword, - timeout: int32(sessionTimeout.Nanoseconds() / 1e6), logger: DefaultLogger, // Debug reconnectDelay: 0, } + conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) + go func() { conn.loop() conn.flushRequests(ErrClosing) @@ -215,6 +212,13 @@ func (c *Conn) SetLogger(l Logger) { c.logger = l } +func (c *Conn) setTimeouts(sessionTimeoutMs int32) { + c.sessionTimeoutMs = sessionTimeoutMs + sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond + c.recvTimeout = sessionTimeout * 2 / 3 + c.pingInterval = c.recvTimeout / 2 +} + func (c *Conn) setState(state State) { atomic.StoreInt32((*int32)(&c.state), int32(state)) select { @@ -225,9 +229,9 @@ func (c *Conn) setState(state State) { } func (c *Conn) connect() error { - c.setState(StateConnecting) for { c.serverIndex = (c.serverIndex + 1) % len(c.servers) + c.setState(StateConnecting) if c.serverIndex == c.lastServerIndex { c.flushUnsentRequests(ErrNoServer) select { @@ -247,6 +251,7 @@ func (c *Conn) connect() error { if err == nil { c.conn = zkConn c.setState(StateConnected) + c.logger.Printf("Connected to %s", c.servers[c.serverIndex]) return nil } @@ -264,24 +269,29 @@ func (c *Conn) loop() { err := c.authenticate() switch { case err == ErrSessionExpired: + c.logger.Printf("Authentication failed: %s", err) c.invalidateWatches(err) case err != nil && c.conn != nil: + c.logger.Printf("Authentication failed: %s", err) c.conn.Close() case err == nil: + c.logger.Printf("Authenticated: id=%d, timeout=%d", c.sessionID, c.sessionTimeoutMs) c.lastServerIndex = c.serverIndex closeChan := make(chan struct{}) // channel to tell send loop stop var wg sync.WaitGroup wg.Add(1) go func() { - c.sendLoop(c.conn, closeChan) + err := c.sendLoop(c.conn, closeChan) + c.logger.Printf("Send loop terminated: err=%v", err) c.conn.Close() // causes recv loop to EOF/exit wg.Done() }() wg.Add(1) go func() { - err = c.recvLoop(c.conn) + err := c.recvLoop(c.conn) + c.logger.Printf("Recv loop terminated: err=%v", err) if err == nil { panic("zk: recvLoop should never return nil error") } @@ -289,16 +299,12 @@ func (c *Conn) loop() { wg.Done() }() + c.sendSetWatches() wg.Wait() } c.setState(StateDisconnected) - // Yeesh - if err != io.EOF && err != ErrSessionExpired && !strings.Contains(err.Error(), "use of closed network connection") { - c.logger.Printf(err.Error()) - } - select { case <-c.shouldQuit: c.flushRequests(ErrClosing) @@ -404,12 +410,11 @@ func (c *Conn) sendSetWatches() { func (c *Conn) authenticate() error { buf := make([]byte, 256) - // connect request - + // Encode and send a connect request. n, err := encodePacket(buf[4:], &connectRequest{ ProtocolVersion: protocolVersion, LastZxidSeen: c.lastZxid, - TimeOut: c.timeout, + TimeOut: c.sessionTimeoutMs, SessionID: c.sessionID, Passwd: c.passwd, }) @@ -426,23 +431,12 @@ func (c *Conn) authenticate() error { return err } - c.sendSetWatches() - - // connect response - - // package length + // Receive and decode a connect response. c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10)) _, err = io.ReadFull(c.conn, buf[:4]) c.conn.SetReadDeadline(time.Time{}) if err != nil { - // Sometimes zookeeper just drops connection on invalid session data, - // we prefer to drop session and start from scratch when that event - // occurs instead of dropping into loop of connect/disconnect attempts - atomic.StoreInt64(&c.sessionID, int64(0)) - c.passwd = emptyPassword - c.lastZxid = 0 - c.setState(StateExpired) - return ErrSessionExpired + return err } blen := int(binary.BigEndian.Uint32(buf[:4])) @@ -468,8 +462,8 @@ func (c *Conn) authenticate() error { return ErrSessionExpired } - c.timeout = r.TimeOut atomic.StoreInt64(&c.sessionID, r.SessionID) + c.setTimeouts(r.TimeOut) c.passwd = r.Passwd c.setState(StateHasSession) @@ -859,7 +853,7 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { case *CheckVersionRequest: opCode = opCheck default: - return nil, fmt.Errorf("uknown operation type %T", op) + return nil, fmt.Errorf("unknown operation type %T", op) } req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op}) } diff --git a/zk/server_help.go b/zk/server_help.go index 4a53772b..a0e12cf4 100644 --- a/zk/server_help.go +++ b/zk/server_help.go @@ -7,6 +7,7 @@ import ( "math/rand" "os" "path/filepath" + "strings" "time" ) @@ -88,7 +89,7 @@ func StartTestCluster(size int, stdout, stderr io.Writer) (*TestCluster, error) }) } success = true - time.Sleep(time.Second) // Give the server time to become active. Should probably actually attempt to connect to verify. + time.Sleep(3 * time.Second) // Give the server time to become active. Should probably actually attempt to connect to verify. return cluster, nil } @@ -117,3 +118,23 @@ func (ts *TestCluster) Stop() error { defer os.RemoveAll(ts.Path) return nil } + +func (tc *TestCluster) StartServer(server string) { + for _, s := range tc.Servers { + if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { + s.Srv.Start() + return + } + } + panic(fmt.Sprintf("Unknown server: %s", server)) +} + +func (tc *TestCluster) StopServer(server string) { + for _, s := range tc.Servers { + if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { + s.Srv.Stop() + return + } + } + panic(fmt.Sprintf("Unknown server: %s", server)) +} diff --git a/zk/zk_test.go b/zk/zk_test.go index fdbe5172..10e0b586 100644 --- a/zk/zk_test.go +++ b/zk/zk_test.go @@ -297,10 +297,12 @@ func TestSetWatchers(t *testing.T) { t.Fatal("Children should return at least 1 child") } + // Simulate network error by brutally closing the network connection. zk.conn.Close() if err := zk2.Delete(testPath, -1); err != nil && err != ErrNoNode { t.Fatalf("Delete returned error: %+v", err) } + // Allow some time for the `zk` session to reconnect and set watches. time.Sleep(time.Millisecond * 100) if path, err := zk2.Create("/gozk-test", []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil {