Skip to content

Commit

Permalink
Properly handle unknown broadcast messages (pschmitt#231)
Browse files Browse the repository at this point in the history
* Properly handle unknown broadcast messages

* Fix typing for capabilities

* Add review suggestions
  • Loading branch information
Orhideous authored Feb 15, 2024
1 parent 6cdf5a8 commit b1d4364
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 61 deletions.
123 changes: 122 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ roomba-password = "roombapy.entry_points:password"
python = ">=3.10,<4.0"
paho-mqtt = ">=1.5.1,<3.0.0"
orjson = ">=3.9.13"
pydantic = "^2.6.1"

[tool.poetry.dev-dependencies]
pytest = "^8.0"
Expand Down
51 changes: 23 additions & 28 deletions roombapy/discovery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import socket
from typing import Optional

import orjson
from pydantic import ValidationError

from roombapy.roomba_info import RoombaInfo

Expand Down Expand Up @@ -51,26 +52,15 @@ def _get_response(self, ip=None):
self.log.debug(
"Received response: %s, address: %s", raw_response, addr
)
data = raw_response.decode()
if self._is_from_irobot(data):
return _decode_data(data)
response = _decode_data(raw_response)
if not response:
continue
else:
return response
except socket.timeout:
self.log.info("Socket timeout")
return None

def _is_from_irobot(self, data):
if data == self.roomba_message:
return False

json_response = orjson.loads(data)
if (
"Roomba" in json_response["hostname"]
or "iRobot" in json_response["hostname"]
):
return True

return False

def _broadcast_message(self, amount):
for i in range(amount):
self.server_socket.sendto(
Expand All @@ -89,17 +79,22 @@ def _start_server(self):
self.log.debug("Socket server started, port %s", self.udp_port)


def _decode_data(data):
json_response = orjson.loads(data)
return RoombaInfo(
hostname=json_response["hostname"],
robot_name=json_response["robotname"],
ip=json_response["ip"],
mac=json_response["mac"],
firmware=json_response["sw"],
sku=json_response["sku"],
capabilities=json_response["cap"],
)
def _decode_data(raw_response: bytes) -> Optional[RoombaInfo]:
try:
data = raw_response.decode()
except UnicodeDecodeError:
# Unknown ND response (routers, etc.)
return None

if data == RoombaDiscovery.roomba_message:
# Filter our own messages
return None

try:
return RoombaInfo.model_validate_json(data)
except ValidationError:
# Malformed json from robots
return None


def _get_socket():
Expand Down
62 changes: 30 additions & 32 deletions roombapy/roomba_info.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
class RoombaInfo:
hostname = None
firmware = None
ip = None
mac = None
robot_name = None
sku = None
capabilities = None
blid = None
password = None
from functools import cached_property
from typing import Dict, Optional

def __init__(
self, hostname, robot_name, ip, mac, firmware, sku, capabilities
):
"""Create object with information about roomba."""
self.hostname = hostname
self.firmware = firmware
self.ip = ip
self.mac = mac
self.robot_name = robot_name
self.sku = sku
self.capabilities = capabilities
self.blid = hostname.split("-")[1]
from pydantic import BaseModel, Field, computed_field, field_validator

def __str__(self) -> str:
"""Nice output to console."""
return ", ".join(
[
"{key}={value}".format(key=key, value=self.__dict__.get(key))
for key in self.__dict__
]
)

class RoombaInfo(BaseModel):
hostname: str
firmware: str = Field(alias="sw")
ip: str
mac: str
robot_name: str = Field(alias="robotname")
sku: str
capabilities: Dict[str, int] = Field(alias="cap")
password: Optional[str] = None

@field_validator("hostname")
@classmethod
def hostname_validator(cls, value: str) -> str:
if "-" not in value:
raise ValueError(f"hostname does not contain a dash: {value}")
model_name, blid = value.split("-")
if blid == "":
raise ValueError(f"empty blid: {value}")
if model_name.lower() not in {"roomba", "irobot"}:
raise ValueError(f"unsupported model in hostname: {value}")
return value

@computed_field
@cached_property
def blid(self) -> str:
return self.hostname.split("-")[1]

def __hash__(self) -> int:
"""Hashcode."""
return hash(self.mac)

def __eq__(self, o: object) -> bool:
"""Equals."""
return isinstance(o, RoombaInfo) and self.mac == o.mac
Loading

0 comments on commit b1d4364

Please sign in to comment.