Skip to content

Commit

Permalink
Add more type hints to helpers (home-assistant#20811)
Browse files Browse the repository at this point in the history
* Add type hints to helpers.aiohttp_client

* Add type hints to helpers.area_registry
  • Loading branch information
scop authored and balloob committed Feb 7, 2019
1 parent 16159cc commit d45f25c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
48 changes: 32 additions & 16 deletions homeassistant/helpers/aiohttp_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""Helper for aiohttp webclient stuff."""
import asyncio
import sys
from ssl import SSLContext # noqa: F401
from typing import Any, Awaitable, Optional, cast
from typing import Union # noqa: F401

import aiohttp
from aiohttp.hdrs import USER_AGENT, CONTENT_TYPE
from aiohttp import web
from aiohttp.web_exceptions import HTTPGatewayTimeout, HTTPBadGateway
import async_timeout

from homeassistant.core import callback
from homeassistant.core import callback, Event
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass
from homeassistant.util import ssl as ssl_util

Expand All @@ -23,7 +27,8 @@

@callback
@bind_hass
def async_get_clientsession(hass, verify_ssl=True):
def async_get_clientsession(hass: HomeAssistantType,
verify_ssl: bool = True) -> aiohttp.ClientSession:
"""Return default aiohttp ClientSession.
This method must be run in the event loop.
Expand All @@ -36,13 +41,15 @@ def async_get_clientsession(hass, verify_ssl=True):
if key not in hass.data:
hass.data[key] = async_create_clientsession(hass, verify_ssl)

return hass.data[key]
return cast(aiohttp.ClientSession, hass.data[key])


@callback
@bind_hass
def async_create_clientsession(hass, verify_ssl=True, auto_cleanup=True,
**kwargs):
def async_create_clientsession(hass: HomeAssistantType,
verify_ssl: bool = True,
auto_cleanup: bool = True,
**kwargs: Any) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies.
If auto_cleanup is False, you need to call detach() after the session
Expand All @@ -67,16 +74,18 @@ def async_create_clientsession(hass, verify_ssl=True, auto_cleanup=True,


@bind_hass
async def async_aiohttp_proxy_web(hass, request, web_coro,
buffer_size=102400, timeout=10):
async def async_aiohttp_proxy_web(
hass: HomeAssistantType, request: web.BaseRequest,
web_coro: Awaitable[aiohttp.ClientResponse], buffer_size: int = 102400,
timeout: int = 10) -> Optional[web.StreamResponse]:
"""Stream websession request to aiohttp web response."""
try:
with async_timeout.timeout(timeout, loop=hass.loop):
req = await web_coro

except asyncio.CancelledError:
# The user cancelled the request
return
return None

except asyncio.TimeoutError as err:
# Timeout trying to start the web request
Expand All @@ -98,8 +107,12 @@ async def async_aiohttp_proxy_web(hass, request, web_coro,


@bind_hass
async def async_aiohttp_proxy_stream(hass, request, stream, content_type,
buffer_size=102400, timeout=10):
async def async_aiohttp_proxy_stream(hass: HomeAssistantType,
request: web.BaseRequest,
stream: aiohttp.StreamReader,
content_type: str,
buffer_size: int = 102400,
timeout: int = 10) -> web.StreamResponse:
"""Stream a stream to aiohttp web response."""
response = web.StreamResponse()
response.content_type = content_type
Expand All @@ -122,13 +135,14 @@ async def async_aiohttp_proxy_stream(hass, request, stream, content_type,


@callback
def _async_register_clientsession_shutdown(hass, clientsession):
def _async_register_clientsession_shutdown(
hass: HomeAssistantType, clientsession: aiohttp.ClientSession) -> None:
"""Register ClientSession close on Home Assistant shutdown.
This method must be run in the event loop.
"""
@callback
def _async_close_websession(event):
def _async_close_websession(event: Event) -> None:
"""Close websession."""
clientsession.detach()

Expand All @@ -137,25 +151,27 @@ def _async_close_websession(event):


@callback
def _async_get_connector(hass, verify_ssl=True):
def _async_get_connector(hass: HomeAssistantType,
verify_ssl: bool = True) -> aiohttp.BaseConnector:
"""Return the connector pool for aiohttp.
This method must be run in the event loop.
"""
key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY

if key in hass.data:
return hass.data[key]
return cast(aiohttp.BaseConnector, hass.data[key])

if verify_ssl:
ssl_context = ssl_util.client_context()
ssl_context = \
ssl_util.client_context() # type: Union[bool, SSLContext]
else:
ssl_context = False

connector = aiohttp.TCPConnector(loop=hass.loop, ssl=ssl_context)
hass.data[key] = connector

async def _async_close_connector(event):
async def _async_close_connector(event: Event) -> None:
"""Close connector pool."""
await connector.close()

Expand Down
22 changes: 12 additions & 10 deletions homeassistant/helpers/area_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import logging
import uuid
from collections import OrderedDict
from typing import List, Optional
from typing import MutableMapping # noqa: F401
from typing import Iterable, Optional, cast

import attr

from homeassistant.core import callback
from homeassistant.loader import bind_hass
from .typing import HomeAssistantType

_LOGGER = logging.getLogger(__name__)

Expand All @@ -29,14 +31,14 @@ class AreaEntry:
class AreaRegistry:
"""Class to hold a registry of areas."""

def __init__(self, hass) -> None:
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the area registry."""
self.hass = hass
self.areas = None
self.areas = {} # type: MutableMapping[str, AreaEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)

@callback
def async_list_areas(self) -> List[AreaEntry]:
def async_list_areas(self) -> Iterable[AreaEntry]:
"""Get all areas."""
return self.areas.values()

Expand Down Expand Up @@ -81,18 +83,18 @@ def async_update(self, area_id: str, name: str) -> AreaEntry:
return new

@callback
def _async_is_registered(self, name) -> Optional[AreaEntry]:
def _async_is_registered(self, name: str) -> Optional[AreaEntry]:
"""Check if a name is currently registered."""
for area in self.areas.values():
if name == area.name:
return area
return False
return None

async def async_load(self) -> None:
"""Load the area registry."""
data = await self._store.async_load()

areas = OrderedDict()
areas = OrderedDict() # type: OrderedDict[str, AreaEntry]

if data is not None:
for area in data['areas']:
Expand Down Expand Up @@ -124,16 +126,16 @@ def _data_to_save(self) -> dict:


@bind_hass
async def async_get_registry(hass) -> AreaRegistry:
async def async_get_registry(hass: HomeAssistantType) -> AreaRegistry:
"""Return area registry instance."""
task = hass.data.get(DATA_REGISTRY)

if task is None:
async def _load_reg():
async def _load_reg() -> AreaRegistry:
registry = AreaRegistry(hass)
await registry.async_load()
return registry

task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())

return await task
return cast(AreaRegistry, await task)
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
deps =
-r{toxinidir}/requirements_test.txt
commands =
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,condition,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,aiohttp_client,area_registry,condition,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'

0 comments on commit d45f25c

Please sign in to comment.