Skip to content

Commit

Permalink
Switch from attrs to dataclasses (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
OttoWinter authored Jun 29, 2021
1 parent 61cefdb commit 872c643
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 278 deletions.
70 changes: 18 additions & 52 deletions aioesphomeapi/client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import asyncio
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)

import attr
import zeroconf
from google.protobuf import message

Expand Down Expand Up @@ -79,6 +78,7 @@
CoverState,
DeviceInfo,
EntityInfo,
EntityState,
FanDirection,
FanInfo,
FanSpeed,
Expand All @@ -87,6 +87,7 @@
LegacyCoverCommand,
LightInfo,
LightState,
LogLevel,
NumberInfo,
NumberState,
SensorInfo,
Expand All @@ -96,13 +97,9 @@
TextSensorInfo,
TextSensorState,
UserService,
UserServiceArg,
UserServiceArgType,
)

if TYPE_CHECKING:
from aioesphomeapi.api_pb2 import LogLevel # type: ignore

_LOGGER = logging.getLogger(__name__)

ExecuteServiceDataType = Dict[
Expand Down Expand Up @@ -192,21 +189,13 @@ async def device_info(self) -> DeviceInfo:
resp = await self._connection.send_message_await_response(
DeviceInfoRequest(), DeviceInfoResponse
)
return DeviceInfo(
uses_password=resp.uses_password,
name=resp.name,
mac_address=resp.mac_address,
esphome_version=resp.esphome_version,
compilation_time=resp.compilation_time,
model=resp.model,
has_deep_sleep=resp.has_deep_sleep,
)
return DeviceInfo.from_pb(resp)

async def list_entities_services(
self,
) -> Tuple[List[EntityInfo], List[UserService]]:
self._check_authenticated()
response_types = {
response_types: Dict[Any, Optional[Type[EntityInfo]]] = {
ListEntitiesBinarySensorResponse: BinarySensorInfo,
ListEntitiesCoverResponse: CoverInfo,
ListEntitiesFanResponse: FanInfo,
Expand Down Expand Up @@ -234,39 +223,22 @@ def do_stop(msg: message.Message) -> bool:
services: List[UserService] = []
for msg in resp:
if isinstance(msg, ListEntitiesServicesResponse):
args = []
for arg in msg.args:
args.append(
UserServiceArg(
name=arg.name,
type_=arg.type,
)
)
services.append(
UserService(
name=msg.name,
key=msg.key,
args=args, # type: ignore
)
)
services.append(UserService.from_pb(msg))
continue
cls = None
for resp_type, cls in response_types.items():
if isinstance(msg, resp_type):
break
else:
continue
cls = cast(type, cls)
kwargs = {}
for key, _ in attr.fields_dict(cls).items():
kwargs[key] = getattr(msg, key)
entities.append(cls(**kwargs))
assert cls is not None
entities.append(cls.from_pb(msg))
return entities, services

async def subscribe_states(self, on_state: Callable[[Any], None]) -> None:
async def subscribe_states(self, on_state: Callable[[EntityState], None]) -> None:
self._check_authenticated()

response_types = {
response_types: Dict[Any, Type[EntityState]] = {
BinarySensorStateResponse: BinarySensorState,
CoverStateResponse: CoverState,
FanStateResponse: FanState,
Expand All @@ -284,7 +256,7 @@ def on_msg(msg: message.Message) -> None:
if isinstance(msg, CameraImageResponse):
data = image_stream.pop(msg.key, bytes()) + msg.data
if msg.done:
on_state(CameraState(key=msg.key, image=data))
on_state(CameraState.from_pb(msg))
else:
image_stream[msg.key] = data
return
Expand All @@ -295,11 +267,8 @@ def on_msg(msg: message.Message) -> None:
else:
return

kwargs = {}
# pylint: disable=undefined-loop-variable
for key, _ in attr.fields_dict(cls).items():
kwargs[key] = getattr(msg, key)
on_state(cls(**kwargs))
on_state(cls.from_pb(msg))

assert self._connection is not None
await self._connection.send_message_callback_response(
Expand All @@ -309,7 +278,7 @@ def on_msg(msg: message.Message) -> None:
async def subscribe_logs(
self,
on_log: Callable[[SubscribeLogsResponse], None],
log_level: Optional["LogLevel"] = None,
log_level: Optional[LogLevel] = None,
) -> None:
self._check_authenticated()

Expand All @@ -330,10 +299,7 @@ async def subscribe_service_calls(

def on_msg(msg: message.Message) -> None:
if isinstance(msg, HomeassistantServiceResponse):
kwargs = {}
for key, _ in attr.fields_dict(HomeassistantServiceCall).items():
kwargs[key] = getattr(msg, key)
on_service_call(HomeassistantServiceCall(**kwargs))
on_service_call(HomeassistantServiceCall.from_pb(msg))

assert self._connection is not None
await self._connection.send_message_callback_response(
Expand Down Expand Up @@ -571,12 +537,12 @@ async def execute_service(
UserServiceArgType.FLOAT_ARRAY: "float_array",
UserServiceArgType.STRING_ARRAY: "string_array",
}
# pylint: disable=redefined-outer-name
if arg_desc.type_ in map_array:
attr = getattr(arg, map_array[arg_desc.type_])
if arg_desc.type in map_array:
attr = getattr(arg, map_array[arg_desc.type])
attr.extend(val)
else:
setattr(arg, map_single[arg_desc.type_], val)
assert arg_desc.type in map_single
setattr(arg, map_single[arg_desc.type], val)

args.append(arg)
# pylint: disable=no-member
Expand Down
18 changes: 9 additions & 9 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import logging
import socket
import time
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, List, Optional, cast

import attr
import zeroconf
from google.protobuf import message

Expand All @@ -27,15 +27,15 @@
_LOGGER = logging.getLogger(__name__)


@attr.s
@dataclass
class ConnectionParams:
eventloop = attr.ib(type=asyncio.events.AbstractEventLoop)
address = attr.ib(type=str)
port = attr.ib(type=int)
password = attr.ib(type=Optional[str])
client_info = attr.ib(type=str)
keepalive = attr.ib(type=float)
zeroconf_instance = attr.ib(type=Optional[zeroconf.Zeroconf])
eventloop: asyncio.events.AbstractEventLoop
address: str
port: int
password: Optional[str]
client_info: str
keepalive: float
zeroconf_instance: Optional[zeroconf.Zeroconf]


class APIConnection:
Expand Down
Loading

0 comments on commit 872c643

Please sign in to comment.