Skip to content

Commit 885397b

Browse files
authored
Merge pull request #2918 from daspecster/vision-add-gapic-entity-annotation
Add gax support for entity annotations.
2 parents 40fd881 + 7ff033d commit 885397b

File tree

12 files changed

+455
-29
lines changed

12 files changed

+455
-29
lines changed

system_tests/vision.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def _assert_likelihood(self, likelihood):
7373
Likelihood.VERY_UNLIKELY]
7474
self.assertIn(likelihood, levels)
7575

76+
def _maybe_http_skip(self, message):
77+
if not Config.CLIENT._use_gax:
78+
self.skipTest(message)
79+
7680

7781
class TestVisionClientLogo(unittest.TestCase):
7882
def setUp(self):
@@ -190,6 +194,7 @@ def _assert_face(self, face):
190194

191195
def test_detect_faces_content(self):
192196
client = Config.CLIENT
197+
self._maybe_http_skip('gRPC is required for face detection.')
193198
with open(FACE_FILE, 'rb') as image_file:
194199
image = client.image(content=image_file.read())
195200
faces = image.detect_faces()
@@ -198,6 +203,7 @@ def test_detect_faces_content(self):
198203
self._assert_face(face)
199204

200205
def test_detect_faces_gcs(self):
206+
self._maybe_http_skip('gRPC is required for face detection.')
201207
bucket_name = Config.TEST_BUCKET.name
202208
blob_name = 'faces.jpg'
203209
blob = Config.TEST_BUCKET.blob(blob_name)
@@ -206,7 +212,6 @@ def test_detect_faces_gcs(self):
206212
blob.upload_from_file(file_obj)
207213

208214
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
209-
210215
client = Config.CLIENT
211216
image = client.image(source_uri=source_uri)
212217
faces = image.detect_faces()
@@ -215,6 +220,7 @@ def test_detect_faces_gcs(self):
215220
self._assert_face(face)
216221

217222
def test_detect_faces_filename(self):
223+
self._maybe_http_skip('gRPC is required for face detection.')
218224
client = Config.CLIENT
219225
image = client.image(filename=FACE_FILE)
220226
faces = image.detect_faces()
@@ -361,6 +367,7 @@ def _assert_safe_search(self, safe_search):
361367
self._assert_likelihood(safe_search.violence)
362368

363369
def test_detect_safe_search_content(self):
370+
self._maybe_http_skip('gRPC is required for safe search detection.')
364371
client = Config.CLIENT
365372
with open(FACE_FILE, 'rb') as image_file:
366373
image = client.image(content=image_file.read())
@@ -370,6 +377,7 @@ def test_detect_safe_search_content(self):
370377
self._assert_safe_search(safe_search)
371378

372379
def test_detect_safe_search_gcs(self):
380+
self._maybe_http_skip('gRPC is required for safe search detection.')
373381
bucket_name = Config.TEST_BUCKET.name
374382
blob_name = 'faces.jpg'
375383
blob = Config.TEST_BUCKET.blob(blob_name)
@@ -387,6 +395,7 @@ def test_detect_safe_search_gcs(self):
387395
self._assert_safe_search(safe_search)
388396

389397
def test_detect_safe_search_filename(self):
398+
self._maybe_http_skip('gRPC is required for safe search detection.')
390399
client = Config.CLIENT
391400
image = client.image(filename=FACE_FILE)
392401
safe_searches = image.detect_safe_search()
@@ -484,6 +493,7 @@ def _assert_properties(self, image_property):
484493
self.assertNotEqual(color_info.score, 0.0)
485494

486495
def test_detect_properties_content(self):
496+
self._maybe_http_skip('gRPC is required for text detection.')
487497
client = Config.CLIENT
488498
with open(FACE_FILE, 'rb') as image_file:
489499
image = client.image(content=image_file.read())
@@ -493,6 +503,8 @@ def test_detect_properties_content(self):
493503
self._assert_properties(image_property)
494504

495505
def test_detect_properties_gcs(self):
506+
self._maybe_http_skip('gRPC is required for text detection.')
507+
client = Config.CLIENT
496508
bucket_name = Config.TEST_BUCKET.name
497509
blob_name = 'faces.jpg'
498510
blob = Config.TEST_BUCKET.blob(blob_name)
@@ -502,14 +514,14 @@ def test_detect_properties_gcs(self):
502514

503515
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
504516

505-
client = Config.CLIENT
506517
image = client.image(source_uri=source_uri)
507518
properties = image.detect_properties()
508519
self.assertEqual(len(properties), 1)
509520
image_property = properties[0]
510521
self._assert_properties(image_property)
511522

512523
def test_detect_properties_filename(self):
524+
self._maybe_http_skip('gRPC is required for text detection.')
513525
client = Config.CLIENT
514526
image = client.image(filename=FACE_FILE)
515527
properties = image.detect_properties()

vision/google/cloud/vision/_gax.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from google.cloud._helpers import _to_bytes
2121

22+
from google.cloud.vision.annotations import Annotations
23+
2224

2325
class _GAPICVisionAPI(object):
2426
"""Vision API for interacting with the gRPC version of Vision.
@@ -28,7 +30,32 @@ class _GAPICVisionAPI(object):
2830
"""
2931
def __init__(self, client=None):
3032
self._client = client
31-
self._api = image_annotator_client.ImageAnnotatorClient()
33+
self._annotator_client = image_annotator_client.ImageAnnotatorClient()
34+
35+
def annotate(self, image, features):
36+
"""Annotate images through GAX.
37+
38+
:type image: :class:`~google.cloud.vision.image.Image`
39+
:param image: Instance of ``Image``.
40+
41+
:type features: list
42+
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
43+
44+
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
45+
:returns: Instance of ``Annotations`` with results or ``None``.
46+
"""
47+
gapic_features = [_to_gapic_feature(feature) for feature in features]
48+
gapic_image = _to_gapic_image(image)
49+
request = image_annotator_pb2.AnnotateImageRequest(
50+
image=gapic_image, features=gapic_features)
51+
requests = [request]
52+
annotator_client = self._annotator_client
53+
images = annotator_client.batch_annotate_images(requests)
54+
if len(images.responses) == 1:
55+
return Annotations.from_pb(images.responses[0])
56+
elif len(images.responses) > 1:
57+
raise NotImplementedError(
58+
'Multiple image processing is not yet supported.')
3259

3360

3461
def _to_gapic_feature(feature):

vision/google/cloud/vision/_http.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""HTTP Client for interacting with the Google Cloud Vision API."""
1616

17+
from google.cloud.vision.annotations import Annotations
1718
from google.cloud.vision.feature import Feature
1819

1920

@@ -48,8 +49,12 @@ def annotate(self, image, features):
4849
data = {'requests': [request]}
4950
api_response = self._connection.api_request(
5051
method='POST', path='/images:annotate', data=data)
51-
responses = api_response.get('responses')
52-
return responses[0]
52+
images = api_response.get('responses')
53+
if len(images) == 1:
54+
return Annotations.from_api_repr(images[0])
55+
elif len(images) > 1:
56+
raise NotImplementedError(
57+
'Multiple image processing is not yet supported.')
5358

5459

5560
def _make_request(image, features):

vision/google/cloud/vision/annotations.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,51 @@ def from_api_repr(cls, response):
9393
_entity_from_response_type(feature_type, annotation))
9494
return cls(**annotations)
9595

96+
@classmethod
97+
def from_pb(cls, response):
98+
"""Factory: construct an instance of ``Annotations`` from protobuf.
99+
100+
:type response: :class:`~google.cloud.grpc.vision.v1.\
101+
image_annotator_pb2.AnnotateImageResponse`
102+
:param response: ``AnnotateImageResponse`` from protobuf call.
103+
104+
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
105+
:returns: ``Annotations`` instance populated from gRPC response.
106+
"""
107+
annotations = _process_image_annotations(response)
108+
return cls(**annotations)
109+
110+
111+
def _process_image_annotations(image):
112+
"""Helper for processing annotation types from protobuf.
113+
114+
:type image: :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.\
115+
AnnotateImageResponse`
116+
:param image: ``AnnotateImageResponse`` from protobuf.
117+
118+
:rtype: dict
119+
:returns: Dictionary populated with entities from response.
120+
"""
121+
return {
122+
'labels': _make_entity_from_pb(image.label_annotations),
123+
'landmarks': _make_entity_from_pb(image.landmark_annotations),
124+
'logos': _make_entity_from_pb(image.logo_annotations),
125+
'texts': _make_entity_from_pb(image.text_annotations),
126+
}
127+
128+
129+
def _make_entity_from_pb(annotations):
130+
"""Create an entity from a gRPC response.
131+
132+
:type annotations:
133+
:class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.EntityAnnotation`
134+
:param annotations: protobuf instance of ``EntityAnnotation``.
135+
136+
:rtype: list
137+
:returns: List of ``EntityAnnotation``.
138+
"""
139+
return [EntityAnnotation.from_pb(annotation) for annotation in annotations]
140+
96141

97142
def _entity_from_response_type(feature_type, results):
98143
"""Convert a JSON result to an entity type based on the feature.

vision/google/cloud/vision/entity.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,32 @@ def from_api_repr(cls, response):
6464
description = response['description']
6565
locale = response.get('locale', None)
6666
locations = [LocationInformation.from_api_repr(location)
67-
for location in response.get('locations', [])]
67+
for location in response.get('locations', ())]
6868
mid = response.get('mid', None)
6969
score = response.get('score', None)
7070

7171
return cls(bounds, description, locale, locations, mid, score)
7272

73+
@classmethod
74+
def from_pb(cls, response):
75+
"""Factory: construct entity from Vision gRPC response.
76+
77+
:type response: :class:`~google.cloud.grpc.vision.v1.\
78+
image_annotator_pb2.AnnotateImageResponse`
79+
:param response: gRPC response from Vision API with entity data.
80+
81+
:rtype: :class:`~google.cloud.vision.entity.EntityAnnotation`
82+
:returns: Instance of ``EntityAnnotation``.
83+
"""
84+
bounds = Bounds.from_pb(response.bounding_poly)
85+
description = response.description
86+
locale = response.locale
87+
locations = [LocationInformation.from_pb(location)
88+
for location in response.locations]
89+
mid = response.mid
90+
score = response.score
91+
return cls(bounds, description, locale, locations, mid, score)
92+
7393
@property
7494
def bounds(self):
7595
"""Bounding polygon of detected image feature.

vision/google/cloud/vision/geometry.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,33 @@ def __init__(self, vertices):
2525
self._vertices = vertices
2626

2727
@classmethod
28-
def from_api_repr(cls, response_vertices):
28+
def from_api_repr(cls, vertices):
2929
"""Factory: construct BoundsBase instance from Vision API response.
3030
31-
:type response_vertices: dict
32-
:param response_vertices: List of vertices.
31+
:type vertices: dict
32+
:param vertices: List of vertices.
3333
3434
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
3535
:returns: Instance of BoundsBase with populated verticies or None.
3636
"""
37-
if not response_vertices:
37+
if vertices is None:
3838
return None
39+
return cls([Vertex(vertex.get('x', None), vertex.get('y', None))
40+
for vertex in vertices.get('vertices', ())])
3941

40-
vertices = [Vertex(vertex.get('x', None), vertex.get('y', None)) for
41-
vertex in response_vertices.get('vertices', [])]
42-
return cls(vertices)
42+
@classmethod
43+
def from_pb(cls, vertices):
44+
"""Factory: construct BoundsBase instance from Vision gRPC response.
45+
46+
:type vertices: :class:`~google.cloud.grpc.vision.v1.\
47+
geometry_pb2.BoundingPoly`
48+
:param vertices: List of vertices.
49+
50+
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
51+
:returns: Instance of ``BoundsBase`` with populated verticies.
52+
"""
53+
return cls([Vertex(vertex.x, vertex.y)
54+
for vertex in vertices.vertices])
4355

4456
@property
4557
def vertices(self):
@@ -73,20 +85,35 @@ def __init__(self, latitude, longitude):
7385
self._longitude = longitude
7486

7587
@classmethod
76-
def from_api_repr(cls, response):
88+
def from_api_repr(cls, location_info):
7789
"""Factory: construct location information from Vision API response.
7890
79-
:type response: dict
80-
:param response: Dictionary response of locations.
91+
:type location_info: dict
92+
:param location_info: Dictionary response of locations.
8193
8294
:rtype: :class:`~google.cloud.vision.geometry.LocationInformation`
8395
:returns: ``LocationInformation`` with populated latitude and
8496
longitude.
8597
"""
86-
latitude = response['latLng']['latitude']
87-
longitude = response['latLng']['longitude']
98+
lat_long = location_info.get('latLng', {})
99+
latitude = lat_long.get('latitude')
100+
longitude = lat_long.get('longitude')
88101
return cls(latitude, longitude)
89102

103+
@classmethod
104+
def from_pb(cls, location_info):
105+
"""Factory: construct location information from Vision gRPC response.
106+
107+
:type location_info: :class:`~google.cloud.vision.v1.LocationInfo`
108+
:param location_info: gRPC response of ``LocationInfo``.
109+
110+
:rtype: :class:`~google.cloud.vision.geometry.LocationInformation`
111+
:returns: ``LocationInformation`` with populated latitude and
112+
longitude.
113+
"""
114+
return cls(location_info.lat_lng.latitude,
115+
location_info.lat_lng.longitude)
116+
90117
@property
91118
def latitude(self):
92119
"""Latitude coordinate.
@@ -127,15 +154,18 @@ def __init__(self, x_coordinate, y_coordinate, z_coordinate):
127154
self._z_coordinate = z_coordinate
128155

129156
@classmethod
130-
def from_api_repr(cls, response_position):
157+
def from_api_repr(cls, position):
131158
"""Factory: construct 3D position from API response.
132159
160+
:type position: dict
161+
:param position: Dictionary with 3 axis position data.
162+
133163
:rtype: :class:`~google.cloud.vision.geometry.Position`
134164
:returns: `Position` constructed with 3D points from API response.
135165
"""
136-
x_coordinate = response_position['x']
137-
y_coordinate = response_position['y']
138-
z_coordinate = response_position['z']
166+
x_coordinate = position['x']
167+
y_coordinate = position['y']
168+
z_coordinate = position['z']
139169
return cls(x_coordinate, y_coordinate, z_coordinate)
140170

141171
@property

vision/google/cloud/vision/image.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from google.cloud._helpers import _to_bytes
2121
from google.cloud._helpers import _bytes_to_unicode
22-
from google.cloud.vision.annotations import Annotations
2322
from google.cloud.vision.feature import Feature
2423
from google.cloud.vision.feature import FeatureTypes
2524

@@ -109,8 +108,7 @@ def _detect_annotation(self, features):
109108
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
110109
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
111110
"""
112-
results = self.client._vision_api.annotate(self, features)
113-
return Annotations.from_api_repr(results)
111+
return self.client._vision_api.annotate(self, features)
114112

115113
def detect(self, features):
116114
"""Detect multiple feature types.

0 commit comments

Comments
 (0)