diff --git a/CHANGES.rst b/CHANGES.rst index 7e20db3d4a2..2f4579984de 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,8 @@ CHANGES ======= +- Fix multiple calls to client ws_connect when using a shared header dict #1643 + 1.3.1 (2017-02-09) ------------------ diff --git a/aiohttp/client.py b/aiohttp/client.py index c138f3dd4a9..936a242c683 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -317,8 +317,6 @@ def _ws_connect(self, url, *, proxy=None, proxy_auth=None): - sec_key = base64.b64encode(os.urandom(16)) - if headers is None: headers = CIMultiDict() @@ -326,13 +324,15 @@ def _ws_connect(self, url, *, hdrs.UPGRADE: hdrs.WEBSOCKET, hdrs.CONNECTION: hdrs.UPGRADE, hdrs.SEC_WEBSOCKET_VERSION: '13', - hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(), } for key, value in default_headers.items(): if key not in headers: headers[key] = value + sec_key = base64.b64encode(os.urandom(16)) + headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() + if protocols: headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) if origin is not None: diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index f016f31d35b..5a07df0ada7 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -215,6 +215,56 @@ def test_ws_connect_err_challenge(loop, ws_key, key_data): assert ctx.value.message == 'Invalid challenge response' +@asyncio.coroutine +def test_ws_connect_common_headers(ws_key, loop, key_data): + """Emulate a headers dict being reused for a second ws_connect. + + In this scenario, we need to ensure that the newly generated secret key + is sent to the server, not the stale key. + """ + headers = {} + + @asyncio.coroutine + def test_connection(): + @asyncio.coroutine + def mock_get(*args, **kwargs): + resp = mock.Mock() + resp.status = 101 + key = kwargs.get('headers').get(hdrs.SEC_WEBSOCKET_KEY) + accept = base64.b64encode( + hashlib.sha1(base64.b64encode(base64.b64decode(key)) + WS_KEY) + .digest()).decode() + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: accept, + hdrs.SEC_WEBSOCKET_PROTOCOL: 'chat' + } + return resp + with mock.patch('aiohttp.client.os') as m_os: + with mock.patch('aiohttp.client.ClientSession.get', + side_effect=mock_get) as m_req: + m_os.urandom.return_value = key_data + #m_req.return_value = helpers.create_future(loop) + #m_req.return_value.set_result(resp) + + res = yield from aiohttp.ClientSession(loop=loop).ws_connect( + 'http://test.org', + protocols=('t1', 't2', 'chat'), + headers=headers) + + assert isinstance(res, ClientWebSocketResponse) + assert res.protocol == 'chat' + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + + yield from test_connection() + # Generate a new ws key + key_data = os.urandom(16) + ws_key = base64.b64encode( + hashlib.sha1(base64.b64encode(key_data) + WS_KEY).digest()).decode() + yield from test_connection() + + @asyncio.coroutine def test_close(loop, ws_key, key_data): resp = mock.Mock()