Skip to content

Commit

Permalink
Merge pull request #1643 from armills/ws_connect_fix
Browse files Browse the repository at this point in the history
Fix multiple ws_connect with headers dict
  • Loading branch information
Nikolay Kim authored Feb 16, 2017
2 parents 55e1228 + 0b6681d commit 96a760d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
CHANGES
=======

- Fix multiple calls to client ws_connect when using a shared header dict #1643

1.3.1 (2017-02-09)
------------------

Expand Down
6 changes: 3 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,22 +317,22 @@ def _ws_connect(self, url, *,
proxy=None,
proxy_auth=None):

sec_key = base64.b64encode(os.urandom(16))

if headers is None:
headers = CIMultiDict()

default_headers = {
hdrs.UPGRADE: hdrs.WEBSOCKET,
hdrs.CONNECTION: hdrs.UPGRADE,
hdrs.SEC_WEBSOCKET_VERSION: '13',
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
}

for key, value in default_headers.items():
if key not in headers:
headers[key] = value

sec_key = base64.b64encode(os.urandom(16))
headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode()

if protocols:
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
if origin is not None:
Expand Down
50 changes: 50 additions & 0 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,56 @@ def test_ws_connect_err_challenge(loop, ws_key, key_data):
assert ctx.value.message == 'Invalid challenge response'


@asyncio.coroutine
def test_ws_connect_common_headers(ws_key, loop, key_data):
"""Emulate a headers dict being reused for a second ws_connect.
In this scenario, we need to ensure that the newly generated secret key
is sent to the server, not the stale key.
"""
headers = {}

@asyncio.coroutine
def test_connection():
@asyncio.coroutine
def mock_get(*args, **kwargs):
resp = mock.Mock()
resp.status = 101
key = kwargs.get('headers').get(hdrs.SEC_WEBSOCKET_KEY)
accept = base64.b64encode(
hashlib.sha1(base64.b64encode(base64.b64decode(key)) + WS_KEY)
.digest()).decode()
resp.headers = {
hdrs.UPGRADE: hdrs.WEBSOCKET,
hdrs.CONNECTION: hdrs.UPGRADE,
hdrs.SEC_WEBSOCKET_ACCEPT: accept,
hdrs.SEC_WEBSOCKET_PROTOCOL: 'chat'
}
return resp
with mock.patch('aiohttp.client.os') as m_os:
with mock.patch('aiohttp.client.ClientSession.get',
side_effect=mock_get) as m_req:
m_os.urandom.return_value = key_data
#m_req.return_value = helpers.create_future(loop)
#m_req.return_value.set_result(resp)

res = yield from aiohttp.ClientSession(loop=loop).ws_connect(
'http://test.org',
protocols=('t1', 't2', 'chat'),
headers=headers)

assert isinstance(res, ClientWebSocketResponse)
assert res.protocol == 'chat'
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]

yield from test_connection()
# Generate a new ws key
key_data = os.urandom(16)
ws_key = base64.b64encode(
hashlib.sha1(base64.b64encode(key_data) + WS_KEY).digest()).decode()
yield from test_connection()


@asyncio.coroutine
def test_close(loop, ws_key, key_data):
resp = mock.Mock()
Expand Down

0 comments on commit 96a760d

Please sign in to comment.