Skip to content

Commit

Permalink
Fix nest streams broken due to CameraCapabilities change (home-assist…
Browse files Browse the repository at this point in the history
…ant#129711)

* Fix nest streams broken due to CameraCapabilities change

* Fix stream cleanup

* Apply suggestions from code review

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Update homeassistant/components/nest/camera.py

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
  • Loading branch information
allenporter and balloob authored Nov 4, 2024
1 parent 49f0bb6 commit 6718cce
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 135 deletions.
230 changes: 120 additions & 110 deletions homeassistant/components/nest/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@

from __future__ import annotations

from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
import datetime
import functools
import logging
from pathlib import Path
from typing import cast

from google_nest_sdm.camera_traits import (
CameraImageTrait,
CameraLiveStreamTrait,
RtspStream,
Stream,
StreamingProtocol,
WebRtcStream,
)
Expand Down Expand Up @@ -57,19 +55,25 @@ async def async_setup_entry(
device_manager: DeviceManager = hass.data[DOMAIN][entry.entry_id][
DATA_DEVICE_MANAGER
]
async_add_entities(
NestCamera(device)
for device in device_manager.devices.values()
if CameraImageTrait.NAME in device.traits
or CameraLiveStreamTrait.NAME in device.traits
)
entities: list[NestCameraBaseEntity] = []
for device in device_manager.devices.values():
if (live_stream := device.traits.get(CameraLiveStreamTrait.NAME)) is None:
continue
if StreamingProtocol.WEB_RTC in live_stream.supported_protocols:
entities.append(NestWebRTCEntity(device))
elif StreamingProtocol.RTSP in live_stream.supported_protocols:
entities.append(NestRTSPEntity(device))

async_add_entities(entities)

class NestCamera(Camera):

class NestCameraBaseEntity(Camera, ABC):
"""Devices that support cameras."""

_attr_has_entity_name = True
_attr_name = None
_attr_is_streaming = True
_attr_supported_features = CameraEntityFeature.STREAM

def __init__(self, device: Device) -> None:
"""Initialize the camera."""
Expand All @@ -79,39 +83,74 @@ def __init__(self, device: Device) -> None:
self._attr_device_info = nest_device_info.device_info
self._attr_brand = nest_device_info.device_brand
self._attr_model = nest_device_info.device_model
self._rtsp_stream: RtspStream | None = None
self._webrtc_sessions: dict[str, WebRtcStream] = {}
self._create_stream_url_lock = asyncio.Lock()
self._stream_refresh_unsub: Callable[[], None] | None = None
self._attr_is_streaming = False
self._attr_supported_features = CameraEntityFeature(0)
self._rtsp_live_stream_trait: CameraLiveStreamTrait | None = None
if CameraLiveStreamTrait.NAME in self._device.traits:
self._attr_is_streaming = True
self._attr_supported_features |= CameraEntityFeature.STREAM
trait = cast(
CameraLiveStreamTrait, self._device.traits[CameraLiveStreamTrait.NAME]
)
if StreamingProtocol.RTSP in trait.supported_protocols:
self._rtsp_live_stream_trait = trait
self.stream_options[CONF_EXTRA_PART_WAIT_TIME] = 3
# The API "name" field is a unique device identifier.
self._attr_unique_id = f"{self._device.name}-camera"
self._stream_refresh_unsub: Callable[[], None] | None = None

@property
def use_stream_for_stills(self) -> bool:
"""Whether or not to use stream to generate stills."""
return self._rtsp_live_stream_trait is not None
@abstractmethod
def _stream_expires_at(self) -> datetime.datetime | None:
"""Next time when a stream expires."""

@abstractmethod
async def _async_refresh_stream(self) -> None:
"""Refresh any stream to extend expiration time."""

def _schedule_stream_refresh(self) -> None:
"""Schedules an alarm to refresh any streams before expiration."""
if self._stream_refresh_unsub is not None:
self._stream_refresh_unsub()

expiration_time = self._stream_expires_at()
if not expiration_time:
return
refresh_time = expiration_time - STREAM_EXPIRATION_BUFFER
_LOGGER.debug("Scheduled next stream refresh for %s", refresh_time)

self._stream_refresh_unsub = async_track_point_in_utc_time(
self.hass,
self._handle_stream_refresh,
refresh_time,
)

async def _handle_stream_refresh(self, _: datetime.datetime) -> None:
"""Alarm that fires to check if the stream should be refreshed."""
_LOGGER.debug("Examining streams to refresh")
self._stream_refresh_unsub = None
try:
await self._async_refresh_stream()
finally:
self._schedule_stream_refresh()

async def async_added_to_hass(self) -> None:
"""Run when entity is added to register update signal handler."""
self.async_on_remove(
self._device.add_update_listener(self.async_write_ha_state)
)

async def async_will_remove_from_hass(self) -> None:
"""Invalidates the RTSP token when unloaded."""
await super().async_will_remove_from_hass()
if self._stream_refresh_unsub:
self._stream_refresh_unsub()


class NestRTSPEntity(NestCameraBaseEntity):
"""Nest cameras that use RTSP."""

_rtsp_stream: RtspStream | None = None
_rtsp_live_stream_trait: CameraLiveStreamTrait

def __init__(self, device: Device) -> None:
"""Initialize the camera."""
super().__init__(device)
self._create_stream_url_lock = asyncio.Lock()
self._rtsp_live_stream_trait = device.traits[CameraLiveStreamTrait.NAME]

@property
def frontend_stream_type(self) -> StreamType | None:
"""Return the type of stream supported by this camera."""
if CameraLiveStreamTrait.NAME not in self._device.traits:
return None
trait = self._device.traits[CameraLiveStreamTrait.NAME]
if StreamingProtocol.WEB_RTC in trait.supported_protocols:
return StreamType.WEB_RTC
return super().frontend_stream_type
def use_stream_for_stills(self) -> bool:
"""Always use the RTSP stream to generate snapshots."""
return True

@property
def available(self) -> bool:
Expand All @@ -125,8 +164,6 @@ def available(self) -> bool:

async def stream_source(self) -> str | None:
"""Return the source of the stream."""
if not self._rtsp_live_stream_trait:
return None
async with self._create_stream_url_lock:
if not self._rtsp_stream:
_LOGGER.debug("Fetching stream url")
Expand All @@ -142,50 +179,14 @@ async def stream_source(self) -> str | None:
_LOGGER.warning("Stream already expired")
return self._rtsp_stream.rtsp_stream_url

def _all_streams(self) -> list[Stream]:
"""Return the current list of active streams."""
streams: list[Stream] = []
if self._rtsp_stream:
streams.append(self._rtsp_stream)
streams.extend(list(self._webrtc_sessions.values()))
return streams
def _stream_expires_at(self) -> datetime.datetime | None:
"""Next time when a stream expires."""
return self._rtsp_stream.expires_at if self._rtsp_stream else None

def _schedule_stream_refresh(self) -> None:
"""Schedules an alarm to refresh any streams before expiration."""
# Schedule an alarm to extend the stream
if self._stream_refresh_unsub is not None:
self._stream_refresh_unsub()

_LOGGER.debug("Scheduling next stream refresh")
expiration_times = [stream.expires_at for stream in self._all_streams()]
if not expiration_times:
_LOGGER.debug("No streams to refresh")
return

refresh_time = min(expiration_times) - STREAM_EXPIRATION_BUFFER
_LOGGER.debug("Scheduled next stream refresh for %s", refresh_time)

self._stream_refresh_unsub = async_track_point_in_utc_time(
self.hass,
self._handle_stream_refresh,
refresh_time,
)

async def _handle_stream_refresh(self, _: datetime.datetime) -> None:
"""Alarm that fires to check if the stream should be refreshed."""
_LOGGER.debug("Examining streams to refresh")
await self._handle_rtsp_stream_refresh()
await self._handle_webrtc_stream_refresh()
self._schedule_stream_refresh()

async def _handle_rtsp_stream_refresh(self) -> None:
"""Alarm that fires to check if the stream should be refreshed."""
async def _async_refresh_stream(self) -> None:
"""Refresh stream to extend expiration time."""
if not self._rtsp_stream:
return
now = utcnow()
refresh_time = self._rtsp_stream.expires_at - STREAM_EXPIRATION_BUFFER
if now < refresh_time:
return
_LOGGER.debug("Extending RTSP stream")
try:
self._rtsp_stream = await self._rtsp_stream.extend_rtsp_stream()
Expand All @@ -201,8 +202,38 @@ async def _handle_rtsp_stream_refresh(self) -> None:
if self.stream:
self.stream.update_source(self._rtsp_stream.rtsp_stream_url)

async def _handle_webrtc_stream_refresh(self) -> None:
"""Alarm that fires to check if the stream should be refreshed."""
async def async_will_remove_from_hass(self) -> None:
"""Invalidates the RTSP token when unloaded."""
await super().async_will_remove_from_hass()
if self._rtsp_stream:
try:
await self._rtsp_stream.stop_stream()
except ApiException as err:
_LOGGER.debug("Error stopping stream: %s", err)
self._rtsp_stream = None


class NestWebRTCEntity(NestCameraBaseEntity):
"""Nest cameras that use WebRTC."""

def __init__(self, device: Device) -> None:
"""Initialize the camera."""
super().__init__(device)
self._webrtc_sessions: dict[str, WebRtcStream] = {}

@property
def frontend_stream_type(self) -> StreamType | None:
"""Return the type of stream supported by this camera."""
return StreamType.WEB_RTC

def _stream_expires_at(self) -> datetime.datetime | None:
"""Next time when a stream expires."""
if not self._webrtc_sessions:
return None
return min(stream.expires_at for stream in self._webrtc_sessions.values())

async def _async_refresh_stream(self) -> None:
"""Refresh stream to extend expiration time."""
now = utcnow()
for webrtc_stream in list(self._webrtc_sessions.values()):
if now < (webrtc_stream.expires_at - STREAM_EXPIRATION_BUFFER):
Expand All @@ -218,32 +249,10 @@ async def _handle_webrtc_stream_refresh(self) -> None:
else:
self._webrtc_sessions[webrtc_stream.media_session_id] = webrtc_stream

async def async_will_remove_from_hass(self) -> None:
"""Invalidates the RTSP token when unloaded."""
for stream in self._all_streams():
_LOGGER.debug("Invalidating stream")
try:
await stream.stop_stream()
except ApiException as err:
_LOGGER.debug("Error stopping stream: %s", err)
self._rtsp_stream = None
self._webrtc_sessions.clear()

if self._stream_refresh_unsub:
self._stream_refresh_unsub()

async def async_added_to_hass(self) -> None:
"""Run when entity is added to register update signal handler."""
self.async_on_remove(
self._device.add_update_listener(self.async_write_ha_state)
)

async def async_camera_image(
self, width: int | None = None, height: int | None = None
) -> bytes | None:
"""Return bytes of camera image."""
# Use the thumbnail from RTSP stream, or a placeholder if stream is
# not supported (e.g. WebRTC) as a fallback when 'use_stream_for_stills' if False
"""Return a placeholder image for WebRTC cameras that don't support snapshots."""
return await self.hass.async_add_executor_job(self.placeholder_image)

@classmethod
Expand All @@ -257,11 +266,6 @@ async def async_handle_async_webrtc_offer(
) -> None:
"""Return the source of the stream."""
trait: CameraLiveStreamTrait = self._device.traits[CameraLiveStreamTrait.NAME]
if StreamingProtocol.WEB_RTC not in trait.supported_protocols:
await super().async_handle_async_webrtc_offer(
offer_sdp, session_id, send_message
)
return
try:
stream = await trait.generate_web_rtc_stream(offer_sdp)
except ApiException as err:
Expand Down Expand Up @@ -294,3 +298,9 @@ async def stop_stream() -> None:
def _async_get_webrtc_client_configuration(self) -> WebRTCClientConfiguration:
"""Return the WebRTC client configuration adjustable per integration."""
return WebRTCClientConfiguration(data_channel="dataSendChannel")

async def async_will_remove_from_hass(self) -> None:
"""Invalidates the RTSP token when unloaded."""
await super().async_will_remove_from_hass()
for session_id in list(self._webrtc_sessions.keys()):
self.close_webrtc_session(session_id)
Loading

0 comments on commit 6718cce

Please sign in to comment.