diff --git a/kazoo/client.py b/kazoo/client.py index ca25890f..a939a74f 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -453,7 +453,8 @@ def _session_callback(self, state): self._live.clear() self._notify_pending(state) self._make_state_change(KazooState.SUSPENDED) - self._reset_watchers() + if state != KeeperState.CONNECTING: + self._reset_watchers() def _notify_pending(self, state): """Used to clear a pending response queue and request queue diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 08304c1d..5c7a2536 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -26,6 +26,7 @@ Ping, PingInstance, ReplyHeader, + SetWatches, Transaction, Watch, int_struct @@ -59,6 +60,7 @@ WATCH_XID = -1 PING_XID = -2 AUTH_XID = -4 +SET_WATCHES_XID = -8 CLOSE_RESPONSE = Close.type @@ -406,6 +408,8 @@ def _read_socket(self, read_timeout): async_object.set(True) elif header.xid == WATCH_XID: self._read_watch_event(buffer, offset) + elif header.xid == SET_WATCHES_XID: + self.logger.log(BLATHER, 'Received SetWatches reply') else: self.logger.log(BLATHER, 'Reading for header %r', header) @@ -438,6 +442,8 @@ def _send_request(self, read_timeout, connect_timeout): # Special case for auth packets if request.type == Auth.type: xid = AUTH_XID + elif request.type == SetWatches.type: + xid = SET_WATCHES_XID else: self._xid += 1 xid = self._xid @@ -588,6 +594,10 @@ def _connect(self, host, port): client._session_id or 0, client._session_passwd, client.read_only) + # save the client's last_zxid before it gets overwritten by the server's. + # we'll need this to reset watches via SetWatches further below. + last_zxid = client.last_zxid + connect_result, zxid = self._invoke( client._session_timeout / 1000.0, connect) @@ -626,4 +636,15 @@ def _connect(self, host, port): zxid = self._invoke(connect_timeout / 1000.0, ap, xid=AUTH_XID) if zxid: client.last_zxid = zxid + + # TODO: separate exist from data watches + if client._data_watchers or client._child_watchers.keys(): + sw = SetWatches(last_zxid, + client._data_watchers.keys(), + client._data_watchers.keys(), + client._child_watchers.keys()) + zxid = self._invoke(connect_timeout / 1000.0, sw, xid=SET_WATCHES_XID) + if zxid: + client.last_zxid = zxid + return read_timeout, connect_timeout diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index f44f49a3..db13cba9 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -14,6 +14,7 @@ int_int_long_struct = struct.Struct('!iiq') int_long_int_long_struct = struct.Struct('!iqiq') +long_struct = struct.Struct('!q') multiheader_struct = struct.Struct('!iBi') reply_header_struct = struct.Struct('!iqi') stat_struct = struct.Struct('!qqqqiiiqiiq') @@ -53,6 +54,14 @@ def write_string(bytes): return int_struct.pack(len(utf8_str)) + utf8_str +def write_string_vector(v): + b = bytearray() + b.extend(int_struct.pack(len(v))) + for s in v: + b.extend(write_string(s)) + return b + + def write_buffer(bytes): if bytes is None: return int_struct.pack(-1) @@ -360,6 +369,20 @@ def serialize(self): write_string(self.auth)) +class SetWatches( + namedtuple('SetWatches', + 'relativeZxid, dataWatches, existWatches, childWatches')): + type = 101 + + def serialize(self): + b = bytearray() + b.extend(long_struct.pack(self.relativeZxid)) + b.extend(write_string_vector(self.dataWatches)) + b.extend(write_string_vector(self.existWatches)) + b.extend(write_string_vector(self.childWatches)) + return b + + class Watch(namedtuple('Watch', 'type state path')): @classmethod def deserialize(cls, bytes, offset): diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index a6e52e51..6d97ace9 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -915,6 +915,42 @@ def test_update_host_list(self): finally: self.cluster[0].run() + def test_set_watches_on_reconnect(self): + client = self.client + watch_event = client.handler.event_object() + + client.create("/tacos") + + # set the watch + def w(we): + eq_(we.path, "/tacos") + watch_event.set() + + client.get_children("/tacos", watch=w) + + # force a reconnect + states = [] + rc = client.handler.event_object() + @client.add_listener + def listener(state): + if state == KazooState.CONNECTED: + states.append(state) + rc.set() + + client._connection._socket.shutdown(socket.SHUT_RDWR) + + rc.wait(10) + eq_(states, [KazooState.CONNECTED]) + + # watches should still be there + self.assertTrue(len(client._child_watchers) == 1) + + # ... and they should fire + client.create("/tacos/hello_", b"", ephemeral=True, sequence=True) + + watch_event.wait(1) + self.assertTrue(watch_event.is_set()) + dummy_dict = { 'aversion': 1, 'ctime': 0, 'cversion': 1,