diff --git a/homeassistant/auth/providers/trusted_networks.py b/homeassistant/auth/providers/trusted_networks.py index 8a7e1d67c6d25..d0bc45c326a1a 100644 --- a/homeassistant/auth/providers/trusted_networks.py +++ b/homeassistant/auth/providers/trusted_networks.py @@ -3,18 +3,23 @@ It shows list of users if access from trusted network. Abort login flow if not access from trusted network. """ -from typing import Any, Dict, Optional, cast +from ipaddress import ip_network, IPv4Address, IPv6Address, IPv4Network,\ + IPv6Network +from typing import Any, Dict, List, Optional, Union, cast import voluptuous as vol -from homeassistant.components.http import HomeAssistantHTTP # noqa: F401 +import homeassistant.helpers.config_validation as cv from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError - from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow from ..models import Credentials, UserMeta +IPAddress = Union[IPv4Address, IPv6Address] +IPNetwork = Union[IPv4Network, IPv6Network] + CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ + vol.Required('trusted_networks'): vol.All(cv.ensure_list, [ip_network]) }, extra=vol.PREVENT_EXTRA) @@ -35,6 +40,11 @@ class TrustedNetworksAuthProvider(AuthProvider): DEFAULT_TITLE = 'Trusted Networks' + @property + def trusted_networks(self) -> List[IPNetwork]: + """Return trusted networks.""" + return cast(List[IPNetwork], self.config['trusted_networks']) + @property def support_mfa(self) -> bool: """Trusted Networks auth provider does not support MFA.""" @@ -49,7 +59,7 @@ async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: if not user.system_generated and user.is_active} return TrustedNetworksLoginFlow( - self, cast(str, context.get('ip_address')), available_users) + self, cast(IPAddress, context.get('ip_address')), available_users) async def async_get_or_create_credentials( self, flow_result: Dict[str, str]) -> Credentials: @@ -80,19 +90,17 @@ async def async_user_meta_for_credentials( raise NotImplementedError @callback - def async_validate_access(self, ip_address: str) -> None: + def async_validate_access(self, ip_addr: IPAddress) -> None: """Make sure the access from trusted networks. Raise InvalidAuthError if not. Raise InvalidAuthError if trusted_networks is not configured. """ - hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP - - if not hass_http or not hass_http.trusted_networks: + if not self.trusted_networks: raise InvalidAuthError('trusted_networks is not configured') - if not any(ip_address in trusted_network for trusted_network - in hass_http.trusted_networks): + if not any(ip_addr in trusted_network for trusted_network + in self.trusted_networks): raise InvalidAuthError('Not in trusted_networks') @@ -100,12 +108,12 @@ class TrustedNetworksLoginFlow(LoginFlow): """Handler for the login flow.""" def __init__(self, auth_provider: TrustedNetworksAuthProvider, - ip_address: str, available_users: Dict[str, Optional[str]]) \ - -> None: + ip_addr: IPAddress, + available_users: Dict[str, Optional[str]]) -> None: """Initialize the login flow.""" super().__init__(auth_provider) self._available_users = available_users - self._ip_address = ip_address + self._ip_address = ip_addr async def async_step_init( self, user_input: Optional[Dict[str, str]] = None) \ diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index a018d5400338b..ca01610bcf949 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -86,13 +86,12 @@ async def async_from_config_dict(config: Dict[str, Any], log_no_color) core_config = config.get(core.DOMAIN, {}) - has_api_password = bool((config.get('http') or {}).get('api_password')) - has_trusted_networks = bool((config.get('http') or {}) - .get('trusted_networks')) + has_api_password = bool(config.get('http', {}).get('api_password')) + trusted_networks = config.get('http', {}).get('trusted_networks') try: await conf_util.async_process_ha_core_config( - hass, core_config, has_api_password, has_trusted_networks) + hass, core_config, has_api_password, trusted_networks) except vol.Invalid as config_err: conf_util.async_log_exception( config_err, 'homeassistant', core_config, hass) diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 7dca332058c1c..f57068081a56d 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -52,6 +52,17 @@ DEFAULT_DEVELOPMENT = '0' NO_LOGIN_ATTEMPT_THRESHOLD = -1 + +def trusted_networks_deprecated(value): + """Warn user trusted_networks config is deprecated.""" + _LOGGER.warning( + "Configuring trusted_networks via the http component has been" + " deprecated. Use the trusted networks auth provider instead." + " For instructions, see https://www.home-assistant.io/docs/" + "authentication/providers/#trusted-networks") + return value + + HTTP_SCHEMA = vol.Schema({ vol.Optional(CONF_API_PASSWORD): cv.string, vol.Optional(CONF_SERVER_HOST, default=DEFAULT_SERVER_HOST): cv.string, @@ -66,7 +77,7 @@ vol.Inclusive(CONF_TRUSTED_PROXIES, 'proxy'): vol.All(cv.ensure_list, [ip_network]), vol.Optional(CONF_TRUSTED_NETWORKS, default=[]): - vol.All(cv.ensure_list, [ip_network]), + vol.All(cv.ensure_list, [ip_network], trusted_networks_deprecated), vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD, default=NO_LOGIN_ATTEMPT_THRESHOLD): vol.Any(cv.positive_int, NO_LOGIN_ATTEMPT_THRESHOLD), diff --git a/homeassistant/config.py b/homeassistant/config.py index 3310cd3e160d9..492db240eeecd 100644 --- a/homeassistant/config.py +++ b/homeassistant/config.py @@ -429,7 +429,7 @@ def _format_config_error(ex: vol.Invalid, domain: str, config: Dict) -> str: async def async_process_ha_core_config( hass: HomeAssistant, config: Dict, has_api_password: bool = False, - has_trusted_networks: bool = False) -> None: + trusted_networks: Optional[Any] = None) -> None: """Process the [homeassistant] section from the configuration. This method is a coroutine. @@ -446,8 +446,11 @@ async def async_process_ha_core_config( ] if has_api_password: auth_conf.append({'type': 'legacy_api_password'}) - if has_trusted_networks: - auth_conf.append({'type': 'trusted_networks'}) + if trusted_networks: + auth_conf.append({ + 'type': 'trusted_networks', + 'trusted_networks': trusted_networks, + }) mfa_conf = config.get(CONF_AUTH_MFA_MODULES, [ {'type': 'totp', 'id': 'totp', 'name': 'Authenticator app'}, diff --git a/tests/auth/providers/test_trusted_networks.py b/tests/auth/providers/test_trusted_networks.py index 0ca302f827305..57e74e750d562 100644 --- a/tests/auth/providers/test_trusted_networks.py +++ b/tests/auth/providers/test_trusted_networks.py @@ -1,5 +1,5 @@ """Test the Trusted Networks auth provider.""" -from unittest.mock import Mock +from ipaddress import ip_address import pytest import voluptuous as vol @@ -18,9 +18,17 @@ def store(hass): @pytest.fixture def provider(hass, store): """Mock provider.""" - return tn_auth.TrustedNetworksAuthProvider(hass, store, { - 'type': 'trusted_networks' - }) + return tn_auth.TrustedNetworksAuthProvider( + hass, store, tn_auth.CONFIG_SCHEMA({ + 'type': 'trusted_networks', + 'trusted_networks': [ + '192.168.0.1', + '192.168.128.0/24', + '::1', + 'fd00::/8' + ] + }) + ) @pytest.fixture @@ -56,14 +64,17 @@ async def test_trusted_networks_credentials(manager, provider): async def test_validate_access(provider): """Test validate access from trusted networks.""" - with pytest.raises(tn_auth.InvalidAuthError): - provider.async_validate_access('192.168.0.1') - - provider.hass.http = Mock(trusted_networks=['192.168.0.1']) - provider.async_validate_access('192.168.0.1') + provider.async_validate_access(ip_address('192.168.0.1')) + provider.async_validate_access(ip_address('192.168.128.10')) + provider.async_validate_access(ip_address('::1')) + provider.async_validate_access(ip_address('fd01:db8::ff00:42:8329')) with pytest.raises(tn_auth.InvalidAuthError): - provider.async_validate_access('127.0.0.1') + provider.async_validate_access(ip_address('192.168.0.2')) + with pytest.raises(tn_auth.InvalidAuthError): + provider.async_validate_access(ip_address('127.0.0.1')) + with pytest.raises(tn_auth.InvalidAuthError): + provider.async_validate_access(ip_address('2001:db8::ff00:42:8329')) async def test_login_flow(manager, provider): @@ -71,22 +82,16 @@ async def test_login_flow(manager, provider): owner = await manager.async_create_user("test-owner") user = await manager.async_create_user("test-user") - # trusted network didn't loaded - flow = await provider.async_login_flow({'ip_address': '127.0.0.1'}) - step = await flow.async_step_init() - assert step['type'] == 'abort' - assert step['reason'] == 'not_whitelisted' - - provider.hass.http = Mock(trusted_networks=['192.168.0.1']) - # not from trusted network - flow = await provider.async_login_flow({'ip_address': '127.0.0.1'}) + flow = await provider.async_login_flow( + {'ip_address': ip_address('127.0.0.1')}) step = await flow.async_step_init() assert step['type'] == 'abort' assert step['reason'] == 'not_whitelisted' # from trusted network, list users - flow = await provider.async_login_flow({'ip_address': '192.168.0.1'}) + flow = await provider.async_login_flow( + {'ip_address': ip_address('192.168.0.1')}) step = await flow.async_step_init() assert step['step_id'] == 'init' diff --git a/tests/test_config.py b/tests/test_config.py index 212fc247eb9a8..e860ff53b3d6a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,6 +5,7 @@ import unittest import unittest.mock as mock from collections import OrderedDict +from ipaddress import ip_network import asynctest import pytest @@ -891,12 +892,14 @@ async def test_auth_provider_config_default_trusted_networks(hass): } if hasattr(hass, 'auth'): del hass.auth - await config_util.async_process_ha_core_config(hass, core_config, - has_trusted_networks=True) + await config_util.async_process_ha_core_config( + hass, core_config, trusted_networks=['192.168.0.1']) assert len(hass.auth.auth_providers) == 2 assert hass.auth.auth_providers[0].type == 'homeassistant' assert hass.auth.auth_providers[1].type == 'trusted_networks' + assert hass.auth.auth_providers[1].trusted_networks[0] == ip_network( + '192.168.0.1') async def test_disallowed_auth_provider_config(hass):