From 699ee9c47adcb44f1a8150d8c92a9555d07f7b5b Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 10 Sep 2023 17:22:09 +0100 Subject: [PATCH] Context manager interface for the simple clients --- docs/client.rst | 22 +++++++++++++-- examples/simple-client/async/fiddle_client.py | 7 ++--- .../simple-client/async/latency_client.py | 8 ++---- examples/simple-client/sync/fiddle_client.py | 7 ++--- examples/simple-client/sync/latency_client.py | 8 ++---- src/socketio/asyncio_simple_client.py | 28 +++++++++++++------ src/socketio/simple_client.py | 26 +++++++++++------ tests/asyncio/test_asyncio_simple_client.py | 24 ++++++++++++++++ tests/common/test_simple_client.py | 17 +++++++++++ 9 files changed, 109 insertions(+), 38 deletions(-) diff --git a/docs/client.rst b/docs/client.rst index 3344bd35..aea9aba3 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -35,8 +35,26 @@ the application. Creating a Client Instance ~~~~~~~~~~~~~~~~~~~~~~~~~~ -To instantiate a Socket.IO client, create an instance of the appropriate client -class:: +The easiest way to create a Socket.IO client is to use the context manager +interface:: + + import socketio + + # standard Python + with socketio.SimpleClient() as sio: + # ... connect to a server and use the client + # ... no need to manually disconnect! + + # asyncio + async with socketio.AsyncSimpleClient() as sio: + # ... connect to a server and use the client + # ... no need to manually disconnect! + + +With this usage the context manager will ensure that the client is properly +disconnected before exiting the ``with`` or ``async with`` block. + +If preferred, a client can be manually instantiated:: import socketio diff --git a/examples/simple-client/async/fiddle_client.py b/examples/simple-client/async/fiddle_client.py index 305e71f6..d9744803 100644 --- a/examples/simple-client/async/fiddle_client.py +++ b/examples/simple-client/async/fiddle_client.py @@ -3,10 +3,9 @@ async def main(): - sio = socketio.AsyncSimpleClient() - await sio.connect('http://localhost:5000', auth={'token': 'my-token'}) - print(await sio.receive()) - await sio.disconnect() + async with socketio.AsyncSimpleClient() as sio: + await sio.connect('http://localhost:5000', auth={'token': 'my-token'}) + print(await sio.receive()) if __name__ == '__main__': diff --git a/examples/simple-client/async/latency_client.py b/examples/simple-client/async/latency_client.py index 1139cccd..96387c65 100644 --- a/examples/simple-client/async/latency_client.py +++ b/examples/simple-client/async/latency_client.py @@ -4,10 +4,8 @@ async def main(): - sio = socketio.AsyncSimpleClient() - await sio.connect('http://localhost:5000') - - try: + async with socketio.AsyncSimpleClient() as sio: + await sio.connect('http://localhost:5000') while True: start_timer = time.time() await sio.emit('ping_from_client') @@ -17,8 +15,6 @@ async def main(): print('latency is {0:.2f} ms'.format(latency * 1000)) await asyncio.sleep(1) - except (KeyboardInterrupt, asyncio.CancelledError): - await sio.disconnect() if __name__ == '__main__': diff --git a/examples/simple-client/sync/fiddle_client.py b/examples/simple-client/sync/fiddle_client.py index 2f79e97c..1be759cb 100644 --- a/examples/simple-client/sync/fiddle_client.py +++ b/examples/simple-client/sync/fiddle_client.py @@ -2,10 +2,9 @@ def main(): - sio = socketio.SimpleClient() - sio.connect('http://localhost:5000', auth={'token': 'my-token'}) - print(sio.receive()) - sio.disconnect() + with socketio.SimpleClient() as sio: + sio.connect('http://localhost:5000', auth={'token': 'my-token'}) + print(sio.receive()) if __name__ == '__main__': diff --git a/examples/simple-client/sync/latency_client.py b/examples/simple-client/sync/latency_client.py index 2bf76577..d5cd853e 100644 --- a/examples/simple-client/sync/latency_client.py +++ b/examples/simple-client/sync/latency_client.py @@ -3,10 +3,8 @@ def main(): - sio = socketio.SimpleClient() - sio.connect('http://localhost:5000') - - try: + with socketio.SimpleClient() as sio: + sio.connect('http://localhost:5000') while True: start_timer = time.time() sio.emit('ping_from_client') @@ -16,8 +14,6 @@ def main(): print('latency is {0:.2f} ms'.format(latency * 1000)) time.sleep(1) - except KeyboardInterrupt: - sio.disconnect() if __name__ == '__main__': diff --git a/src/socketio/asyncio_simple_client.py b/src/socketio/asyncio_simple_client.py index f0066efa..68dce66f 100644 --- a/src/socketio/asyncio_simple_client.py +++ b/src/socketio/asyncio_simple_client.py @@ -59,21 +59,21 @@ async def connect(self, url, headers={}, auth=None, transports=None, self.input_event.clear() self.client = AsyncClient(*self.client_args, **self.client_kwargs) - @self.client.event + @self.client.event(namespace=self.namespace) def connect(): # pragma: no cover self.connected = True self.connected_event.set() - @self.client.event + @self.client.event(namespace=self.namespace) def disconnect(): # pragma: no cover self.connected_event.clear() - @self.client.event + @self.client.event(namespace=self.namespace) def __disconnect_final(): # pragma: no cover self.connected = False self.connected_event.set() - @self.client.on('*') + @self.client.on('*', namespace=self.namespace) def on_event(event, *args): # pragma: no cover self.input_buffer.append([event, *args]) self.input_event.set() @@ -172,8 +172,12 @@ async def receive(self, timeout=None): the server included arguments with the event, they are returned as additional list elements. """ - if not self.input_buffer: - await self.connected_event.wait() + while not self.input_buffer: + try: + await asyncio.wait_for(self.connected_event.wait(), + timeout=timeout) + except asyncio.TimeoutError: # pragma: no cover + raise TimeoutError() if not self.connected: raise DisconnectedError() try: @@ -189,5 +193,13 @@ async def disconnect(self): Note: this method is a coroutine. i """ - await self.client.disconnect() - self.client = None + if self.connected: + await self.client.disconnect() + self.client = None + self.connected = False + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.disconnect() diff --git a/src/socketio/simple_client.py b/src/socketio/simple_client.py index 9a58cba1..4a883806 100644 --- a/src/socketio/simple_client.py +++ b/src/socketio/simple_client.py @@ -57,21 +57,21 @@ def connect(self, url, headers={}, auth=None, transports=None, self.input_event.clear() self.client = Client(*self.client_args, **self.client_kwargs) - @self.client.event + @self.client.event(namespace=self.namespace) def connect(): # pragma: no cover self.connected = True self.connected_event.set() - @self.client.event + @self.client.event(namespace=self.namespace) def disconnect(): # pragma: no cover self.connected_event.clear() - @self.client.event + @self.client.event(namespace=self.namespace) def __disconnect_final(): # pragma: no cover self.connected = False self.connected_event.set() - @self.client.on('*') + @self.client.on('*', namespace=self.namespace) def on_event(event, *args): # pragma: no cover self.input_buffer.append([event, *args]) self.input_event.set() @@ -162,8 +162,10 @@ def receive(self, timeout=None): the server included arguments with the event, they are returned as additional list elements. """ - if not self.input_buffer: - self.connected_event.wait() + while not self.input_buffer: + if not self.connected_event.wait( + timeout=timeout): # pragma: no cover + raise TimeoutError() if not self.connected: raise DisconnectedError() if not self.input_event.wait(timeout=timeout): @@ -173,5 +175,13 @@ def receive(self, timeout=None): def disconnect(self): """Disconnect from the server.""" - self.client.disconnect() - self.client = None + if self.connected: + self.client.disconnect() + self.client = None + self.connected = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() diff --git a/tests/asyncio/test_asyncio_simple_client.py b/tests/asyncio/test_asyncio_simple_client.py index 9188c4fe..6a935978 100644 --- a/tests/asyncio/test_asyncio_simple_client.py +++ b/tests/asyncio/test_asyncio_simple_client.py @@ -34,6 +34,28 @@ def test_connect(self): assert client.namespace == 'n' assert not client.input_event.is_set() + def test_connect_context_manager(self): + async def _t(): + async with AsyncSimpleClient(123, a='b') as client: + with mock.patch('socketio.asyncio_simple_client.AsyncClient') \ + as mock_client: + mock_client.return_value.connect = AsyncMock() + + await client.connect('url', headers='h', auth='a', + transports='t', namespace='n', + socketio_path='s') + mock_client.assert_called_once_with(123, a='b') + assert client.client == mock_client() + mock_client().connect.mock.assert_called_once_with( + 'url', headers='h', auth='a', transports='t', + namespaces=['n'], socketio_path='s') + mock_client().event.call_count == 3 + mock_client().on.called_once_with('*') + assert client.namespace == 'n' + assert not client.input_event.is_set() + + _run(_t()) + def test_connect_twice(self): client = AsyncSimpleClient(123, a='b') client.client = mock.MagicMock() @@ -158,6 +180,8 @@ def test_disconnect(self): mc = mock.MagicMock() mc.disconnect = AsyncMock() client.client = mc + client.connected = True + _run(client.disconnect()) _run(client.disconnect()) mc.disconnect.mock.assert_called_once_with() assert client.client is None diff --git a/tests/common/test_simple_client.py b/tests/common/test_simple_client.py index f445ff85..2a0b7b7d 100644 --- a/tests/common/test_simple_client.py +++ b/tests/common/test_simple_client.py @@ -29,6 +29,21 @@ def test_connect(self): assert client.namespace == 'n' assert not client.input_event.is_set() + def test_connect_context_manager(self): + with SimpleClient(123, a='b') as client: + with mock.patch('socketio.simple_client.Client') as mock_client: + client.connect('url', headers='h', auth='a', transports='t', + namespace='n', socketio_path='s') + mock_client.assert_called_once_with(123, a='b') + assert client.client == mock_client() + mock_client().connect.assert_called_once_with( + 'url', headers='h', auth='a', transports='t', + namespaces=['n'], socketio_path='s') + mock_client().event.call_count == 3 + mock_client().on.called_once_with('*') + assert client.namespace == 'n' + assert not client.input_event.is_set() + def test_connect_twice(self): client = SimpleClient(123, a='b') client.client = mock.MagicMock() @@ -141,6 +156,8 @@ def test_disconnect(self): client = SimpleClient() mc = mock.MagicMock() client.client = mc + client.connected = True + client.disconnect() client.disconnect() mc.disconnect.assert_called_once_with() assert client.client is None