Skip to content

Commit

Permalink
Drop language parameter from async_get_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
emontnemery committed Apr 18, 2023
1 parent f3897d8 commit 596b379
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 58 deletions.
14 changes: 1 addition & 13 deletions homeassistant/components/assist_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,16 @@ async def async_pipeline_from_audio_stream(
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes],
language: str | None = None,
pipeline_id: str | None = None,
conversation_id: str | None = None,
context: Context | None = None,
tts_options: dict | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
if language is None and pipeline_id is None:
language = hass.config.language

# Temporary workaround for language codes
if language == "en":
language = "en-US"

if context is None:
context = Context()

pipeline = await async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,
)
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
Expand Down
9 changes: 4 additions & 5 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,19 @@


async def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
hass: HomeAssistant, pipeline_id: str | None = None
) -> Pipeline | None:
"""Get a pipeline by id or create one for a language."""
pipeline_data: PipelineData = hass.data[DOMAIN]

if pipeline_id is not None:
return pipeline_data.pipeline_store.data.get(pipeline_id)

# Construct a pipeline for the required/configured language
language = language or hass.config.language
# Construct a pipeline for the configured language
return await pipeline_data.pipeline_store.async_create_item(
{
"name": language,
"language": language,
"name": hass.config.language,
"language": hass.config.language,
"stt_engine": None, # first engine
"conversation_engine": None, # first agent
"tts_engine": None, # first engine
Expand Down
17 changes: 3 additions & 14 deletions homeassistant/components/assist_pipeline/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): dict,
vol.Optional("language"): str,
vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
Expand Down Expand Up @@ -82,23 +81,13 @@ async def websocket_run(
msg: dict[str, Any],
) -> None:
"""Run a pipeline."""
language = msg.get("language", hass.config.language)

# Temporary workaround for language codes
if language == "en":
language = "en-US"

pipeline_id = msg.get("pipeline")
pipeline = await async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,
)
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
connection.send_error(
msg["id"],
"pipeline-not-found",
f"Pipeline not found: id={pipeline_id}, language={language}",
f"Pipeline not found: id={pipeline_id}",
)
return

Expand Down Expand Up @@ -147,7 +136,7 @@ def handle_binary(

# Audio input must be raw PCM at 16Khz with 16-bit mono samples
input_args["stt_metadata"] = stt.SpeechMetadata(
language=language,
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
Expand Down
8 changes: 4 additions & 4 deletions tests/components/assist_pipeline/snapshots/test_init.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
list([
dict({
'data': dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
Expand Down Expand Up @@ -47,7 +47,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
Expand All @@ -70,7 +70,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
Expand Down
44 changes: 22 additions & 22 deletions tests/components/assist_pipeline/snapshots/test_websocket.ambr
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# serializer version: 1
# name: test_audio_pipeline
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
Expand Down Expand Up @@ -45,7 +45,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
Expand All @@ -66,16 +66,16 @@
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_debug
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
Expand Down Expand Up @@ -118,7 +118,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
Expand All @@ -139,16 +139,16 @@
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
})
# ---
# name: test_intent_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
Expand All @@ -163,8 +163,8 @@
# ---
# name: test_intent_timeout
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 0.1,
Expand All @@ -185,8 +185,8 @@
# ---
# name: test_stt_provider_missing
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
Expand All @@ -201,15 +201,15 @@
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'language': 'en',
'sample_rate': 16000,
}),
})
# ---
# name: test_stt_stream_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
Expand All @@ -231,8 +231,8 @@
# ---
# name: test_text_only_pipeline
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
Expand All @@ -255,7 +255,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
Expand All @@ -275,8 +275,8 @@
# ---
# name: test_tts_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
Expand Down

0 comments on commit 596b379

Please sign in to comment.