Skip to content

Commit

Permalink
Merge pull request #2922 from daspecster/vision-add-face-from-pb
Browse files Browse the repository at this point in the history
Vision: Add gRPC support for face detection.
  • Loading branch information
daspecster authored Jan 19, 2017
2 parents 4f5f782 + 04517cd commit 309fa8d
Show file tree
Hide file tree
Showing 13 changed files with 388 additions and 135 deletions.
29 changes: 16 additions & 13 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class BaseVisionTestCase(unittest.TestCase):
def _assert_coordinate(self, coordinate):
if coordinate is None:
return
self.assertIsNotNone(coordinate)
self.assertIsInstance(coordinate, (int, float))
self.assertNotEqual(coordinate, 0.0)

def _assert_likelihood(self, likelihood):
from google.cloud.vision.likelihood import Likelihood
Expand All @@ -73,8 +73,8 @@ def _assert_likelihood(self, likelihood):
Likelihood.VERY_UNLIKELY]
self.assertIn(likelihood, levels)

def _maybe_http_skip(self, message):
if not Config.CLIENT._use_gax:
def _pb_not_implemented_skip(self, message):
if Config.CLIENT._use_gax:
self.skipTest(message)


Expand Down Expand Up @@ -150,7 +150,7 @@ def _assert_landmarks(self, landmarks):

for landmark in LandmarkTypes:
if landmark is not LandmarkTypes.UNKNOWN_LANDMARK:
feature = getattr(landmarks, landmark.value.lower())
feature = getattr(landmarks, landmark.name.lower())
self.assertIsInstance(feature, Landmark)
self.assertIsInstance(feature.position, Position)
self._assert_coordinate(feature.position.x_coordinate)
Expand Down Expand Up @@ -194,7 +194,6 @@ def _assert_face(self, face):

def test_detect_faces_content(self):
client = Config.CLIENT
self._maybe_http_skip('gRPC is required for face detection.')
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
faces = image.detect_faces()
Expand All @@ -203,7 +202,6 @@ def test_detect_faces_content(self):
self._assert_face(face)

def test_detect_faces_gcs(self):
self._maybe_http_skip('gRPC is required for face detection.')
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
Expand All @@ -220,7 +218,6 @@ def test_detect_faces_gcs(self):
self._assert_face(face)

def test_detect_faces_filename(self):
self._maybe_http_skip('gRPC is required for face detection.')
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
faces = image.detect_faces()
Expand Down Expand Up @@ -367,7 +364,8 @@ def _assert_safe_search(self, safe_search):
self._assert_likelihood(safe_search.violence)

def test_detect_safe_search_content(self):
self._maybe_http_skip('gRPC is required for safe search detection.')
self._pb_not_implemented_skip(
'gRPC not implemented for safe search detection.')
client = Config.CLIENT
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
Expand All @@ -377,7 +375,8 @@ def test_detect_safe_search_content(self):
self._assert_safe_search(safe_search)

def test_detect_safe_search_gcs(self):
self._maybe_http_skip('gRPC is required for safe search detection.')
self._pb_not_implemented_skip(
'gRPC not implemented for safe search detection.')
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
Expand All @@ -395,7 +394,8 @@ def test_detect_safe_search_gcs(self):
self._assert_safe_search(safe_search)

def test_detect_safe_search_filename(self):
self._maybe_http_skip('gRPC is required for safe search detection.')
self._pb_not_implemented_skip(
'gRPC not implemented for safe search detection.')
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
safe_searches = image.detect_safe_search()
Expand Down Expand Up @@ -493,7 +493,8 @@ def _assert_properties(self, image_property):
self.assertNotEqual(color_info.score, 0.0)

def test_detect_properties_content(self):
self._maybe_http_skip('gRPC is required for text detection.')
self._pb_not_implemented_skip(
'gRPC not implemented for image properties detection.')
client = Config.CLIENT
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
Expand All @@ -503,7 +504,8 @@ def test_detect_properties_content(self):
self._assert_properties(image_property)

def test_detect_properties_gcs(self):
self._maybe_http_skip('gRPC is required for text detection.')
self._pb_not_implemented_skip(
'gRPC not implemented for image properties detection.')
client = Config.CLIENT
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
Expand All @@ -521,7 +523,8 @@ def test_detect_properties_gcs(self):
self._assert_properties(image_property)

def test_detect_properties_filename(self):
self._maybe_http_skip('gRPC is required for text detection.')
self._pb_not_implemented_skip(
'gRPC not implemented for image properties detection.')
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
properties = image.detect_properties()
Expand Down
4 changes: 1 addition & 3 deletions vision/google/cloud/vision/_gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from google.cloud.gapic.vision.v1 import image_annotator_client
from google.cloud.grpc.vision.v1 import image_annotator_pb2

from google.cloud._helpers import _to_bytes

from google.cloud.vision.annotations import Annotations


Expand Down Expand Up @@ -85,7 +83,7 @@ def _to_gapic_image(image):
:class:`~google.cloud.vision.image.Image`.
"""
if image.content is not None:
return image_annotator_pb2.Image(content=_to_bytes(image.content))
return image_annotator_pb2.Image(content=image.content)
if image.source is not None:
return image_annotator_pb2.Image(
source=image_annotator_pb2.ImageSource(
Expand Down
14 changes: 14 additions & 0 deletions vision/google/cloud/vision/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _process_image_annotations(image):
:returns: Dictionary populated with entities from response.
"""
return {
'faces': _make_faces_from_pb(image.face_annotations),
'labels': _make_entity_from_pb(image.label_annotations),
'landmarks': _make_entity_from_pb(image.landmark_annotations),
'logos': _make_entity_from_pb(image.logo_annotations),
Expand All @@ -139,6 +140,19 @@ def _make_entity_from_pb(annotations):
return [EntityAnnotation.from_pb(annotation) for annotation in annotations]


def _make_faces_from_pb(faces):
"""Create face objects from a gRPC response.
:type faces:
:class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.FaceAnnotation`
:param faces: Protobuf instance of ``FaceAnnotation``.
:rtype: list
:returns: List of ``Face``.
"""
return [Face.from_pb(face) for face in faces]


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature.
Expand Down
2 changes: 1 addition & 1 deletion vision/google/cloud/vision/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Client(JSONClient):
_vision_api_internal = None

def __init__(self, project=None, credentials=None, http=None,
use_gax=False):
use_gax=None):
super(Client, self).__init__(
project=project, credentials=credentials, http=http)
self._connection = Connection(
Expand Down
Loading

0 comments on commit 309fa8d

Please sign in to comment.