Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/server_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ def get_commandline():

# set defaults
comm_defaults = {
"tcp": ("socket", 5020),
"udp": ("socket", 5020),
"serial": ("rtu", "/dev/ptyp0"),
"tls": ("tls", 5020),
"tcp": ["socket", 5020],
"udp": ["socket", 5020],
"serial": ["rtu", "/dev/ptyp0"],
"tls": ["tls", 5020],
}
framers = {
"ascii": ModbusAsciiFramer,
Expand Down
8 changes: 4 additions & 4 deletions examples/server_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,10 @@ def get_commandline():

# set defaults
comm_defaults = {
"tcp": ("socket", 5020),
"udp": ("socket", 5020),
"serial": ("rtu", "/dev/ptyp0"),
"tls": ("tls", 5020),
"tcp": ["socket", 5020],
"udp": ["socket", 5020],
"serial": ["rtu", "/dev/ptyp0"],
"tls": ["tls", 5020],
}
framers = {
"ascii": ModbusAsciiFramer,
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def _connect(self):
_logger.debug("Connecting.")
try:
transport, protocol = await self.loop.create_connection(
self._create_protocol, self.params.host, self.params.port
self._create_protocol, host=self.params.host, port=self.params.port
)
return transport, protocol
except Exception as exc: # pylint: disable=broad-except
Expand Down
116 changes: 74 additions & 42 deletions pymodbus/server/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import traceback
import ssl
from time import sleep

import serial
from serial_asyncio import create_serial_connection
Expand Down Expand Up @@ -32,7 +33,8 @@
# --------------------------------------------------------------------------- #
# Allow access to server object, to e.g. make a shutdown
# --------------------------------------------------------------------------- #
ServerObject = None # pylint: disable=invalid-name
_server_stopped = None # pylint: disable=invalid-name
_server_stop = None # pylint: disable=invalid-name


def sslctx_provider(
Expand Down Expand Up @@ -551,11 +553,15 @@ async def serve_forever(self):
try:
await self.server.serve_forever()
except asyncio.exceptions.CancelledError:
pass
raise
except Exception as exc: # pylint: disable=broad-except
txt = f"Server unexpected exception {exc}"
_logger.error(txt)
else:
raise RuntimeError(
"Can't call serve_forever on an already running server object"
)
_logger.info("Server graceful shutdown.")

async def shutdown(self):
"""Shutdown server."""
Expand Down Expand Up @@ -892,6 +898,32 @@ async def serve_forever(self):
# Creation Factories
# --------------------------------------------------------------------------- #

async def _helper_run_server(server, custom_functions):
"""Help starting/stopping server."""
global _server_stopped, _server_stop # pylint: disable=global-statement,invalid-name

for func in custom_functions:
server.decoder.register(func)
_server_stopped = asyncio.Event()
_server_stop = asyncio.Event()
try:
server_task = asyncio.create_task(server.serve_forever())
except Exception as exc: # pylint: disable=broad-except
txt = f"Server caught exception: {exc}"
_logger.error(txt)
await _server_stop.wait()
await server.shutdown()
server_task.cancel()
owntask = asyncio.current_task()
for task in asyncio.all_tasks():
if task != owntask:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
_server_stopped.set()


async def StartAsyncTcpServer( # pylint: disable=invalid-name,dangerous-default-value
context=None,
Expand All @@ -914,17 +946,18 @@ async def StartAsyncTcpServer( # pylint: disable=invalid-name,dangerous-default
:param kwargs: The rest
:return: an initialized but inactive server object coroutine
"""
global ServerObject # pylint: disable=global-statement

framer = kwargs.pop("framer", ModbusSocketFramer)
ServerObject = ModbusTcpServer(context, framer, identity, address, **kwargs)

for func in custom_functions:
ServerObject.decoder.register(func) # pragma: no cover
server = ModbusTcpServer(
context,
framer,
identity,
address,
**kwargs
)

if defer_start:
return ServerObject
await ServerObject.serve_forever()
return server
await _helper_run_server(server, custom_functions)


async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default-value,too-many-arguments
Expand Down Expand Up @@ -963,10 +996,8 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default
:param kwargs: The rest
:return: an initialized but inactive server object coroutine
"""
global ServerObject # pylint: disable=global-statement

framer = kwargs.pop("framer", ModbusTlsFramer)
ServerObject = ModbusTlsServer(
server = ModbusTlsServer(
context,
framer,
identity,
Expand All @@ -980,13 +1011,9 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default
allow_reuse_port=allow_reuse_port,
**kwargs,
)

for func in custom_functions:
ServerObject.decoder.register(func) # pragma: no cover

if defer_start:
return ServerObject
await ServerObject.serve_forever()
return server
await _helper_run_server(server, custom_functions)


async def StartAsyncUdpServer( # pylint: disable=invalid-name,dangerous-default-value
Expand All @@ -1009,17 +1036,17 @@ async def StartAsyncUdpServer( # pylint: disable=invalid-name,dangerous-default
up without the ability to shut it off
:param kwargs:
"""
global ServerObject # pylint: disable=global-statement

framer = kwargs.pop("framer", ModbusSocketFramer)
ServerObject = ModbusUdpServer(context, framer, identity, address, **kwargs)

for func in custom_functions:
ServerObject.decoder.register(func) # pragma: no cover

server = ModbusUdpServer(
context,
framer,
identity,
address,
**kwargs
)
if defer_start:
return ServerObject
await ServerObject.serve_forever()
return server
await _helper_run_server(server, custom_functions)


async def StartAsyncSerialServer( # pylint: disable=invalid-name,dangerous-default-value
Expand All @@ -1040,17 +1067,17 @@ async def StartAsyncSerialServer( # pylint: disable=invalid-name,dangerous-defa
up without the ability to shut it off
:param kwargs: The rest
"""
global ServerObject # pylint: disable=global-statement

framer = kwargs.pop("framer", ModbusAsciiFramer)
ServerObject = ModbusSerialServer(context, framer, identity=identity, **kwargs)
for func in custom_functions:
ServerObject.decoder.register(func)

server = ModbusSerialServer(
context,
framer,
identity=identity,
**kwargs
)
if defer_start:
return ServerObject
await ServerObject.start()
await ServerObject.serve_forever()
return server
await server.start()
await _helper_run_server(server, custom_functions)


def StartSerialServer(**kwargs): # pylint: disable=invalid-name
Expand All @@ -1075,13 +1102,18 @@ def StartUdpServer(**kwargs): # pylint: disable=invalid-name

async def ServerAsyncStop(): # pylint: disable=invalid-name
"""Terminate server."""
global ServerObject # pylint: disable=global-statement,invalid-name
global _server_stopped, _server_stop # pylint: disable=invalid-name,global-variable-not-assigned

if ServerObject:
await ServerObject.shutdown()
ServerObject = None
_server_stop.set()
try:
await _server_stopped.wait()
except asyncio.exceptions.CancelledError:
pass


def ServerStop(): # pylint: disable=invalid-name
"""Terminate server."""
asyncio.run(ServerAsyncStop())
global _server_stopped, _server_stop # pylint: disable=invalid-name,global-variable-not-assigned

_server_stop.set()
sleep(10)
55 changes: 27 additions & 28 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from time import sleep
import logging

from unittest.mock import patch, MagicMock
import pytest
import pytest_asyncio

Expand All @@ -30,11 +29,11 @@
_logger.setLevel("DEBUG")

TEST_COMMS_FRAMER = [
("tcp", ModbusSocketFramer, 5020),
("tcp", ModbusRtuFramer, 5021),
("tls", ModbusTlsFramer, 5030),
("udp", ModbusSocketFramer, 5040),
("udp", ModbusRtuFramer, 5041),
("tcp", ModbusSocketFramer, 5021),
("tcp", ModbusRtuFramer, 5022),
("tls", ModbusTlsFramer, 5023),
("udp", ModbusSocketFramer, 5024),
("udp", ModbusRtuFramer, 5025),
("serial", ModbusRtuFramer, "dummy"),
("serial", ModbusAsciiFramer, "dummy"),
("serial", ModbusBinaryFramer, "dummy"),
Expand All @@ -52,34 +51,25 @@ class Commandline:
slaves = None


@pytest_asyncio.fixture(name="mock_libs")
def _helper_libs():
"""Patch ssl and pyserial-async libs."""
with patch('pymodbus.server.async_io.create_serial_connection') as mock_serial:
mock_serial.return_value = (MagicMock(), MagicMock())
yield True


@pytest_asyncio.fixture(name="mock_run_server")
async def _helper_server( # pylint: disable=unused-argument
mock_libs,
async def _helper_server(
test_comm,
test_framer,
test_port_offset,
test_port,
):
"""Run server."""
if test_comm in ("serial"):
yield
return
args = Commandline
args.comm = test_comm
args.framer = test_framer
args.port = test_port
args.port = test_port + test_port_offset
asyncio.create_task(run_async_server(args))
await asyncio.sleep(0.1)
yield True
yield
await ServerAsyncStop()
tasks = asyncio.all_tasks()
owntask = asyncio.current_task()
for i in [i for i in tasks if not (i.done() or i.cancelled() or i == owntask)]:
i.cancel()


async def run_client(
Expand All @@ -98,34 +88,43 @@ async def run_client(
await asyncio.sleep(0.1)


@pytest.mark.parametrize("test_port_offset", [10])
@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
async def test_exp_async_simple( # pylint: disable=unused-argument
test_comm,
test_framer,
test_port_offset,
test_port,
mock_run_server,
):
"""Run async client and server."""


@pytest.mark.parametrize("test_port_offset", [20])
@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
def test_exp_sync_simple( # pylint: disable=unused-argument
mock_libs,
def test_exp_sync_simple(
test_comm,
test_framer,
test_port_offset,
test_port,
):
"""Run sync client and server."""
if test_comm == "serial":
# missing mock of port
return
args = Commandline
args.comm = test_comm
args.port = test_port
args.port = test_port + test_port_offset
args.framer = test_framer
thread = Thread(target=run_sync_server, args=(args,))
thread.daemon = True
thread.start()
sleep(0.1)
sleep(1)
ServerStop()
_logger.error("jan igen")


@pytest.mark.parametrize("test_port_offset", [30])
@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
@pytest.mark.parametrize(
"test_type",
Expand All @@ -139,6 +138,7 @@ def test_exp_sync_simple( # pylint: disable=unused-argument
async def test_exp_async_framer( # pylint: disable=unused-argument
test_comm,
test_framer,
test_port_offset,
test_port,
mock_run_server,
test_type
Expand All @@ -147,11 +147,10 @@ async def test_exp_async_framer( # pylint: disable=unused-argument
if test_type == run_async_ext_calls and test_framer == ModbusRtuFramer: # pylint: disable=comparison-with-callable
return
if test_comm == "serial":
# mocking serial needs to pass data between send/receive
return

args = Commandline
args.framer = test_framer
args.comm = test_comm
args.port = test_port
args.port = test_port + test_port_offset
await run_client(test_comm, test_type, args=args)
Loading