Skip to content

Commit

Permalink
Enable strict typing of assist_pipeline (home-assistant#91529)
Browse files Browse the repository at this point in the history
  • Loading branch information
emontnemery authored Apr 17, 2023
1 parent 9985516 commit 3367e86
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 13 deletions.
1 change: 1 addition & 0 deletions .strict-typing
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ homeassistant.components.anthemav.*
homeassistant.components.apcupsd.*
homeassistant.components.aqualogic.*
homeassistant.components.aseko_pool_live.*
homeassistant.components.assist_pipeline.*
homeassistant.components.asuswrt.*
homeassistant.components.auth.*
homeassistant.components.automation.*
Expand Down
19 changes: 11 additions & 8 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class PipelineRun:
tts_engine: str | None = None
tts_options: dict | None = None

def __post_init__(self):
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language

Expand All @@ -189,7 +189,7 @@ def __post_init__(self):
):
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)

def start(self):
def start(self) -> None:
"""Emit run start event."""
data = {
"pipeline": self.pipeline.name,
Expand All @@ -200,7 +200,7 @@ def start(self):

self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))

def end(self):
def end(self) -> None:
"""Emit run end event."""
self.event_callback(
PipelineEvent(
Expand Down Expand Up @@ -349,7 +349,9 @@ async def recognize_intent(
)
)

speech = conversation_result.response.speech.get("plain", {}).get("speech", "")
speech: str = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)

return speech

Expand Down Expand Up @@ -453,7 +455,7 @@ class PipelineInput:

conversation_id: str | None = None

async def execute(self):
async def execute(self) -> None:
"""Run pipeline."""
self.run.start()
current_stage = self.run.start_stage
Expand Down Expand Up @@ -496,7 +498,7 @@ async def execute(self):

self.run.end()

async def validate(self):
async def validate(self) -> None:
"""Validate pipeline input against start stage."""
if self.run.start_stage == PipelineStage.STT:
if self.stt_metadata is None:
Expand Down Expand Up @@ -524,7 +526,8 @@ async def validate(self):
prepare_tasks = []

if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata))
# self.stt_metadata can't be None or we'd raise above
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata)) # type: ignore[arg-type]

if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
prepare_tasks.append(self.run.prepare_recognize_intent())
Expand Down Expand Up @@ -696,7 +699,7 @@ async def ws_set_preferred_item(
connection.send_result(msg["id"])


async def async_setup_pipeline_store(hass):
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(
Store(hass, STORAGE_VERSION, STORAGE_KEY)
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/assist_pipeline/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ class VoiceCommandSegmenter:
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms

def __post_init__(self):
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self.reset()

def reset(self):
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._speech_seconds_left = self.speech_seconds
Expand Down
10 changes: 7 additions & 3 deletions homeassistant/components/assist_pipeline/websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Assist pipeline Websocket API."""
import asyncio
import audioop # pylint: disable=deprecated-module
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
import logging
from typing import Any

Expand Down Expand Up @@ -114,7 +114,7 @@ async def websocket_run(
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"]

async def stt_stream():
async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None
segmenter = VoiceCommandSegmenter()

Expand All @@ -129,7 +129,11 @@ async def stt_stream():

yield chunk

def handle_binary(_hass, _connection, data: bytes):
def handle_binary(
_hass: HomeAssistant,
_connection: websocket_api.ActiveConnection,
data: bytes,
) -> None:
# Forward to STT audio stream
audio_queue.put_nowait(data)

Expand Down
10 changes: 10 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true

[mypy-homeassistant.components.assist_pipeline.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true

[mypy-homeassistant.components.asuswrt.*]
check_untyped_defs = true
disallow_incomplete_defs = true
Expand Down

0 comments on commit 3367e86

Please sign in to comment.