Skip to content

Commit d01eed0

Browse files
committed
Fix pymodbus TLS module conflicts in 3.0.0
Developers add/fix features at the same time, then produce the conflicts in pymodbus' TLS module. This patch tries to fix the conflicts.
1 parent 9e37031 commit d01eed0

File tree

4 files changed

+41
-39
lines changed

4 files changed

+41
-39
lines changed

pymodbus/client/asynchronous/async_io/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,8 @@ def __init__(self, protocol_class=None, loop=None, framer=None):
450450
self.framer = framer
451451
ReconnectingAsyncioModbusTcpClient.__init__(self, protocol_class, loop)
452452

453-
async def start(self, host, port=802, sslctx=None,
454-
server_hostname=None, certfile=None, keyfile=None,
455-
password=None, **kwargs):
453+
async def start(self, host='localhost', port=802, sslctx=None,
454+
certfile=None, keyfile=None, password=None, **kwargs):
456455
"""
457456
Initiates connection to start client
458457
:param host: The host to connect to (default localhost)
@@ -463,7 +462,6 @@ async def start(self, host, port=802, sslctx=None,
463462
:param password: The password for for decrypting client's private key file
464463
"""
465464
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password)
466-
self.server_hostname = server_hostname
467465
return await ReconnectingAsyncioModbusTcpClient.start(self, host, port)
468466

469467
async def _connect(self):
@@ -840,8 +838,8 @@ async def init_tcp_client(proto_cls, loop, host, port, **kwargs):
840838

841839

842840
async def init_tls_client(proto_cls, loop, host, port, sslctx=None,
843-
server_hostname=None, certfile=None, keyfile=None,
844-
password=None, framer=None, **kwargs):
841+
certfile=None, keyfile=None, password=None,
842+
framer=None, **kwargs):
845843
"""
846844
Helper function to initialize tcp client
847845
:param proto_cls:
@@ -858,9 +856,7 @@ async def init_tls_client(proto_cls, loop, host, port, sslctx=None,
858856
"""
859857
client = ReconnectingAsyncioModbusTlsClient(protocol_class=proto_cls,
860858
loop=loop, framer=framer)
861-
await client.start(host, port, sslctx, server_hostname=server_hostname,
862-
certfile=certfile, keyfile=keyfile, password=password,
863-
**kwargs)
859+
await client.start(host, port, sslctx, certfile, keyfile, password)
864860
return client
865861

866862

pymodbus/client/asynchronous/factory/tls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None,
4848
framer=framer)
4949
client = loop.run_until_complete(asyncio.gather(cor))[0]
5050
elif loop is asyncio.get_event_loop():
51-
return loop, init_tls_client(proto_cls, loop, host, port)
51+
return loop, init_tls_client(proto_cls, loop, host, port,
52+
sslctx, certfile, keyfile, password,
53+
framer)
5254
else:
5355
cor = init_tls_client(proto_cls, loop, host, port,
5456
sslctx, certfile, keyfile, password, framer)

pymodbus/server/sync.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,10 @@ class ModbusTlsServer(ModbusTcpServer):
376376
server context instance.
377377
"""
378378

379-
def __init__(self, context, framer=None, identity=None,
380-
address=None, handler=None, allow_reuse_address=False,
381-
sslctx=None, certfile=None, keyfile=None, **kwargs):
379+
def __init__(self, context, framer=None, identity=None, address=None,
380+
sslctx=None, certfile=None, keyfile=None, password=None,
381+
reqclicert=False, handler=None, allow_reuse_address=False,
382+
**kwargs):
382383
""" Overloaded initializer for the ModbusTcpServer
383384
384385
If the identify structure is not passed in, the ModbusControlBlock
@@ -388,32 +389,24 @@ def __init__(self, context, framer=None, identity=None,
388389
:param framer: The framer strategy to use
389390
:param identity: An optional identify structure
390391
:param address: An optional (interface, port) to bind to.
391-
:param handler: A handler for each client session; default is
392-
ModbusConnectedRequestHandler
393-
:param allow_reuse_address: Whether the server will allow the
394-
reuse of an address.
395392
:param sslctx: The SSLContext to use for TLS (default None and auto
396393
create)
397394
:param certfile: The cert file path for TLS (used if sslctx is None)
398395
:param keyfile: The key file path for TLS (used if sslctx is None)
396+
:param password: The password for for decrypting the private key file
397+
:param reqclicert: Force the sever request client's certificate
398+
:param handler: A handler for each client session; default is
399+
ModbusConnectedRequestHandler
400+
:param allow_reuse_address: Whether the server will allow the
401+
reuse of an address.
399402
:param ignore_missing_slaves: True to not send errors on a request
400403
to a missing slave
401404
:param broadcast_enable: True to treat unit_id 0 as broadcast address,
402405
False to treat 0 as any other unit_id
403406
"""
404407
framer = framer or ModbusTlsFramer
405-
self.sslctx = sslctx
406-
if self.sslctx is None:
407-
self.sslctx = ssl.create_default_context()
408-
self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
409-
# According to MODBUS/TCP Security Protocol Specification, it is
410-
# TLSv2 at least
411-
self.sslctx.options |= ssl.OP_NO_TLSv1_1
412-
self.sslctx.options |= ssl.OP_NO_TLSv1
413-
self.sslctx.options |= ssl.OP_NO_SSLv3
414-
self.sslctx.options |= ssl.OP_NO_SSLv2
415-
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
416-
self.sslctx.check_hostname = False
408+
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password,
409+
reqclicert)
417410

418411
ModbusTcpServer.__init__(self, context, framer, identity, address,
419412
handler, allow_reuse_address, **kwargs)
@@ -627,7 +620,8 @@ def StartTcpServer(context=None, identity=None, address=None,
627620

628621

629622
def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
630-
certfile=None, keyfile=None, custom_functions=[], **kwargs):
623+
certfile=None, keyfile=None, password=None, reqclicert=False,
624+
custom_functions=[], **kwargs):
631625
""" A factory to start and run a tls modbus server
632626
633627
:param context: The ModbusServerContext datastore
@@ -636,14 +630,16 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
636630
:param sslctx: The SSLContext to use for TLS (default None and auto create)
637631
:param certfile: The cert file path for TLS (used if sslctx is None)
638632
:param keyfile: The key file path for TLS (used if sslctx is None)
633+
:param password: The password for for decrypting the private key file
634+
:param reqclicert: Force the sever request client's certificate
639635
:param custom_functions: An optional list of custom function classes
640636
supported by server instance.
641637
:param ignore_missing_slaves: True to not send errors on a request to a
642638
missing slave
643639
"""
644640
framer = kwargs.pop("framer", ModbusTlsFramer)
645-
server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx,
646-
certfile=certfile, keyfile=keyfile, **kwargs)
641+
server = ModbusTlsServer(context, framer, identity, address, sslctx,
642+
certfile, keyfile, password, reqclicert, **kwargs)
647643

648644
for f in custom_functions:
649645
server.decoder.register(f)

test/test_server_sync.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pymodbus.server.sync import ModbusDisconnectedRequestHandler
1717
from pymodbus.server.sync import ModbusTcpServer, ModbusTlsServer, ModbusUdpServer, ModbusSerialServer
1818
from pymodbus.server.sync import StartTcpServer, StartTlsServer, StartUdpServer, StartSerialServer
19+
from pymodbus.server.tls_helper import sslctx_provider
1920
from pymodbus.exceptions import NotImplementedException
2021
from pymodbus.bit_read_message import ReadCoilsRequest, ReadCoilsResponse
2122
from pymodbus.datastore import ModbusServerContext
@@ -278,23 +279,30 @@ def testTcpServerProcess(self):
278279
#-----------------------------------------------------------------------#
279280
# Test TLS Server
280281
#-----------------------------------------------------------------------#
282+
def testTlsSSLCTX_Provider(self):
283+
''' test that sslctx_provider() produce SSLContext correctly '''
284+
with patch.object(ssl.SSLContext, 'load_cert_chain'):
285+
sslctx = sslctx_provider(reqclicert=True)
286+
self.assertIsNotNone(sslctx)
287+
self.assertEqual(type(sslctx), ssl.SSLContext)
288+
self.assertEqual(sslctx.verify_mode, ssl.CERT_REQUIRED)
289+
290+
sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
291+
sslctx_new = sslctx_provider(sslctx=sslctx_old)
292+
self.assertEqual(sslctx_new, sslctx_old)
293+
281294
def testTlsServerInit(self):
282295
''' test that the synchronous TLS server initial correctly '''
283296
with patch.object(socketserver.TCPServer, 'server_activate'):
284297
with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method:
285298
identity = ModbusDeviceIdentification(info={0x00: 'VendorName'})
286299
server = ModbusTlsServer(context=None, identity=identity,
300+
reqclicert=True,
287301
bind_and_activate=False)
288302
self.assertIs(server.framer, ModbusTlsFramer)
289303
server.server_activate()
290304
self.assertIsNotNone(server.sslctx)
291-
self.assertEqual(type(server.socket), ssl.SSLSocket)
292-
server.server_close()
293-
sslctx = ssl.create_default_context()
294-
server = ModbusTlsServer(context=None, identity=identity,
295-
sslctx=sslctx, bind_and_activate=False)
296-
server.server_activate()
297-
self.assertEqual(server.sslctx, sslctx)
305+
self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED)
298306
self.assertEqual(type(server.socket), ssl.SSLSocket)
299307
server.server_close()
300308

0 commit comments

Comments
 (0)