From 6718cce203fbfb2566bca1c5ee7c894cf727502b Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 3 Nov 2024 20:45:09 -0800 Subject: [PATCH] Fix nest streams broken due to CameraCapabilities change (#129711) * Fix nest streams broken due to CameraCapabilities change * Fix stream cleanup * Apply suggestions from code review Co-authored-by: Paulus Schoutsen * Update homeassistant/components/nest/camera.py --------- Co-authored-by: Paulus Schoutsen --- homeassistant/components/nest/camera.py | 230 +++++++++++---------- tests/components/nest/test_camera.py | 79 ++++--- tests/components/nest/test_media_source.py | 7 +- 3 files changed, 181 insertions(+), 135 deletions(-) diff --git a/homeassistant/components/nest/camera.py b/homeassistant/components/nest/camera.py index 737c0a77bede17..30f96f819c1999 100644 --- a/homeassistant/components/nest/camera.py +++ b/homeassistant/components/nest/camera.py @@ -2,19 +2,17 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio from collections.abc import Callable import datetime import functools import logging from pathlib import Path -from typing import cast from google_nest_sdm.camera_traits import ( - CameraImageTrait, CameraLiveStreamTrait, RtspStream, - Stream, StreamingProtocol, WebRtcStream, ) @@ -57,19 +55,25 @@ async def async_setup_entry( device_manager: DeviceManager = hass.data[DOMAIN][entry.entry_id][ DATA_DEVICE_MANAGER ] - async_add_entities( - NestCamera(device) - for device in device_manager.devices.values() - if CameraImageTrait.NAME in device.traits - or CameraLiveStreamTrait.NAME in device.traits - ) + entities: list[NestCameraBaseEntity] = [] + for device in device_manager.devices.values(): + if (live_stream := device.traits.get(CameraLiveStreamTrait.NAME)) is None: + continue + if StreamingProtocol.WEB_RTC in live_stream.supported_protocols: + entities.append(NestWebRTCEntity(device)) + elif StreamingProtocol.RTSP in live_stream.supported_protocols: + entities.append(NestRTSPEntity(device)) + async_add_entities(entities) -class NestCamera(Camera): + +class NestCameraBaseEntity(Camera, ABC): """Devices that support cameras.""" _attr_has_entity_name = True _attr_name = None + _attr_is_streaming = True + _attr_supported_features = CameraEntityFeature.STREAM def __init__(self, device: Device) -> None: """Initialize the camera.""" @@ -79,39 +83,74 @@ def __init__(self, device: Device) -> None: self._attr_device_info = nest_device_info.device_info self._attr_brand = nest_device_info.device_brand self._attr_model = nest_device_info.device_model - self._rtsp_stream: RtspStream | None = None - self._webrtc_sessions: dict[str, WebRtcStream] = {} - self._create_stream_url_lock = asyncio.Lock() - self._stream_refresh_unsub: Callable[[], None] | None = None - self._attr_is_streaming = False - self._attr_supported_features = CameraEntityFeature(0) - self._rtsp_live_stream_trait: CameraLiveStreamTrait | None = None - if CameraLiveStreamTrait.NAME in self._device.traits: - self._attr_is_streaming = True - self._attr_supported_features |= CameraEntityFeature.STREAM - trait = cast( - CameraLiveStreamTrait, self._device.traits[CameraLiveStreamTrait.NAME] - ) - if StreamingProtocol.RTSP in trait.supported_protocols: - self._rtsp_live_stream_trait = trait self.stream_options[CONF_EXTRA_PART_WAIT_TIME] = 3 # The API "name" field is a unique device identifier. self._attr_unique_id = f"{self._device.name}-camera" + self._stream_refresh_unsub: Callable[[], None] | None = None - @property - def use_stream_for_stills(self) -> bool: - """Whether or not to use stream to generate stills.""" - return self._rtsp_live_stream_trait is not None + @abstractmethod + def _stream_expires_at(self) -> datetime.datetime | None: + """Next time when a stream expires.""" + + @abstractmethod + async def _async_refresh_stream(self) -> None: + """Refresh any stream to extend expiration time.""" + + def _schedule_stream_refresh(self) -> None: + """Schedules an alarm to refresh any streams before expiration.""" + if self._stream_refresh_unsub is not None: + self._stream_refresh_unsub() + + expiration_time = self._stream_expires_at() + if not expiration_time: + return + refresh_time = expiration_time - STREAM_EXPIRATION_BUFFER + _LOGGER.debug("Scheduled next stream refresh for %s", refresh_time) + + self._stream_refresh_unsub = async_track_point_in_utc_time( + self.hass, + self._handle_stream_refresh, + refresh_time, + ) + + async def _handle_stream_refresh(self, _: datetime.datetime) -> None: + """Alarm that fires to check if the stream should be refreshed.""" + _LOGGER.debug("Examining streams to refresh") + self._stream_refresh_unsub = None + try: + await self._async_refresh_stream() + finally: + self._schedule_stream_refresh() + + async def async_added_to_hass(self) -> None: + """Run when entity is added to register update signal handler.""" + self.async_on_remove( + self._device.add_update_listener(self.async_write_ha_state) + ) + + async def async_will_remove_from_hass(self) -> None: + """Invalidates the RTSP token when unloaded.""" + await super().async_will_remove_from_hass() + if self._stream_refresh_unsub: + self._stream_refresh_unsub() + + +class NestRTSPEntity(NestCameraBaseEntity): + """Nest cameras that use RTSP.""" + + _rtsp_stream: RtspStream | None = None + _rtsp_live_stream_trait: CameraLiveStreamTrait + + def __init__(self, device: Device) -> None: + """Initialize the camera.""" + super().__init__(device) + self._create_stream_url_lock = asyncio.Lock() + self._rtsp_live_stream_trait = device.traits[CameraLiveStreamTrait.NAME] @property - def frontend_stream_type(self) -> StreamType | None: - """Return the type of stream supported by this camera.""" - if CameraLiveStreamTrait.NAME not in self._device.traits: - return None - trait = self._device.traits[CameraLiveStreamTrait.NAME] - if StreamingProtocol.WEB_RTC in trait.supported_protocols: - return StreamType.WEB_RTC - return super().frontend_stream_type + def use_stream_for_stills(self) -> bool: + """Always use the RTSP stream to generate snapshots.""" + return True @property def available(self) -> bool: @@ -125,8 +164,6 @@ def available(self) -> bool: async def stream_source(self) -> str | None: """Return the source of the stream.""" - if not self._rtsp_live_stream_trait: - return None async with self._create_stream_url_lock: if not self._rtsp_stream: _LOGGER.debug("Fetching stream url") @@ -142,50 +179,14 @@ async def stream_source(self) -> str | None: _LOGGER.warning("Stream already expired") return self._rtsp_stream.rtsp_stream_url - def _all_streams(self) -> list[Stream]: - """Return the current list of active streams.""" - streams: list[Stream] = [] - if self._rtsp_stream: - streams.append(self._rtsp_stream) - streams.extend(list(self._webrtc_sessions.values())) - return streams + def _stream_expires_at(self) -> datetime.datetime | None: + """Next time when a stream expires.""" + return self._rtsp_stream.expires_at if self._rtsp_stream else None - def _schedule_stream_refresh(self) -> None: - """Schedules an alarm to refresh any streams before expiration.""" - # Schedule an alarm to extend the stream - if self._stream_refresh_unsub is not None: - self._stream_refresh_unsub() - - _LOGGER.debug("Scheduling next stream refresh") - expiration_times = [stream.expires_at for stream in self._all_streams()] - if not expiration_times: - _LOGGER.debug("No streams to refresh") - return - - refresh_time = min(expiration_times) - STREAM_EXPIRATION_BUFFER - _LOGGER.debug("Scheduled next stream refresh for %s", refresh_time) - - self._stream_refresh_unsub = async_track_point_in_utc_time( - self.hass, - self._handle_stream_refresh, - refresh_time, - ) - - async def _handle_stream_refresh(self, _: datetime.datetime) -> None: - """Alarm that fires to check if the stream should be refreshed.""" - _LOGGER.debug("Examining streams to refresh") - await self._handle_rtsp_stream_refresh() - await self._handle_webrtc_stream_refresh() - self._schedule_stream_refresh() - - async def _handle_rtsp_stream_refresh(self) -> None: - """Alarm that fires to check if the stream should be refreshed.""" + async def _async_refresh_stream(self) -> None: + """Refresh stream to extend expiration time.""" if not self._rtsp_stream: return - now = utcnow() - refresh_time = self._rtsp_stream.expires_at - STREAM_EXPIRATION_BUFFER - if now < refresh_time: - return _LOGGER.debug("Extending RTSP stream") try: self._rtsp_stream = await self._rtsp_stream.extend_rtsp_stream() @@ -201,8 +202,38 @@ async def _handle_rtsp_stream_refresh(self) -> None: if self.stream: self.stream.update_source(self._rtsp_stream.rtsp_stream_url) - async def _handle_webrtc_stream_refresh(self) -> None: - """Alarm that fires to check if the stream should be refreshed.""" + async def async_will_remove_from_hass(self) -> None: + """Invalidates the RTSP token when unloaded.""" + await super().async_will_remove_from_hass() + if self._rtsp_stream: + try: + await self._rtsp_stream.stop_stream() + except ApiException as err: + _LOGGER.debug("Error stopping stream: %s", err) + self._rtsp_stream = None + + +class NestWebRTCEntity(NestCameraBaseEntity): + """Nest cameras that use WebRTC.""" + + def __init__(self, device: Device) -> None: + """Initialize the camera.""" + super().__init__(device) + self._webrtc_sessions: dict[str, WebRtcStream] = {} + + @property + def frontend_stream_type(self) -> StreamType | None: + """Return the type of stream supported by this camera.""" + return StreamType.WEB_RTC + + def _stream_expires_at(self) -> datetime.datetime | None: + """Next time when a stream expires.""" + if not self._webrtc_sessions: + return None + return min(stream.expires_at for stream in self._webrtc_sessions.values()) + + async def _async_refresh_stream(self) -> None: + """Refresh stream to extend expiration time.""" now = utcnow() for webrtc_stream in list(self._webrtc_sessions.values()): if now < (webrtc_stream.expires_at - STREAM_EXPIRATION_BUFFER): @@ -218,32 +249,10 @@ async def _handle_webrtc_stream_refresh(self) -> None: else: self._webrtc_sessions[webrtc_stream.media_session_id] = webrtc_stream - async def async_will_remove_from_hass(self) -> None: - """Invalidates the RTSP token when unloaded.""" - for stream in self._all_streams(): - _LOGGER.debug("Invalidating stream") - try: - await stream.stop_stream() - except ApiException as err: - _LOGGER.debug("Error stopping stream: %s", err) - self._rtsp_stream = None - self._webrtc_sessions.clear() - - if self._stream_refresh_unsub: - self._stream_refresh_unsub() - - async def async_added_to_hass(self) -> None: - """Run when entity is added to register update signal handler.""" - self.async_on_remove( - self._device.add_update_listener(self.async_write_ha_state) - ) - async def async_camera_image( self, width: int | None = None, height: int | None = None ) -> bytes | None: - """Return bytes of camera image.""" - # Use the thumbnail from RTSP stream, or a placeholder if stream is - # not supported (e.g. WebRTC) as a fallback when 'use_stream_for_stills' if False + """Return a placeholder image for WebRTC cameras that don't support snapshots.""" return await self.hass.async_add_executor_job(self.placeholder_image) @classmethod @@ -257,11 +266,6 @@ async def async_handle_async_webrtc_offer( ) -> None: """Return the source of the stream.""" trait: CameraLiveStreamTrait = self._device.traits[CameraLiveStreamTrait.NAME] - if StreamingProtocol.WEB_RTC not in trait.supported_protocols: - await super().async_handle_async_webrtc_offer( - offer_sdp, session_id, send_message - ) - return try: stream = await trait.generate_web_rtc_stream(offer_sdp) except ApiException as err: @@ -294,3 +298,9 @@ async def stop_stream() -> None: def _async_get_webrtc_client_configuration(self) -> WebRTCClientConfiguration: """Return the WebRTC client configuration adjustable per integration.""" return WebRTCClientConfiguration(data_channel="dataSendChannel") + + async def async_will_remove_from_hass(self) -> None: + """Invalidates the RTSP token when unloaded.""" + await super().async_will_remove_from_hass() + for session_id in list(self._webrtc_sessions.keys()): + self.close_webrtc_session(session_id) diff --git a/tests/components/nest/test_camera.py b/tests/components/nest/test_camera.py index 6417fa4ebe9132..500dbc0f46f059 100644 --- a/tests/components/nest/test_camera.py +++ b/tests/components/nest/test_camera.py @@ -28,7 +28,7 @@ from .conftest import FakeAuth from tests.common import async_fire_time_changed -from tests.typing import WebSocketGenerator +from tests.typing import MockHAClientWebSocket, WebSocketGenerator PLATFORM = "camera" CAMERA_DEVICE_TYPE = "sdm.devices.types.CAMERA" @@ -176,6 +176,30 @@ async def async_get_image( return image.content +def get_frontend_stream_type_attribute( + hass: HomeAssistant, entity_id: str +) -> StreamType: + """Get the frontend_stream_type camera attribute.""" + cam = hass.states.get(entity_id) + assert cam is not None + assert cam.state == CameraState.STREAMING + return cam.attributes.get("frontend_stream_type") + + +async def async_frontend_stream_types( + client: MockHAClientWebSocket, entity_id: str +) -> list[str] | None: + """Get the frontend stream types supported.""" + await client.send_json_auto_id( + {"type": "camera/capabilities", "entity_id": entity_id} + ) + msg = await client.receive_json() + assert msg.get("type") == TYPE_RESULT + assert msg.get("success") + assert msg.get("result") + return msg["result"].get("frontend_stream_types") + + async def fire_alarm(hass: HomeAssistant, point_in_time: datetime.datetime) -> None: """Fire an alarm and wait for callbacks to run.""" with freeze_time(point_in_time): @@ -237,16 +261,21 @@ async def test_camera_stream( camera_device: None, auth: FakeAuth, mock_create_stream: Mock, + hass_ws_client: WebSocketGenerator, ) -> None: """Test a basic camera and fetch its live stream.""" auth.responses = [make_stream_url_response()] await setup_platform() assert len(hass.states.async_all()) == 1 - cam = hass.states.get("camera.my_camera") - assert cam is not None - assert cam.state == CameraState.STREAMING - assert cam.attributes["frontend_stream_type"] == StreamType.HLS + assert ( + get_frontend_stream_type_attribute(hass, "camera.my_camera") == StreamType.HLS + ) + client = await hass_ws_client(hass) + frontend_stream_types = await async_frontend_stream_types( + client, "camera.my_camera" + ) + assert frontend_stream_types == [StreamType.HLS] stream_source = await camera.async_get_stream_source(hass, "camera.my_camera") assert stream_source == "rtsp://some/url?auth=g.0.streamingToken" @@ -265,12 +294,16 @@ async def test_camera_ws_stream( await setup_platform() assert len(hass.states.async_all()) == 1 - cam = hass.states.get("camera.my_camera") - assert cam is not None - assert cam.state == CameraState.STREAMING - assert cam.attributes["frontend_stream_type"] == StreamType.HLS + assert ( + get_frontend_stream_type_attribute(hass, "camera.my_camera") == StreamType.HLS + ) client = await hass_ws_client(hass) + frontend_stream_types = await async_frontend_stream_types( + client, "camera.my_camera" + ) + assert frontend_stream_types == [StreamType.HLS] + await client.send_json( { "id": 2, @@ -322,7 +355,7 @@ async def test_camera_ws_stream_failure( async def test_camera_stream_missing_trait( hass: HomeAssistant, setup_platform, create_device ) -> None: - """Test fetching a video stream when not supported by the API.""" + """Test that cameras missing a live stream are not supported.""" create_device.create( { "sdm.devices.traits.Info": { @@ -338,16 +371,7 @@ async def test_camera_stream_missing_trait( ) await setup_platform() - assert len(hass.states.async_all()) == 1 - cam = hass.states.get("camera.my_camera") - assert cam is not None - assert cam.state == CameraState.IDLE - - stream_source = await camera.async_get_stream_source(hass, "camera.my_camera") - assert stream_source is None - - # Fallback to placeholder image - await async_get_image(hass) + assert len(hass.states.async_all()) == 0 async def test_refresh_expired_stream_token( @@ -655,6 +679,15 @@ async def test_camera_web_rtc_unsupported( assert cam.attributes["frontend_stream_type"] == StreamType.HLS client = await hass_ws_client(hass) + await client.send_json_auto_id( + {"type": "camera/capabilities", "entity_id": "camera.my_camera"} + ) + msg = await client.receive_json() + + assert msg["type"] == TYPE_RESULT + assert msg["success"] + assert msg["result"] == {"frontend_stream_types": ["hls"]} + await client.send_json_auto_id( { "type": "camera/webrtc/offer", @@ -732,8 +765,6 @@ async def test_camera_multiple_streams( """Test a camera supporting multiple stream types.""" expiration = utcnow() + datetime.timedelta(seconds=100) auth.responses = [ - # RTSP response - make_stream_url_response(), # WebRTC response aiohttp.web.json_response( { @@ -770,9 +801,9 @@ async def test_camera_multiple_streams( # Prefer WebRTC over RTSP/HLS assert cam.attributes["frontend_stream_type"] == StreamType.WEB_RTC - # RTSP stream + # RTSP stream is not supported stream_source = await camera.async_get_stream_source(hass, "camera.my_camera") - assert stream_source == "rtsp://some/url?auth=g.0.streamingToken" + assert not stream_source # WebRTC stream client = await hass_ws_client(hass) diff --git a/tests/components/nest/test_media_source.py b/tests/components/nest/test_media_source.py index 101bfae089d3e7..2526bfdf975f6e 100644 --- a/tests/components/nest/test_media_source.py +++ b/tests/components/nest/test_media_source.py @@ -48,6 +48,9 @@ "customName": DEVICE_NAME, }, "sdm.devices.traits.CameraImage": {}, + "sdm.devices.traits.CameraLiveStream": { + "supportedProtocols": ["RTSP"], + }, "sdm.devices.traits.CameraEventImage": {}, "sdm.devices.traits.CameraPerson": {}, "sdm.devices.traits.CameraMotion": {}, @@ -57,7 +60,9 @@ "customName": DEVICE_NAME, }, "sdm.devices.traits.CameraClipPreview": {}, - "sdm.devices.traits.CameraLiveStream": {}, + "sdm.devices.traits.CameraLiveStream": { + "supportedProtocols": ["WEB_RTC"], + }, "sdm.devices.traits.CameraPerson": {}, "sdm.devices.traits.CameraMotion": {}, }