Skip to content

Commit

Permalink
Merge pull request #400 from tonyseek/hotfix/xid-mismatch
Browse files Browse the repository at this point in the history
Fix the client.add_auth hangs by xids mismatch.
  • Loading branch information
bbangert authored May 31, 2017
2 parents b1f3d61 + 61a3576 commit c85969a
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 2 deletions.
4 changes: 3 additions & 1 deletion kazoo/protocol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,10 @@ def _read_response(self, header, buffer, offset):
if header.zxid and header.zxid > 0:
client.last_zxid = header.zxid
if header.xid != xid:
raise RuntimeError('xids do not match, expected %r '
exc = RuntimeError('xids do not match, expected %r '
'received %r', xid, header.xid)
async_object.set_exception(exc)
raise exc

# Determine if its an exists request and a no node error
exists_error = (header.err == NoNodeError.code and
Expand Down
73 changes: 72 additions & 1 deletion kazoo/tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections import namedtuple
from collections import namedtuple, deque
import os
import threading
import time
import uuid
import struct
import sys

from nose import SkipTest
from nose.tools import eq_
Expand Down Expand Up @@ -304,3 +305,73 @@ def testit():
client.remove_listener(listen)
self.cluster[1].run()
self.cluster[2].run()


class TestUnorderedXids(KazooTestCase):

def setUp(self):
super(TestUnorderedXids, self).setUp()

self.connection = self.client._connection
self.connection_routine = self.connection._connection_routine

self._pending = self.client._pending
self.client._pending = _naughty_deque()

def tearDown(self):
self.client._pending = self._pending
super(TestUnorderedXids, self).tearDown()

def _get_client(self, **kwargs):
# overrides for patching zk_loop
c = KazooTestCase._get_client(self, **kwargs)
self._zk_loop = c._connection.zk_loop
self._zk_loop_errors = []
c._connection.zk_loop = self._zk_loop_func
return c

def _zk_loop_func(self, *args, **kwargs):
# patched zk_loop which will catch and collect all RuntimeError
try:
self._zk_loop(*args, **kwargs)
except RuntimeError as e:
self._zk_loop_errors.append(e)

def test_xids_mismatch(self):
from kazoo.protocol.states import KeeperState

ev = threading.Event()
error_stack = []

@self.client.add_listener
def listen(state):
if self.client.client_state == KeeperState.CLOSED:
ev.set()

def log_exception(*args):
error_stack.append((args, sys.exc_info()))

self.connection.logger.exception = log_exception

ev.clear()
self.assertRaises(RuntimeError, self.client.get_children, '/')

ev.wait()
eq_(self.client.connected, False)
eq_(self.client.state, 'LOST')
eq_(self.client.client_state, KeeperState.CLOSED)

args, exc_info = error_stack[-1]
eq_(args, ('Unhandled exception in connection loop',))
eq_(exc_info[0], RuntimeError)

self.client.handler.sleep_func(0.2)
assert not self.connection_routine.is_alive()
assert len(self._zk_loop_errors) == 1
assert self._zk_loop_errors[0] == exc_info[1]


class _naughty_deque(deque):
def append(self, s):
request, async_object, xid = s
return deque.append(self, (request, async_object, xid + 1)) # +1s

0 comments on commit c85969a

Please sign in to comment.