Skip to content

Commit

Permalink
Remove cloud assist pipeline setup from cloud client (home-assistant#…
Browse files Browse the repository at this point in the history
  • Loading branch information
emontnemery authored Apr 26, 2023
1 parent 6b931b2 commit ed737f3
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 133 deletions.
41 changes: 23 additions & 18 deletions homeassistant/components/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,25 +238,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await prefs.async_initialize()

# Initialize Cloud
loaded = False

async def _discover_platforms():
"""Discover platforms."""
nonlocal loaded

# Prevent multiple discovery
if loaded:
return
loaded = True

await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, {}, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config)

websession = async_get_clientsession(hass)
client = CloudClient(
hass, prefs, websession, alexa_conf, google_conf, _discover_platforms
)
client = CloudClient(hass, prefs, websession, alexa_conf, google_conf)
cloud = hass.data[DOMAIN] = Cloud(client, **kwargs)
cloud.iot.register_on_connect(client.on_cloud_connected)

Expand Down Expand Up @@ -288,6 +271,27 @@ async def async_startup_repairs(_=None) -> None:
if subscription_info := await async_subscription_info(cloud):
async_manage_legacy_subscription_issue(hass, subscription_info)

loaded = False

async def _on_start():
"""Discover platforms."""
nonlocal loaded

# Prevent multiple discovery
if loaded:
return
loaded = True

stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
stt_info = {"platform_loaded": stt_platform_loaded}
tts_info = {"platform_loaded": tts_platform_loaded}

await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, stt_info, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config)
await asyncio.gather(stt_platform_loaded.wait(), tts_platform_loaded.wait())

async def _on_connect():
"""Handle cloud connect."""
async_dispatcher_send(
Expand All @@ -304,6 +308,7 @@ async def _on_initialized():
"""Update preferences."""
await prefs.async_update(remote_domain=cloud.remote.instance_domain)

cloud.register_on_start(_on_start)
cloud.iot.register_on_connect(_on_connect)
cloud.iot.register_on_disconnect(_on_disconnect)
cloud.register_on_initialized(_on_initialized)
Expand Down
35 changes: 1 addition & 34 deletions homeassistant/components/cloud/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable, Coroutine
from http import HTTPStatus
import logging
from pathlib import Path
Expand All @@ -11,13 +10,7 @@
import aiohttp
from hass_nabucasa.client import CloudClient as Interface

from homeassistant.components import (
assist_pipeline,
conversation,
google_assistant,
persistent_notification,
webhook,
)
from homeassistant.components import google_assistant, persistent_notification, webhook
from homeassistant.components.alexa import (
errors as alexa_errors,
smart_home as alexa_smart_home,
Expand All @@ -43,7 +36,6 @@ def __init__(
websession: aiohttp.ClientSession,
alexa_user_config: dict[str, Any],
google_user_config: dict[str, Any],
on_started_cb: Callable[[], Coroutine[Any, Any, None]],
) -> None:
"""Initialize client interface to Cloud."""
self._hass = hass
Expand All @@ -56,10 +48,6 @@ def __init__(
self._alexa_config_init_lock = asyncio.Lock()
self._google_config_init_lock = asyncio.Lock()
self._relayer_region: str | None = None
self._on_started_cb = on_started_cb
self.cloud_pipeline = self._cloud_assist_pipeline()
self.stt_platform_loaded = asyncio.Event()
self.tts_platform_loaded = asyncio.Event()

@property
def base_path(self) -> Path:
Expand Down Expand Up @@ -148,22 +136,6 @@ async def get_google_config(self) -> google_config.CloudGoogleConfig:

return self._google_config

def _cloud_assist_pipeline(self) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None."""
for pipeline in assist_pipeline.async_get_pipelines(self._hass):
if (
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
return None

async def create_cloud_assist_pipeline(self) -> None:
"""Create a cloud-enabled assist pipeline."""
await assist_pipeline.async_create_default_pipeline(self._hass, DOMAIN, DOMAIN)
self.cloud_pipeline = self._cloud_assist_pipeline()

async def on_cloud_connected(self) -> None:
"""When cloud is connected."""
is_new_user = await self.prefs.async_set_username(self.cloud.username)
Expand Down Expand Up @@ -211,11 +183,6 @@ async def enable_google(_):

async def cloud_started(self) -> None:
"""When cloud is started."""
await self._on_started_cb()
await asyncio.gather(
self.stt_platform_loaded.wait(),
self.tts_platform_loaded.wait(),
)

async def cloud_stopped(self) -> None:
"""When the cloud is stopped."""
Expand Down
25 changes: 19 additions & 6 deletions homeassistant/components/cloud/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hass_nabucasa.voice import MAP_VOICE
import voluptuous as vol

from homeassistant.components import websocket_api
from homeassistant.components import assist_pipeline, conversation, websocket_api
from homeassistant.components.alexa import (
entities as alexa_entities,
errors as alexa_errors,
Expand Down Expand Up @@ -182,15 +182,28 @@ class CloudLoginView(HomeAssistantView):
)
async def post(self, request, data):
"""Handle login request."""

def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None."""
for pipeline in assist_pipeline.async_get_pipelines(hass):
if (
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
return None

hass = request.app["hass"]
cloud = hass.data[DOMAIN]
await cloud.login(data["email"], data["password"])

if cloud.client.cloud_pipeline is None:
await cloud.client.create_cloud_assist_pipeline()
return self.json(
{"success": True, "cloud_pipeline": cloud.client.cloud_pipeline}
)
if (cloud_pipeline_id := cloud_assist_pipeline(hass)) is None:
if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
hass, DOMAIN, DOMAIN
):
cloud_pipeline_id = cloud_pipeline.id
return self.json({"success": True, "cloud_pipeline": cloud_pipeline_id})


class CloudLogoutView(HomeAssistantView):
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/cloud/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ async def async_get_engine(hass, config, discovery_info=None):
cloud: Cloud = hass.data[DOMAIN]

cloud_provider = CloudProvider(cloud)
cloud.client.stt_platform_loaded.set()
if discovery_info is not None:
discovery_info["platform_loaded"].set()
return cloud_provider


Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/cloud/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ async def async_get_engine(hass, config, discovery_info=None):
gender = config[ATTR_GENDER]

cloud_provider = CloudProvider(cloud, language, gender)
cloud.client.tts_platform_loaded.set()
if discovery_info is not None:
discovery_info["platform_loaded"].set()
return cloud_provider


Expand Down
47 changes: 4 additions & 43 deletions tests/components/cloud/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
from aiohttp import web
import pytest

from homeassistant.components.assist_pipeline import (
Pipeline,
async_get_pipeline,
async_get_pipelines,
)
from homeassistant.components.cloud import DOMAIN
from homeassistant.components.cloud.client import CloudClient
from homeassistant.components.cloud.const import (
Expand Down Expand Up @@ -303,39 +298,31 @@ async def test_google_config_should_2fa(
assert not gconf.should_2fa(state)


@patch(
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
return_value=[],
)
async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None:
async def test_set_username(hass: HomeAssistant) -> None:
"""Test we set username during login."""
prefs = MagicMock(
alexa_enabled=False,
google_enabled=False,
async_set_username=AsyncMock(return_value=None),
)
client = CloudClient(hass, prefs, None, {}, {}, AsyncMock())
client = CloudClient(hass, prefs, None, {}, {})
client.cloud = MagicMock(is_logged_in=True, username="mock-username")
await client.on_cloud_connected()

assert len(prefs.async_set_username.mock_calls) == 1
assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username"


@patch(
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
return_value=[],
)
async def test_login_recovers_bad_internet(
async_get_pipelines, hass: HomeAssistant, caplog: pytest.LogCaptureFixture
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test Alexa can recover bad auth."""
prefs = Mock(
alexa_enabled=True,
google_enabled=False,
async_set_username=AsyncMock(return_value=None),
)
client = CloudClient(hass, prefs, None, {}, {}, AsyncMock())
client = CloudClient(hass, prefs, None, {}, {})
client.cloud = Mock()
client._alexa_config = Mock(
async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError)
Expand Down Expand Up @@ -367,29 +354,3 @@ async def test_system_msg(hass: HomeAssistant) -> None:

assert response is None
assert cloud.client.relayer_region == "xx-earth-616"


async def test_create_cloud_assist_pipeline(
hass: HomeAssistant, mock_cloud_setup, mock_cloud_login
) -> None:
"""Test creating a cloud enabled assist pipeline."""
cloud_client: CloudClient = hass.data[DOMAIN].client
await cloud_client.cloud_started()
assert cloud_client.cloud_pipeline is None
assert len(async_get_pipelines(hass)) == 1

await cloud_client.create_cloud_assist_pipeline()
assert cloud_client.cloud_pipeline is not None
assert len(async_get_pipelines(hass)) == 2
assert async_get_pipeline(hass, cloud_client.cloud_pipeline) == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=cloud_client.cloud_pipeline,
language="en",
name="Home Assistant Cloud",
stt_engine="cloud",
stt_language="en-US",
tts_engine="cloud",
tts_language="en-US",
tts_voice="JennyNeural",
)
73 changes: 48 additions & 25 deletions tests/components/cloud/test_http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,45 +104,68 @@ async def test_google_actions_sync_fails(


async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
"""Test logging in."""
create_cloud_assist_pipeline_mock = AsyncMock()
hass.data["cloud"] = MagicMock(
login=AsyncMock(),
client=Mock(
cloud_pipeline="12345",
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock,
),
)
"""Test logging in when an assist pipeline is available."""
hass.data["cloud"] = MagicMock(login=AsyncMock())

req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_get_pipelines",
return_value=[
Mock(
conversation_engine="homeassistant",
id="12345",
stt_engine=DOMAIN,
tts_engine=DOMAIN,
)
],
), patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
) as create_pipeline_mock:
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)

assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": "12345"}
create_cloud_assist_pipeline_mock.assert_not_awaited()
create_pipeline_mock.assert_not_awaited()


async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) -> None:
"""Test logging in when no assist pipeline is available."""
create_cloud_assist_pipeline_mock = AsyncMock()
hass.data["cloud"] = MagicMock(
login=AsyncMock(),
client=Mock(
cloud_pipeline=None,
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock,
),
)
hass.data["cloud"] = MagicMock(login=AsyncMock())

req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
return_value=AsyncMock(id="12345"),
) as create_pipeline_mock:
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)

assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": "12345"}
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")


async def test_login_view_create_pipeline_fail(
hass: HomeAssistant, cloud_client
) -> None:
"""Test logging in when no assist pipeline is available."""
hass.data["cloud"] = MagicMock(login=AsyncMock())

with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
return_value=None,
) as create_pipeline_mock:
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)

assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": None}
create_cloud_assist_pipeline_mock.assert_awaited_once()
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")


async def test_login_view_random_exception(cloud_client) -> None:
Expand Down
Loading

0 comments on commit ed737f3

Please sign in to comment.