Skip to content

Commit

Permalink
Merge pull request #1119 from voith/v4-add-websocket-timeout
Browse files Browse the repository at this point in the history
[BACKPORT TO V4] Add timeout for WebsocketProvider
  • Loading branch information
carver authored Oct 22, 2018
2 parents 8fae96b + bcde527 commit e3e7510
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 63 deletions.
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@
identity,
)

from .utils import (
get_open_port,
)


@pytest.fixture(scope="module", params=[lambda x: to_bytes(hexstr=x), identity])
def address_conversion_func(request):
return request.param


@pytest.fixture()
def open_port():
return get_open_port()
49 changes: 49 additions & 0 deletions tests/core/providers/test_websocket_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
import asyncio
from concurrent.futures import (
TimeoutError,
)
import pytest
from threading import (
Thread,
)

import websockets

from tests.utils import (
wait_for_ws,
)
from web3 import Web3
from web3.exceptions import (
ValidationError,
)
Expand All @@ -8,6 +21,42 @@
)


@pytest.yield_fixture
def start_websocket_server(open_port):
event_loop = asyncio.new_event_loop()

def run_server():
async def empty_server(websocket, path):
data = await websocket.recv()
await asyncio.sleep(0.02)
await websocket.send(data)
server = websockets.serve(empty_server, '127.0.0.1', open_port, loop=event_loop)
event_loop.run_until_complete(server)
event_loop.run_forever()

thd = Thread(target=run_server)
thd.start()
try:
yield
finally:
event_loop.call_soon_threadsafe(event_loop.stop)


@pytest.fixture()
def w3(open_port, start_websocket_server):
# need new event loop as the one used by server is already running
event_loop = asyncio.new_event_loop()
endpoint_uri = 'ws://127.0.0.1:{}'.format(open_port)
event_loop.run_until_complete(wait_for_ws(endpoint_uri, event_loop))
provider = WebsocketProvider(endpoint_uri, websocket_timeout=0.01)
return Web3(provider)


def test_websocket_provider_timeout(w3):
with pytest.raises(TimeoutError):
w3.eth.accounts


def test_restricted_websocket_kwargs():
invalid_kwargs = {'uri': 'ws://127.0.0.1:8546'}
re_exc_message = r'.*found: {0}*'.format(set(invalid_kwargs.keys()))
Expand Down
11 changes: 3 additions & 8 deletions tests/generate_go_ethereum_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
to_wei,
)

from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.module_testing.emitter_contract import (
EMITTER_ABI,
Expand Down Expand Up @@ -100,14 +103,6 @@ def tempdir():
shutil.rmtree(dir_path)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


def get_geth_binary():
from geth.install import (
get_executable_path,
Expand Down
8 changes: 0 additions & 8 deletions tests/integration/generate_fixtures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ def tempdir():
shutil.rmtree(dir_path)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


def get_geth_binary():
from geth.install import (
get_executable_path,
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/generate_fixtures/go_ethereum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
)

import common
from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.module_testing.emitter_contract import (
EMITTER_ABI,
Expand Down Expand Up @@ -42,7 +45,7 @@ def generate_go_ethereum_fixture(destination_dir):
geth_ipc_path_dir = stack.enter_context(common.tempdir())
geth_ipc_path = os.path.join(geth_ipc_path_dir, 'geth.ipc')

geth_port = common.get_open_port()
geth_port = get_open_port()
geth_binary = common.get_geth_binary()

geth_proc = stack.enter_context(common.get_geth_process( # noqa: F841
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/generate_fixtures/parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

import common
import go_ethereum
from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.toolz import (
merge,
Expand Down Expand Up @@ -176,7 +179,7 @@ def generate_parity_fixture(destination_dir):

geth_datadir = stack.enter_context(common.tempdir())

geth_port = common.get_open_port()
geth_port = get_open_port()

geth_ipc_path_dir = stack.enter_context(common.tempdir())
geth_ipc_path = os.path.join(geth_ipc_path_dir, 'geth.ipc')
Expand Down Expand Up @@ -221,7 +224,7 @@ def generate_parity_fixture(destination_dir):
parity_ipc_path_dir = stack.enter_context(common.tempdir())
parity_ipc_path = os.path.join(parity_ipc_path_dir, 'jsonrpc.ipc')

parity_port = common.get_open_port()
parity_port = get_open_port()
parity_binary = get_parity_binary()

parity_proc = stack.enter_context(get_parity_process( # noqa: F841
Expand Down
9 changes: 0 additions & 9 deletions tests/integration/go_ethereum/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import socket

from web3.utils.module_testing import (
EthModuleTest,
Expand All @@ -10,14 +9,6 @@
)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


class GoEthereumTest(Web3ModuleTest):
def _check_web3_clientVersion(self, client_version):
assert client_version.startswith('Geth/')
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

from tests.utils import (
get_open_port,
)
from web3 import Web3

from .common import (
Expand All @@ -8,7 +11,6 @@
GoEthereumPersonalModuleTest,
GoEthereumTest,
GoEthereumVersionModuleTest,
get_open_port,
)
from .utils import (
wait_for_http,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/go_ethereum/test_goethereum_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import pytest
import tempfile

from tests.utils import (
get_open_port,
)
from web3 import Web3

from .common import (
Expand All @@ -12,7 +15,6 @@
GoEthereumVersionModuleTest,
)
from .utils import (
get_open_port,
wait_for_socket,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/go_ethereum/test_goethereum_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from tests.integration.common import (
MiscWebsocketTest,
)
from tests.integration.utils import (
from tests.utils import (
get_open_port,
wait_for_ws,
)
from web3 import Web3
Expand All @@ -14,7 +15,6 @@
GoEthereumPersonalModuleTest,
GoEthereumTest,
GoEthereumVersionModuleTest,
get_open_port,
)


Expand Down
8 changes: 0 additions & 8 deletions tests/integration/go_ethereum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,3 @@ def kill_proc_gracefully(proc):
if proc.poll() is None:
proc.kill()
wait_for_popen(proc, 2)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)
9 changes: 0 additions & 9 deletions tests/integration/parity/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import socket

from flaky import (
flaky,
Expand All @@ -16,14 +15,6 @@
MAX_FLAKY_RUNS = 3


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


class ParityWeb3ModuleTest(Web3ModuleTest):
def _check_web3_clientVersion(self, client_version):
assert client_version.startswith('Parity/')
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/parity/test_parity_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from tests.integration.parity.utils import (
wait_for_http,
)
from tests.utils import (
get_open_port,
)
from web3 import Web3
from web3.utils.module_testing import (
NetModuleTest,
Expand All @@ -15,7 +18,6 @@
ParityPersonalModuleTest,
ParityTraceModuleTest,
ParityWeb3ModuleTest,
get_open_port,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/integration/parity/test_parity_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from tests.integration.common import (
MiscWebsocketTest,
)
from tests.integration.utils import (
from tests.utils import (
get_open_port,
wait_for_ws,
)
from web3 import Web3
Expand All @@ -18,7 +19,6 @@
ParityPersonalModuleTest,
ParityTraceModuleTest,
ParityWeb3ModuleTest,
get_open_port,
)


Expand Down
8 changes: 0 additions & 8 deletions tests/integration/parity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,3 @@ def kill_proc_gracefully(proc):
if proc.poll() is None:
proc.kill()
wait_for_popen(proc, 2)


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)
9 changes: 9 additions & 0 deletions tests/integration/utils.py → tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import asyncio
import socket
import time

import websockets


def get_open_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
return str(port)


async def wait_for_ws(endpoint_uri, event_loop, timeout=60):
start = time.time()
while time.time() < start + timeout:
Expand Down
21 changes: 18 additions & 3 deletions web3/providers/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

RESTRICTED_WEBSOCKET_KWARGS = {'uri', 'loop'}
DEFAULT_WEBSOCKET_TIMEOUT = 10


def _start_event_loop(loop):
Expand Down Expand Up @@ -63,8 +64,14 @@ class WebsocketProvider(JSONBaseProvider):
logger = logging.getLogger("web3.providers.WebsocketProvider")
_loop = None

def __init__(self, endpoint_uri=None, websocket_kwargs=None):
def __init__(
self,
endpoint_uri=None,
websocket_kwargs=None,
websocket_timeout=DEFAULT_WEBSOCKET_TIMEOUT
):
self.endpoint_uri = endpoint_uri
self.websocket_timeout = websocket_timeout
if self.endpoint_uri is None:
self.endpoint_uri = get_default_endpoint()
if WebsocketProvider._loop is None:
Expand All @@ -90,8 +97,16 @@ def __str__(self):

async def coro_make_request(self, request_data):
async with self.conn as conn:
await conn.send(request_data)
return json.loads(await conn.recv())
await asyncio.wait_for(
conn.send(request_data),
timeout=self.websocket_timeout
)
return json.loads(
await asyncio.wait_for(
conn.recv(),
timeout=self.websocket_timeout
)
)

def make_request(self, method, params):
self.logger.debug("Making request WebSocket. URI: %s, "
Expand Down

0 comments on commit e3e7510

Please sign in to comment.