Skip to content

Commit 1d86e19

Browse files
authored
Fix encryption disconnect race (#355)
1 parent 92da6db commit 1d86e19

File tree

2 files changed

+474
-60
lines changed

2 files changed

+474
-60
lines changed

switchbot/devices/device.py

Lines changed: 107 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,54 @@ def _commandkey(self, key: str) -> str:
203203
key_suffix = key[4:]
204204
return KEY_PASSWORD_PREFIX + key_action + self._password_encoded + key_suffix
205205

206+
async def _send_command_locked_with_retry(
207+
self, key: str, command: bytes, retry: int, max_attempts: int
208+
) -> bytes | None:
209+
for attempt in range(max_attempts):
210+
try:
211+
return await self._send_command_locked(key, command)
212+
except BleakNotFoundError:
213+
_LOGGER.error(
214+
"%s: device not found, no longer in range, or poor RSSI: %s",
215+
self.name,
216+
self.rssi,
217+
exc_info=True,
218+
)
219+
raise
220+
except CharacteristicMissingError as ex:
221+
if attempt == retry:
222+
_LOGGER.error(
223+
"%s: characteristic missing: %s; Stopping trying; RSSI: %s",
224+
self.name,
225+
ex,
226+
self.rssi,
227+
exc_info=True,
228+
)
229+
raise
230+
231+
_LOGGER.debug(
232+
"%s: characteristic missing: %s; RSSI: %s",
233+
self.name,
234+
ex,
235+
self.rssi,
236+
exc_info=True,
237+
)
238+
except BLEAK_RETRY_EXCEPTIONS:
239+
if attempt == retry:
240+
_LOGGER.error(
241+
"%s: communication failed; Stopping trying; RSSI: %s",
242+
self.name,
243+
self.rssi,
244+
exc_info=True,
245+
)
246+
raise
247+
248+
_LOGGER.debug(
249+
"%s: communication failed with:", self.name, exc_info=True
250+
)
251+
252+
raise RuntimeError("Unreachable")
253+
206254
async def _send_command(self, key: str, retry: int | None = None) -> bytes | None:
207255
"""Send command to device and read response."""
208256
if retry is None:
@@ -217,50 +265,9 @@ async def _send_command(self, key: str, retry: int | None = None) -> bytes | Non
217265
self.rssi,
218266
)
219267
async with self._operation_lock:
220-
for attempt in range(max_attempts):
221-
try:
222-
return await self._send_command_locked(key, command)
223-
except BleakNotFoundError:
224-
_LOGGER.error(
225-
"%s: device not found, no longer in range, or poor RSSI: %s",
226-
self.name,
227-
self.rssi,
228-
exc_info=True,
229-
)
230-
raise
231-
except CharacteristicMissingError as ex:
232-
if attempt == retry:
233-
_LOGGER.error(
234-
"%s: characteristic missing: %s; Stopping trying; RSSI: %s",
235-
self.name,
236-
ex,
237-
self.rssi,
238-
exc_info=True,
239-
)
240-
raise
241-
242-
_LOGGER.debug(
243-
"%s: characteristic missing: %s; RSSI: %s",
244-
self.name,
245-
ex,
246-
self.rssi,
247-
exc_info=True,
248-
)
249-
except BLEAK_RETRY_EXCEPTIONS:
250-
if attempt == retry:
251-
_LOGGER.error(
252-
"%s: communication failed; Stopping trying; RSSI: %s",
253-
self.name,
254-
self.rssi,
255-
exc_info=True,
256-
)
257-
raise
258-
259-
_LOGGER.debug(
260-
"%s: communication failed with:", self.name, exc_info=True
261-
)
262-
263-
raise RuntimeError("Unreachable")
268+
return await self._send_command_locked_with_retry(
269+
key, command, retry, max_attempts
270+
)
264271

265272
@property
266273
def name(self) -> str:
@@ -832,37 +839,73 @@ async def _send_command(
832839
if not encrypt:
833840
return await super()._send_command(key[:2] + "000000" + key[2:], retry)
834841

835-
result = await self._ensure_encryption_initialized()
836-
if not result:
837-
_LOGGER.error("Failed to initialize encryption")
838-
return None
842+
if retry is None:
843+
retry = self._retry_count
839844

840-
encrypted = (
841-
key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:])
842-
)
843-
result = await super()._send_command(encrypted, retry)
844-
return result[:1] + self._decrypt(result[4:])
845+
if self._operation_lock.locked():
846+
_LOGGER.debug(
847+
"%s: Operation already in progress, waiting for it to complete; RSSI: %s",
848+
self.name,
849+
self.rssi,
850+
)
851+
852+
async with self._operation_lock:
853+
if not (result := await self._ensure_encryption_initialized()):
854+
_LOGGER.error("Failed to initialize encryption")
855+
return None
856+
857+
encrypted = (
858+
key[:2] + self._key_id + self._iv[0:2].hex() + self._encrypt(key[2:])
859+
)
860+
command = bytearray.fromhex(self._commandkey(encrypted))
861+
_LOGGER.debug("%s: Scheduling command %s", self.name, command.hex())
862+
max_attempts = retry + 1
863+
864+
result = await self._send_command_locked_with_retry(
865+
encrypted, command, retry, max_attempts
866+
)
867+
if result is None:
868+
return None
869+
return result[:1] + self._decrypt(result[4:])
845870

846871
async def _ensure_encryption_initialized(self) -> bool:
872+
"""Ensure encryption is initialized, must be called with operation lock held."""
873+
assert self._operation_lock.locked(), "Operation lock must be held"
874+
847875
if self._iv is not None:
848876
return True
849877

850-
result = await self._send_command(
851-
COMMAND_GET_CK_IV + self._key_id, encrypt=False
878+
_LOGGER.debug("%s: Initializing encryption", self.name)
879+
# Call parent's _send_command_locked_with_retry directly since we already hold the lock
880+
key = COMMAND_GET_CK_IV + self._key_id
881+
command = bytearray.fromhex(self._commandkey(key[:2] + "000000" + key[2:]))
882+
883+
result = await self._send_command_locked_with_retry(
884+
key[:2] + "000000" + key[2:],
885+
command,
886+
self._retry_count,
887+
self._retry_count + 1,
852888
)
853-
ok = self._check_command_result(result, 0, {1})
854-
if ok:
889+
if result is None:
890+
return False
891+
892+
if ok := self._check_command_result(result, 0, {1}):
855893
self._iv = result[4:]
894+
self._cipher = None # Reset cipher when IV changes
895+
_LOGGER.debug("%s: Encryption initialized successfully", self.name)
856896

857897
return ok
858898

859899
async def _execute_disconnect(self) -> None:
860-
await super()._execute_disconnect()
861-
self._iv = None
862-
self._cipher = None
900+
async with self._connect_lock:
901+
self._iv = None
902+
self._cipher = None
903+
await self._execute_disconnect_with_lock()
863904

864905
def _get_cipher(self) -> Cipher:
865906
if self._cipher is None:
907+
if self._iv is None:
908+
raise RuntimeError("Cannot create cipher: IV is None")
866909
self._cipher = Cipher(
867910
algorithms.AES128(self._encryption_key), modes.CTR(self._iv)
868911
)
@@ -871,12 +914,16 @@ def _get_cipher(self) -> Cipher:
871914
def _encrypt(self, data: str) -> str:
872915
if len(data) == 0:
873916
return ""
917+
if self._iv is None:
918+
raise RuntimeError("Cannot encrypt: IV is None")
874919
encryptor = self._get_cipher().encryptor()
875920
return (encryptor.update(bytearray.fromhex(data)) + encryptor.finalize()).hex()
876921

877922
def _decrypt(self, data: bytearray) -> bytes:
878923
if len(data) == 0:
879924
return b""
925+
if self._iv is None:
926+
raise RuntimeError("Cannot decrypt: IV is None")
880927
decryptor = self._get_cipher().decryptor()
881928
return decryptor.update(data) + decryptor.finalize()
882929

0 commit comments

Comments
 (0)