Skip to content

Commit

Permalink
Fix Speech LRO handling and add HTTP side for multiple results. (#2965)
Browse files Browse the repository at this point in the history
* Update after #2962 to fill out http side and handle new LRO.

* Update _OperationsFuture usage.

* Mock OperationsClient.
  • Loading branch information
daspecster authored Feb 6, 2017
1 parent 7e16325 commit de65409
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 40 deletions.
28 changes: 16 additions & 12 deletions docs/speech-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ See: `Speech Asynchronous Recognize`_
>>> operation.complete
True
>>> for result in operation.results:
... print('=' * 20)
... print(result.transcript)
... print(result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print(alternative.transcript)
... print(alternative.confidence)
====================
'how old is the Brooklyn Bridge'
0.98267895
Expand All @@ -93,9 +94,10 @@ Great Britian.
... source_uri='gs://my-bucket/recording.flac', language_code='en-GB',
... max_alternatives=2)
>>> for result in results:
... print('=' * 20)
... print('transcript: ' + result.transcript)
... print('confidence: ' + result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print('transcript: ' + alternative.transcript)
... print('confidence: ' + alternative.confidence)
====================
transcript: Hello, this is a test
confidence: 0.81
Expand All @@ -115,9 +117,10 @@ Example of using the profanity filter.
>>> results = sample.sync_recognize(max_alternatives=1,
... profanity_filter=True)
>>> for result in results:
... print('=' * 20)
... print('transcript: ' + result.transcript)
... print('confidence: ' + result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print('transcript: ' + alternative.transcript)
... print('confidence: ' + alternative.confidence)
====================
transcript: Hello, this is a f****** test
confidence: 0.81
Expand All @@ -137,9 +140,10 @@ words to the vocabulary of the recognizer.
>>> results = sample.sync_recognize(max_alternatives=2,
... speech_context=hints)
>>> for result in results:
... print('=' * 20)
... print('transcript: ' + result.transcript)
... print('confidence: ' + result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print('transcript: ' + alternative.transcript)
... print('confidence: ' + alternative.confidence)
====================
transcript: Hello, this is a test
confidence: 0.81
Expand Down
4 changes: 2 additions & 2 deletions speech/google/cloud/speech/_gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def async_recognize(self, sample, language_code=None,
audio = RecognitionAudio(content=sample.content,
uri=sample.source_uri)
api = self._gapic_api
response = api.async_recognize(config=config, audio=audio)
operation_future = api.async_recognize(config=config, audio=audio)

return Operation.from_pb(response, self)
return Operation.from_pb(operation_future.last_operation_data(), self)

def streaming_recognize(self, sample, language_code=None,
max_alternatives=None, profanity_filter=None,
Expand Down
11 changes: 5 additions & 6 deletions speech/google/cloud/speech/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from google.cloud.environment_vars import DISABLE_GRPC

from google.cloud.speech._gax import GAPICSpeechAPI
from google.cloud.speech.alternative import Alternative
from google.cloud.speech.result import Result
from google.cloud.speech.connection import Connection
from google.cloud.speech.operation import Operation
from google.cloud.speech.sample import Sample
Expand Down Expand Up @@ -235,12 +235,11 @@ def sync_recognize(self, sample, language_code=None, max_alternatives=None,
api_response = self._connection.api_request(
method='POST', path='speech:syncrecognize', data=data)

if len(api_response['results']) == 1:
result = api_response['results'][0]
return [Alternative.from_api_repr(alternative)
for alternative in result['alternatives']]
if len(api_response['results']) > 0:
results = api_response['results']
return [Result.from_api_repr(result) for result in results]
else:
raise ValueError('More than one result or none returned from API.')
raise ValueError('No results were returned from the API')


def _build_request_data(sample, language_code=None, max_alternatives=None,
Expand Down
27 changes: 21 additions & 6 deletions speech/google/cloud/speech/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,34 @@ def __init__(self, alternatives):

@classmethod
def from_pb(cls, result):
"""Factory: construct instance of ``SpeechRecognitionResult``.
"""Factory: construct instance of ``Result``.
:type result: :class:`~google.cloud.grpc.speech.v1beta1\
.cloud_speech_pb2.StreamingRecognizeResult`
:param result: Instance of ``StreamingRecognizeResult`` protobuf.
.cloud_speech_pb2.SpeechRecognitionResult`
:param result: Instance of ``SpeechRecognitionResult`` protobuf.
:rtype: :class:`~google.cloud.speech.result.SpeechRecognitionResult`
:returns: Instance of ``SpeechRecognitionResult``.
:rtype: :class:`~google.cloud.speech.result.Result`
:returns: Instance of ``Result``.
"""
alternatives = [Alternative.from_pb(result) for result
alternatives = [Alternative.from_pb(alternative) for alternative
in result.alternatives]
return cls(alternatives=alternatives)

@classmethod
def from_api_repr(cls, result):
"""Factory: construct instance of ``Result``.
:type result: dict
:param result: Dictionary of a :class:`~google.cloud.grpc.speech.\
v1beta1.cloud_speech_pb2.SpeechRecognitionResult`
:rtype: :class:`~google.cloud.speech.result.Result`
:returns: Instance of ``Result``.
"""
alternatives = [Alternative.from_api_repr(alternative) for alternative
in result['alternatives']]
return cls(alternatives=alternatives)

@property
def confidence(self):
"""Return the confidence for the most probable alternative.
Expand Down
30 changes: 22 additions & 8 deletions speech/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_sync_recognize_content_with_optional_params_no_gax(self):

from google.cloud import speech
from google.cloud.speech.alternative import Alternative
from google.cloud.speech.result import Result
from unit_tests._fixtures import SYNC_RECOGNIZE_RESPONSE

_B64_AUDIO_CONTENT = _bytes_to_unicode(b64encode(self.AUDIO_CONTENT))
Expand Down Expand Up @@ -174,13 +175,16 @@ def test_sync_recognize_content_with_optional_params_no_gax(self):
alternative = SYNC_RECOGNIZE_RESPONSE['results'][0]['alternatives'][0]
expected = Alternative.from_api_repr(alternative)
self.assertEqual(len(response), 1)
self.assertIsInstance(response[0], Alternative)
self.assertEqual(response[0].transcript, expected.transcript)
self.assertEqual(response[0].confidence, expected.confidence)
self.assertIsInstance(response[0], Result)
self.assertEqual(len(response[0].alternatives), 1)
alternative = response[0].alternatives[0]
self.assertEqual(alternative.transcript, expected.transcript)
self.assertEqual(alternative.confidence, expected.confidence)

def test_sync_recognize_source_uri_without_optional_params_no_gax(self):
from google.cloud import speech
from google.cloud.speech.alternative import Alternative
from google.cloud.speech.result import Result
from unit_tests._fixtures import SYNC_RECOGNIZE_RESPONSE

RETURNED = SYNC_RECOGNIZE_RESPONSE
Expand Down Expand Up @@ -214,9 +218,12 @@ def test_sync_recognize_source_uri_without_optional_params_no_gax(self):
expected = Alternative.from_api_repr(
SYNC_RECOGNIZE_RESPONSE['results'][0]['alternatives'][0])
self.assertEqual(len(response), 1)
self.assertIsInstance(response[0], Alternative)
self.assertEqual(response[0].transcript, expected.transcript)
self.assertEqual(response[0].confidence, expected.confidence)
self.assertIsInstance(response[0], Result)
self.assertEqual(len(response[0].alternatives), 1)
alternative = response[0].alternatives[0]

self.assertEqual(alternative.transcript, expected.transcript)
self.assertEqual(alternative.confidence, expected.confidence)

def test_sync_recognize_with_empty_results_no_gax(self):
from google.cloud import speech
Expand Down Expand Up @@ -710,19 +717,26 @@ class _MockGAPICSpeechAPI(object):
_requests = None
_response = None
_results = None

SERVICE_ADDRESS = 'foo.apis.invalid'

def __init__(self, response=None, channel=None):
self._response = response
self._channel = channel

def async_recognize(self, config, audio):
from google.gapic.longrunning.operations_client import OperationsClient
from google.gax import _OperationFuture
from google.longrunning.operations_pb2 import Operation
from google.cloud.proto.speech.v1beta1.cloud_speech_pb2 import (
AsyncRecognizeResponse)

self.config = config
self.audio = audio
operation = Operation()
return operation
operations_client = mock.Mock(spec=OperationsClient)
operation_future = _OperationFuture(Operation(), operations_client,
AsyncRecognizeResponse, {})
return operation_future

def sync_recognize(self, config, audio):
self.config = config
Expand Down
22 changes: 16 additions & 6 deletions system_tests/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def test_sync_recognize_local_file(self):

results = self._make_sync_request(content=content,
max_alternatives=2)
self._check_results(results, 2)
self.assertEqual(len(results), 1)
alternatives = results[0].alternatives
self.assertEqual(len(alternatives), 2)
self._check_results(alternatives, 2)

def test_sync_recognize_gcs_file(self):
bucket_name = Config.TEST_BUCKET.name
Expand All @@ -155,9 +158,10 @@ def test_sync_recognize_gcs_file(self):
blob.upload_from_file(file_obj)

source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
result = self._make_sync_request(source_uri=source_uri,
max_alternatives=1)
self._check_results(result)
results = self._make_sync_request(source_uri=source_uri,
max_alternatives=1)
self.assertEqual(len(results), 1)
self._check_results(results[0].alternatives)

def test_async_recognize_local_file(self):
with open(AUDIO_FILE, 'rb') as file_obj:
Expand All @@ -167,7 +171,10 @@ def test_async_recognize_local_file(self):
max_alternatives=2)

_wait_until_complete(operation)
self._check_results(operation.results, 2)
self.assertEqual(len(operation.results), 1)
alternatives = operation.results[0].alternatives
self.assertEqual(len(alternatives), 2)
self._check_results(alternatives, 2)

def test_async_recognize_gcs_file(self):
bucket_name = Config.TEST_BUCKET.name
Expand All @@ -182,7 +189,10 @@ def test_async_recognize_gcs_file(self):
max_alternatives=2)

_wait_until_complete(operation)
self._check_results(operation.results, 2)
self.assertEqual(len(operation.results), 1)
alternatives = operation.results[0].alternatives
self.assertEqual(len(alternatives), 2)
self._check_results(alternatives, 2)

def test_stream_recognize(self):
if not Config.USE_GAX:
Expand Down

0 comments on commit de65409

Please sign in to comment.