Skip to content

Commit

Permalink
[Python] Add locking to prevent concurrent access with asyncio
Browse files Browse the repository at this point in the history
Make sure that different asyncio tasks do not run the same function
concurrently. This is done by adding an asyncio lock to functions
which use callbacks.
  • Loading branch information
agners committed Jun 18, 2024
1 parent 316a9bb commit 4925772
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions src/controller/python/chip/ChipDeviceCtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,33 +227,36 @@ def wrapper(*args, **kwargs):


class CallbackContext:
def __init__(self) -> None:
def __init__(self, lock: asyncio.Lock) -> None:
self._lock = lock
self._future = None

def __enter__(self):
async def __aenter__(self):
await self._lock.acquire()
self._future = concurrent.futures.Future()
return self

@property
def future(self) -> concurrent.futures.Future | None:
def future(self) -> typing.Optional[concurrent.futures.Future]:
return self._future

def __exit__(self, exc_type, exc_value, traceback):
async def __aexit__(self, exc_type, exc_value, traceback):
self._future = None
self._lock.release()


class CommissioningContext(CallbackContext):
def __init__(self, devCtrl: ChipDeviceController) -> None:
super().__init__()
def __init__(self, devCtrl: ChipDeviceController, lock: asyncio.Lock) -> None:
super().__init__(lock)
self._devCtrl = devCtrl

def __enter__(self):
super().__enter__()
async def __aenter__(self):
await super().__aenter__()
self._devCtrl._fabricCheckNodeId = -1
return self

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
async def __aexit__(self, exc_type, exc_value, traceback):
await super().__aexit__(exc_type, exc_value, traceback)


class CommissionableNode(discovery.CommissionableNode):
Expand Down Expand Up @@ -372,10 +375,11 @@ def __init__(self, name: str = ''):

self._Cluster = ChipClusters(builtins.chipStack)
self._Cluster.InitLib(self._dmLib)
self._commissioning_context: CommissioningContext = CommissioningContext(self)
self._open_window_context: CallbackContext = CallbackContext()
self._unpair_device_context: CallbackContext = CallbackContext()
self._pase_establishment_context: CallbackContext = CallbackContext()
self._commissioning_lock: asyncio.Lock = asyncio.Lock()
self._commissioning_context: CommissioningContext = CommissioningContext(self, self._commissioning_lock)
self._open_window_context: CallbackContext = CallbackContext(asyncio.Lock())
self._unpair_device_context: CallbackContext = CallbackContext(asyncio.Lock())
self._pase_establishment_context: CallbackContext = CallbackContext(self._commissioning_lock)

def _set_dev_ctrl(self, devCtrl, pairingDelegate):
def HandleCommissioningComplete(nodeId: int, err: PyChipError):
Expand Down Expand Up @@ -579,7 +583,7 @@ async def ConnectBLE(self, discriminator: int, setupPinCode: int, nodeid: int, i
self.CheckIsActive()

self._enablePairingCompleteCallback(True)
with self._commissioning_context as ctx:
async with self._commissioning_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_ConnectBLE(
self.devCtrl, discriminator, isShortDiscriminator, setupPinCode, nodeid)
Expand All @@ -591,7 +595,7 @@ async def ConnectBLE(self, discriminator: int, setupPinCode: int, nodeid: int, i
async def UnpairDevice(self, nodeid: int) -> None:
self.CheckIsActive()

with self._unpair_device_context as ctx:
async with self._unpair_device_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_UnpairDevice(
self.devCtrl, nodeid, self.cbHandleDeviceUnpairCompleteFunct)
Expand Down Expand Up @@ -632,7 +636,7 @@ def CloseSession(self, nodeid):
async def _establishPASESession(self, callFunct):
self.CheckIsActive()

with self._pase_establishment_context as ctx:
async with self._pase_establishment_context as ctx:
res = await self._ChipStack.CallAsync(callFunct)
res.raise_on_error()
await asyncio.futures.wrap_future(ctx.future)
Expand Down Expand Up @@ -795,7 +799,7 @@ async def OpenCommissioningWindow(self, nodeid: int, timeout: int, iteration: in
'''
self.CheckIsActive()

with self._open_window_context as ctx:
async with self._open_window_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_OpenCommissioningWindow(
self.devCtrl, self.pairingDelegate, nodeid, timeout, iteration, discriminator, option)
Expand Down Expand Up @@ -1814,7 +1818,7 @@ def __init__(self, opCredsContext: ctypes.c_void_p, fabricId: int, nodeId: int,
f"caIndex({fabricAdmin.caIndex:x})/fabricId(0x{fabricId:016X})/nodeId(0x{nodeId:016X})"
)

self._issue_node_chain_context: CallbackContext = CallbackContext()
self._issue_node_chain_context: CallbackContext = CallbackContext(asyncio.Lock())
self._dmLib.pychip_DeviceController_SetIssueNOCChainCallbackPythonCallback(_IssueNOCChainCallbackPythonCallback)

pairingDelegate = c_void_p(None)
Expand Down Expand Up @@ -1869,7 +1873,7 @@ async def Commission(self, nodeid) -> int:
self.CheckIsActive()

self._enablePairingCompleteCallback(False)
with self._commissioning_context as ctx:
async with self._commissioning_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_Commission(
self.devCtrl, nodeid)
Expand Down Expand Up @@ -2017,7 +2021,7 @@ async def CommissionOnNetwork(self, nodeId: int, setupPinCode: int,
filter = str(filter)

self._enablePairingCompleteCallback(True)
with self._commissioning_context as ctx:
async with self._commissioning_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_OnNetworkCommission(
self.devCtrl, self.pairingDelegate, nodeId, setupPinCode, int(filterType), str(filter).encode("utf-8") if filter is not None else None, discoveryTimeoutMsec)
Expand All @@ -2038,7 +2042,7 @@ async def CommissionWithCode(self, setupPayload: str, nodeid: int, discoveryType
self.CheckIsActive()

self._enablePairingCompleteCallback(True)
with self._commissioning_context as ctx:
async with self._commissioning_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_ConnectWithCode(
self.devCtrl, setupPayload.encode("utf-8"), nodeid, discoveryType.value)
Expand All @@ -2058,7 +2062,7 @@ async def CommissionIP(self, ipaddr: str, setupPinCode: int, nodeid: int) -> int
self.CheckIsActive()

self._enablePairingCompleteCallback(True)
with self._commissioning_context as ctx:
async with self._commissioning_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_ConnectIP(
self.devCtrl, ipaddr.encode("utf-8"), setupPinCode, nodeid)
Expand All @@ -2079,7 +2083,7 @@ async def IssueNOCChain(self, csr: Clusters.OperationalCredentials.Commands.CSRR
The NOC chain will be provided in TLV cert format."""
self.CheckIsActive()

with self._issue_node_chain_context as ctx:
async with self._issue_node_chain_context as ctx:
res = await self._ChipStack.CallAsync(
lambda: self._dmLib.pychip_DeviceController_IssueNOCChain(
self.devCtrl, py_object(self), csr.NOCSRElements, len(csr.NOCSRElements), nodeId)
Expand Down

0 comments on commit 4925772

Please sign in to comment.