Skip to content

Commit

Permalink
Use STT/TTS languages for LLM fallback (home-assistant#135533)
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam authored Jan 13, 2025
1 parent 3e9b410 commit b897e6a
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 6 deletions.
15 changes: 12 additions & 3 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,9 +1021,18 @@ async def recognize_intent(
raise RuntimeError("Recognize intent was not prepared")

if self.pipeline.conversation_language == MATCH_ALL:
# LLMs support all languages ('*') so use pipeline language for
# intent fallback.
input_language = self.pipeline.language
# LLMs support all languages ('*') so use languages from the
# pipeline for intent fallback.
#
# We prioritize the STT and TTS languages because they may be more
# specific, such as "zh-CN" instead of just "zh". This is necessary
# for languages whose intents are split out by region when
# preferring local intent matching.
input_language = (
self.pipeline.stt_language
or self.pipeline.tts_language
or self.pipeline.language
)
else:
input_language = self.pipeline.conversation_language

Expand Down
102 changes: 102 additions & 0 deletions tests/components/assist_pipeline/snapshots/test_init.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,108 @@
}),
])
# ---
# name: test_stt_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en-US',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_tts_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': None,
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en-us',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_wake_word_detection_aborted
list([
dict({
Expand Down
154 changes: 151 additions & 3 deletions tests/components/assist_pipeline/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,13 +1102,13 @@ async def async_handle(
)


async def test_pipeline_language_used_instead_of_conversation_language(
async def test_stt_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the pipeline language is used when the conversation language is '*' (all languages)."""
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)

events: list[assist_pipeline.PipelineEvent] = []
Expand Down Expand Up @@ -1165,7 +1165,155 @@ async def test_pipeline_language_used_instead_of_conversation_language(

assert intent_start is not None

# Pipeline language (en) should be used instead of '*'
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.stt_language

# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.stt_language
)


async def test_tts_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)

events: list[assist_pipeline.PipelineEvent] = []

await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": "en-us",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)

pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()

with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()

# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break

assert intent_start is not None

# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.tts_language

# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.tts_language
)


async def test_pipeline_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)

events: list[assist_pipeline.PipelineEvent] = []

await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)

pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()

with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()

# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break

assert intent_start is not None

# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.language

# Check input to async_converse
Expand Down

0 comments on commit b897e6a

Please sign in to comment.