forked from home-assistant/core
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Google Cloud Speech-to-Text (STT) (home-assistant#120854)
* Google Cloud * . * fix * mypy * add tests * Update .coveragerc * Update const.py * upload file, reconfigure and import flow * fixes * default to latest_short * mypy * update * Allow clearing options in the UI * update * update * update
- Loading branch information
1 parent
0817474
commit 8b28334
Showing
9 changed files
with
345 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
"""Support for the Google Cloud STT service.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import AsyncGenerator, AsyncIterable | ||
import logging | ||
|
||
from google.api_core.exceptions import GoogleAPIError, Unauthenticated | ||
from google.cloud import speech_v1 | ||
|
||
from homeassistant.components.stt import ( | ||
AudioBitRates, | ||
AudioChannels, | ||
AudioCodecs, | ||
AudioFormats, | ||
AudioSampleRates, | ||
SpeechMetadata, | ||
SpeechResult, | ||
SpeechResultState, | ||
SpeechToTextEntity, | ||
) | ||
from homeassistant.config_entries import ConfigEntry | ||
from homeassistant.core import HomeAssistant | ||
from homeassistant.helpers import device_registry as dr | ||
from homeassistant.helpers.entity_platform import AddEntitiesCallback | ||
|
||
from .const import ( | ||
CONF_SERVICE_ACCOUNT_INFO, | ||
CONF_STT_MODEL, | ||
DEFAULT_STT_MODEL, | ||
DOMAIN, | ||
STT_LANGUAGES, | ||
) | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
async def async_setup_entry( | ||
hass: HomeAssistant, | ||
config_entry: ConfigEntry, | ||
async_add_entities: AddEntitiesCallback, | ||
) -> None: | ||
"""Set up Google Cloud speech platform via config entry.""" | ||
service_account_info = config_entry.data[CONF_SERVICE_ACCOUNT_INFO] | ||
client = speech_v1.SpeechAsyncClient.from_service_account_info(service_account_info) | ||
async_add_entities([GoogleCloudSpeechToTextEntity(config_entry, client)]) | ||
|
||
|
||
class GoogleCloudSpeechToTextEntity(SpeechToTextEntity): | ||
"""Google Cloud STT entity.""" | ||
|
||
def __init__( | ||
self, | ||
entry: ConfigEntry, | ||
client: speech_v1.SpeechAsyncClient, | ||
) -> None: | ||
"""Init Google Cloud STT entity.""" | ||
self._attr_unique_id = f"{entry.entry_id}-stt" | ||
self._attr_name = entry.title | ||
self._attr_device_info = dr.DeviceInfo( | ||
identifiers={(DOMAIN, entry.entry_id)}, | ||
manufacturer="Google", | ||
model="Cloud", | ||
entry_type=dr.DeviceEntryType.SERVICE, | ||
) | ||
self._entry = entry | ||
self._client = client | ||
self._model = entry.options.get(CONF_STT_MODEL, DEFAULT_STT_MODEL) | ||
|
||
@property | ||
def supported_languages(self) -> list[str]: | ||
"""Return a list of supported languages.""" | ||
return STT_LANGUAGES | ||
|
||
@property | ||
def supported_formats(self) -> list[AudioFormats]: | ||
"""Return a list of supported formats.""" | ||
return [AudioFormats.WAV, AudioFormats.OGG] | ||
|
||
@property | ||
def supported_codecs(self) -> list[AudioCodecs]: | ||
"""Return a list of supported codecs.""" | ||
return [AudioCodecs.PCM, AudioCodecs.OPUS] | ||
|
||
@property | ||
def supported_bit_rates(self) -> list[AudioBitRates]: | ||
"""Return a list of supported bitrates.""" | ||
return [AudioBitRates.BITRATE_16] | ||
|
||
@property | ||
def supported_sample_rates(self) -> list[AudioSampleRates]: | ||
"""Return a list of supported samplerates.""" | ||
return [AudioSampleRates.SAMPLERATE_16000] | ||
|
||
@property | ||
def supported_channels(self) -> list[AudioChannels]: | ||
"""Return a list of supported channels.""" | ||
return [AudioChannels.CHANNEL_MONO] | ||
|
||
async def async_process_audio_stream( | ||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] | ||
) -> SpeechResult: | ||
"""Process an audio stream to STT service.""" | ||
streaming_config = speech_v1.StreamingRecognitionConfig( | ||
config=speech_v1.RecognitionConfig( | ||
encoding=( | ||
speech_v1.RecognitionConfig.AudioEncoding.OGG_OPUS | ||
if metadata.codec == AudioCodecs.OPUS | ||
else speech_v1.RecognitionConfig.AudioEncoding.LINEAR16 | ||
), | ||
sample_rate_hertz=metadata.sample_rate, | ||
language_code=metadata.language, | ||
model=self._model, | ||
) | ||
) | ||
|
||
async def request_generator() -> ( | ||
AsyncGenerator[speech_v1.StreamingRecognizeRequest] | ||
): | ||
# The first request must only contain a streaming_config | ||
yield speech_v1.StreamingRecognizeRequest(streaming_config=streaming_config) | ||
# All subsequent requests must only contain audio_content | ||
async for audio_content in stream: | ||
yield speech_v1.StreamingRecognizeRequest(audio_content=audio_content) | ||
|
||
try: | ||
responses = await self._client.streaming_recognize( | ||
requests=request_generator(), | ||
timeout=10, | ||
) | ||
|
||
transcript = "" | ||
async for response in responses: | ||
_LOGGER.debug("response: %s", response) | ||
if not response.results: | ||
continue | ||
result = response.results[0] | ||
if not result.alternatives: | ||
continue | ||
transcript += response.results[0].alternatives[0].transcript | ||
except GoogleAPIError as err: | ||
_LOGGER.error("Error occurred during Google Cloud STT call: %s", err) | ||
if isinstance(err, Unauthenticated): | ||
self._entry.async_start_reauth(self.hass) | ||
return SpeechResult(None, SpeechResultState.ERROR) | ||
|
||
return SpeechResult(transcript, SpeechResultState.SUCCESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.