Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions pymodbus/client/asynchronous/async_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,8 @@ def __init__(self, protocol_class=None, loop=None, framer=None):
self.framer = framer
ReconnectingAsyncioModbusTcpClient.__init__(self, protocol_class, loop)

async def start(self, host, port=802, sslctx=None,
server_hostname=None, certfile=None, keyfile=None,
password=None, **kwargs):
async def start(self, host='localhost', port=802, sslctx=None,
certfile=None, keyfile=None, password=None, **kwargs):
"""
Initiates connection to start client
:param host: The host to connect to (default localhost)
Expand All @@ -463,7 +462,6 @@ async def start(self, host, port=802, sslctx=None,
:param password: The password for for decrypting client's private key file
"""
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password)
self.server_hostname = server_hostname
return await ReconnectingAsyncioModbusTcpClient.start(self, host, port)

async def _connect(self):
Expand Down Expand Up @@ -840,8 +838,8 @@ async def init_tcp_client(proto_cls, loop, host, port, **kwargs):


async def init_tls_client(proto_cls, loop, host, port, sslctx=None,
server_hostname=None, certfile=None, keyfile=None,
password=None, framer=None, **kwargs):
certfile=None, keyfile=None, password=None,
framer=None, **kwargs):
"""
Helper function to initialize tcp client
:param proto_cls:
Expand All @@ -858,9 +856,7 @@ async def init_tls_client(proto_cls, loop, host, port, sslctx=None,
"""
client = ReconnectingAsyncioModbusTlsClient(protocol_class=proto_cls,
loop=loop, framer=framer)
await client.start(host, port, sslctx, server_hostname=server_hostname,
certfile=certfile, keyfile=keyfile, password=password,
**kwargs)
await client.start(host, port, sslctx, certfile, keyfile, password)
return client


Expand Down
4 changes: 3 additions & 1 deletion pymodbus/client/asynchronous/factory/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None,
framer=framer)
client = loop.run_until_complete(asyncio.gather(cor))[0]
elif loop is asyncio.get_event_loop():
return loop, init_tls_client(proto_cls, loop, host, port)
return loop, init_tls_client(proto_cls, loop, host, port,
sslctx, certfile, keyfile, password,
framer)
else:
cor = init_tls_client(proto_cls, loop, host, port,
sslctx, certfile, keyfile, password, framer)
Expand Down
40 changes: 18 additions & 22 deletions pymodbus/server/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,10 @@ class ModbusTlsServer(ModbusTcpServer):
server context instance.
"""

def __init__(self, context, framer=None, identity=None,
address=None, handler=None, allow_reuse_address=False,
sslctx=None, certfile=None, keyfile=None, **kwargs):
def __init__(self, context, framer=None, identity=None, address=None,
sslctx=None, certfile=None, keyfile=None, password=None,
reqclicert=False, handler=None, allow_reuse_address=False,
**kwargs):
""" Overloaded initializer for the ModbusTcpServer

If the identify structure is not passed in, the ModbusControlBlock
Expand All @@ -388,32 +389,24 @@ def __init__(self, context, framer=None, identity=None,
:param framer: The framer strategy to use
:param identity: An optional identify structure
:param address: An optional (interface, port) to bind to.
:param handler: A handler for each client session; default is
ModbusConnectedRequestHandler
:param allow_reuse_address: Whether the server will allow the
reuse of an address.
:param sslctx: The SSLContext to use for TLS (default None and auto
create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param password: The password for for decrypting the private key file
:param reqclicert: Force the sever request client's certificate
:param handler: A handler for each client session; default is
ModbusConnectedRequestHandler
:param allow_reuse_address: Whether the server will allow the
reuse of an address.
:param ignore_missing_slaves: True to not send errors on a request
to a missing slave
:param broadcast_enable: True to treat unit_id 0 as broadcast address,
False to treat 0 as any other unit_id
"""
framer = framer or ModbusTlsFramer
self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
self.sslctx.check_hostname = False
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password,
reqclicert)

ModbusTcpServer.__init__(self, context, framer, identity, address,
handler, allow_reuse_address, **kwargs)
Expand Down Expand Up @@ -627,7 +620,8 @@ def StartTcpServer(context=None, identity=None, address=None,


def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
certfile=None, keyfile=None, custom_functions=[], **kwargs):
certfile=None, keyfile=None, password=None, reqclicert=False,
custom_functions=[], **kwargs):
""" A factory to start and run a tls modbus server

:param context: The ModbusServerContext datastore
Expand All @@ -636,14 +630,16 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param password: The password for for decrypting the private key file
:param reqclicert: Force the sever request client's certificate
:param custom_functions: An optional list of custom function classes
supported by server instance.
:param ignore_missing_slaves: True to not send errors on a request to a
missing slave
"""
framer = kwargs.pop("framer", ModbusTlsFramer)
server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx,
certfile=certfile, keyfile=keyfile, **kwargs)
server = ModbusTlsServer(context, framer, identity, address, sslctx,
certfile, keyfile, password, reqclicert, **kwargs)

for f in custom_functions:
server.decoder.register(f)
Expand Down
22 changes: 15 additions & 7 deletions test/test_server_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pymodbus.server.sync import ModbusDisconnectedRequestHandler
from pymodbus.server.sync import ModbusTcpServer, ModbusTlsServer, ModbusUdpServer, ModbusSerialServer
from pymodbus.server.sync import StartTcpServer, StartTlsServer, StartUdpServer, StartSerialServer
from pymodbus.server.tls_helper import sslctx_provider
from pymodbus.exceptions import NotImplementedException
from pymodbus.bit_read_message import ReadCoilsRequest, ReadCoilsResponse
from pymodbus.datastore import ModbusServerContext
Expand Down Expand Up @@ -278,23 +279,30 @@ def testTcpServerProcess(self):
#-----------------------------------------------------------------------#
# Test TLS Server
#-----------------------------------------------------------------------#
def testTlsSSLCTX_Provider(self):
''' test that sslctx_provider() produce SSLContext correctly '''
with patch.object(ssl.SSLContext, 'load_cert_chain'):
sslctx = sslctx_provider(reqclicert=True)
self.assertIsNotNone(sslctx)
self.assertEqual(type(sslctx), ssl.SSLContext)
self.assertEqual(sslctx.verify_mode, ssl.CERT_REQUIRED)

sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
sslctx_new = sslctx_provider(sslctx=sslctx_old)
self.assertEqual(sslctx_new, sslctx_old)

def testTlsServerInit(self):
''' test that the synchronous TLS server initial correctly '''
with patch.object(socketserver.TCPServer, 'server_activate'):
with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method:
identity = ModbusDeviceIdentification(info={0x00: 'VendorName'})
server = ModbusTlsServer(context=None, identity=identity,
reqclicert=True,
bind_and_activate=False)
self.assertIs(server.framer, ModbusTlsFramer)
server.server_activate()
self.assertIsNotNone(server.sslctx)
self.assertEqual(type(server.socket), ssl.SSLSocket)
server.server_close()
sslctx = ssl.create_default_context()
server = ModbusTlsServer(context=None, identity=identity,
sslctx=sslctx, bind_and_activate=False)
server.server_activate()
self.assertEqual(server.sslctx, sslctx)
self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED)
self.assertEqual(type(server.socket), ssl.SSLSocket)
server.server_close()

Expand Down