Skip to content

Commit d337f84

Browse files
committed
Add GAPIC support for safe search.
1 parent 6eed70a commit d337f84

File tree

9 files changed

+144
-45
lines changed

9 files changed

+144
-45
lines changed

docs/vision-usage.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ categorize the entire contents of the image under four categories.
224224
>>> client = vision.Client()
225225
>>> with open('./image.jpg', 'rb') as image_file:
226226
... image = client.image(content=image_file.read())
227-
>>> safe_search_results = image.detect_safe_search()
228-
>>> safe_search = safe_search_results[0]
227+
>>> safe_search = image.detect_safe_search()
229228
>>> safe_search.adult
230229
<Likelihood.VERY_UNLIKELY: 'VERY_UNLIKELY'>
231230
>>> safe_search.spoof

system_tests/vision.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -364,19 +364,13 @@ def _assert_safe_search(self, safe_search):
364364
self._assert_likelihood(safe_search.violence)
365365

366366
def test_detect_safe_search_content(self):
367-
self._pb_not_implemented_skip(
368-
'gRPC not implemented for safe search detection.')
369367
client = Config.CLIENT
370368
with open(FACE_FILE, 'rb') as image_file:
371369
image = client.image(content=image_file.read())
372-
safe_searches = image.detect_safe_search()
373-
self.assertEqual(len(safe_searches), 1)
374-
safe_search = safe_searches[0]
370+
safe_search = image.detect_safe_search()
375371
self._assert_safe_search(safe_search)
376372

377373
def test_detect_safe_search_gcs(self):
378-
self._pb_not_implemented_skip(
379-
'gRPC not implemented for safe search detection.')
380374
bucket_name = Config.TEST_BUCKET.name
381375
blob_name = 'faces.jpg'
382376
blob = Config.TEST_BUCKET.blob(blob_name)
@@ -388,19 +382,13 @@ def test_detect_safe_search_gcs(self):
388382

389383
client = Config.CLIENT
390384
image = client.image(source_uri=source_uri)
391-
safe_searches = image.detect_safe_search()
392-
self.assertEqual(len(safe_searches), 1)
393-
safe_search = safe_searches[0]
385+
safe_search = image.detect_safe_search()
394386
self._assert_safe_search(safe_search)
395387

396388
def test_detect_safe_search_filename(self):
397-
self._pb_not_implemented_skip(
398-
'gRPC not implemented for safe search detection.')
399389
client = Config.CLIENT
400390
image = client.image(filename=FACE_FILE)
401-
safe_searches = image.detect_safe_search()
402-
self.assertEqual(len(safe_searches), 1)
403-
safe_search = safe_searches[0]
391+
safe_search = image.detect_safe_search()
404392
self._assert_safe_search(safe_search)
405393

406394

vision/google/cloud/vision/annotations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def _process_image_annotations(image):
126126
'logos': _make_entity_from_pb(image.logo_annotations),
127127
'properties': _make_image_properties_from_pb(
128128
image.image_properties_annotation),
129+
'safe_searches': _make_safe_search_from_pb(
130+
image.safe_search_annotation),
129131
'texts': _make_entity_from_pb(image.text_annotations),
130132
}
131133

@@ -170,6 +172,19 @@ def _make_image_properties_from_pb(image_properties):
170172
return ImagePropertiesAnnotation.from_pb(image_properties)
171173

172174

175+
def _make_safe_search_from_pb(safe_search):
176+
"""Create ``SafeSearchAnnotation`` object from a protobuf response.
177+
178+
:type safe_search: :class:`~google.cloud.grpc.vision.v1.\
179+
image_annotator_pb2.SafeSearchAnnotation`
180+
:param safe_search: Protobuf instance of ``SafeSearchAnnotation``.
181+
182+
:rtype: :class: `~google.cloud.vision.safe.SafeSearchAnnotation`
183+
:returns: Instance of ``SafeSearchAnnotation``.
184+
"""
185+
return SafeSearchAnnotation.from_pb(safe_search)
186+
187+
173188
def _entity_from_response_type(feature_type, results):
174189
"""Convert a JSON result to an entity type based on the feature.
175190

vision/google/cloud/vision/face.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,12 @@
1717

1818
from enum import Enum
1919

20-
from google.cloud.grpc.vision.v1 import image_annotator_pb2
21-
2220
from google.cloud.vision.geometry import BoundsBase
21+
from google.cloud.vision.likelihood import get_pb_likelihood
2322
from google.cloud.vision.likelihood import Likelihood
2423
from google.cloud.vision.geometry import Position
2524

2625

27-
def _get_pb_likelihood(likelihood):
28-
"""Convert protobuf Likelihood integer value to Likelihood instance.
29-
30-
:type likelihood: int
31-
:param likelihood: Protobuf integer representing ``Likelihood``.
32-
33-
:rtype: :class:`~google.cloud.vision.likelihood.Likelihood`
34-
:returns: Instance of ``Likelihood`` converted from protobuf value.
35-
"""
36-
likelihood_pb = image_annotator_pb2.Likelihood.Name(likelihood)
37-
return Likelihood[likelihood_pb]
38-
39-
4026
class Angles(object):
4127
"""Angles representing the positions of a face."""
4228
def __init__(self, roll, pan, tilt):
@@ -147,10 +133,10 @@ def from_pb(cls, emotions):
147133
:rtype: :class:`~google.cloud.vision.face.Emotions`
148134
:returns: Populated instance of ``Emotions``.
149135
"""
150-
joy_likelihood = _get_pb_likelihood(emotions.joy_likelihood)
151-
sorrow_likelihood = _get_pb_likelihood(emotions.sorrow_likelihood)
152-
surprise_likelihood = _get_pb_likelihood(emotions.surprise_likelihood)
153-
anger_likelihood = _get_pb_likelihood(emotions.anger_likelihood)
136+
joy_likelihood = get_pb_likelihood(emotions.joy_likelihood)
137+
sorrow_likelihood = get_pb_likelihood(emotions.sorrow_likelihood)
138+
surprise_likelihood = get_pb_likelihood(emotions.surprise_likelihood)
139+
anger_likelihood = get_pb_likelihood(emotions.anger_likelihood)
154140

155141
return cls(joy_likelihood, sorrow_likelihood, surprise_likelihood,
156142
anger_likelihood)
@@ -252,7 +238,7 @@ def from_pb(cls, face):
252238
'detection_confidence': face.detection_confidence,
253239
'emotions': Emotions.from_pb(face),
254240
'fd_bounds': FDBounds.from_pb(face.fd_bounding_poly),
255-
'headwear_likelihood': _get_pb_likelihood(
241+
'headwear_likelihood': get_pb_likelihood(
256242
face.headwear_likelihood),
257243
'image_properties': FaceImageProperties.from_pb(face),
258244
'landmarks': Landmarks.from_pb(face.landmarks),
@@ -418,8 +404,8 @@ def from_pb(cls, face):
418404
:rtype: :class:`~google.cloud.vision.face.FaceImageProperties`
419405
:returns: Instance populated with image property data.
420406
"""
421-
blurred = _get_pb_likelihood(face.blurred_likelihood)
422-
underexposed = _get_pb_likelihood(face.under_exposed_likelihood)
407+
blurred = get_pb_likelihood(face.blurred_likelihood)
408+
underexposed = get_pb_likelihood(face.under_exposed_likelihood)
423409

424410
return cls(blurred, underexposed)
425411

vision/google/cloud/vision/likelihood.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@
1717

1818
from enum import Enum
1919

20+
from google.cloud.grpc.vision.v1 import image_annotator_pb2
21+
22+
23+
def get_pb_likelihood(likelihood):
24+
"""Convert protobuf Likelihood integer value to Likelihood enum.
25+
26+
:type likelihood: int
27+
:param likelihood: Protobuf integer representing ``Likelihood``.
28+
29+
:rtype: :class:`~google.cloud.vision.likelihood.Likelihood`
30+
:returns: Enum ``Likelihood`` converted from protobuf value.
31+
"""
32+
likelihood_pb = image_annotator_pb2.Likelihood.Name(likelihood)
33+
return Likelihood[likelihood_pb]
34+
2035

2136
class Likelihood(Enum):
2237
"""A representation of likelihood to give stable results across upgrades.

vision/google/cloud/vision/safe.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Safe search class for information returned from annotating an image."""
1616

17-
17+
from google.cloud.vision.likelihood import get_pb_likelihood
1818
from google.cloud.vision.likelihood import Likelihood
1919

2020

@@ -54,14 +54,30 @@ def from_api_repr(cls, response):
5454
:rtype: :class:`~google.cloud.vision.safe.SafeSearchAnnotation`
5555
:returns: Instance of ``SafeSearchAnnotation``.
5656
"""
57-
adult_likelihood = getattr(Likelihood, response['adult'])
58-
spoof_likelihood = getattr(Likelihood, response['spoof'])
59-
medical_likelihood = getattr(Likelihood, response['medical'])
60-
violence_likelihood = getattr(Likelihood, response['violence'])
57+
adult_likelihood = Likelihood[response['adult']]
58+
spoof_likelihood = Likelihood[response['spoof']]
59+
medical_likelihood = Likelihood[response['medical']]
60+
violence_likelihood = Likelihood[response['violence']]
6161

6262
return cls(adult_likelihood, spoof_likelihood, medical_likelihood,
6363
violence_likelihood)
6464

65+
@classmethod
66+
def from_pb(cls, image):
67+
"""Factory: construct SafeSearchAnnotation from Vision API response.
68+
69+
:type image: :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.\
70+
SafeSearchAnnotation`
71+
:param image: Protobuf response from Vision API with safe search data.
72+
73+
:rtype: :class:`~google.cloud.vision.safe.SafeSearchAnnotation`
74+
:returns: Instance of ``SafeSearchAnnotation``.
75+
"""
76+
classifications = map(get_pb_likelihood, [image.adult, image.spoof,
77+
image.medical,
78+
image.violence])
79+
return cls(*classifications)
80+
6581
@property
6682
def adult(self):
6783
"""Represents the adult contents likelihood for the image.

vision/unit_tests/test_annotations.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def test_ctor(self):
6767
self.assertEqual(annotations.texts, [True])
6868

6969
def test_from_pb(self):
70+
from google.cloud.vision.likelihood import Likelihood
71+
from google.cloud.vision.safe import SafeSearchAnnotation
7072
from google.cloud.grpc.vision.v1 import image_annotator_pb2
7173

7274
image_response = image_annotator_pb2.AnnotateImageResponse()
@@ -76,9 +78,16 @@ def test_from_pb(self):
7678
self.assertEqual(annotations.faces, [])
7779
self.assertEqual(annotations.landmarks, [])
7880
self.assertEqual(annotations.texts, [])
79-
self.assertEqual(annotations.safe_searches, ())
8081
self.assertIsNone(annotations.properties)
8182

83+
self.assertIsInstance(annotations.safe_searches, SafeSearchAnnotation)
84+
safe_search = annotations.safe_searches
85+
unknown = Likelihood.UNKNOWN
86+
self.assertEqual(safe_search.adult, unknown)
87+
self.assertEqual(safe_search.spoof, unknown)
88+
self.assertEqual(safe_search.medical, unknown)
89+
self.assertEqual(safe_search.violence, unknown)
90+
8291

8392
class Test__make_entity_from_pb(unittest.TestCase):
8493
def _call_fut(self, annotations):

vision/unit_tests/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def test_safe_search_detection_from_source(self):
425425
image_request = client._connection._requested[0]['data']['requests'][0]
426426
self.assertEqual(IMAGE_SOURCE,
427427
image_request['image']['source']['gcs_image_uri'])
428+
428429
self.assertEqual(safe_search.adult, Likelihood.VERY_UNLIKELY)
429430
self.assertEqual(safe_search.spoof, Likelihood.UNLIKELY)
430431
self.assertEqual(safe_search.medical, Likelihood.POSSIBLE)

vision/unit_tests/test_safe.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2016 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
18+
class TestSafeSearchAnnotation(unittest.TestCase):
19+
@staticmethod
20+
def _get_target_class():
21+
from google.cloud.vision.safe import SafeSearchAnnotation
22+
return SafeSearchAnnotation
23+
24+
def test_safe_search_annotation(self):
25+
from google.cloud.vision.likelihood import Likelihood
26+
from unit_tests._fixtures import SAFE_SEARCH_DETECTION_RESPONSE
27+
28+
response = SAFE_SEARCH_DETECTION_RESPONSE['responses'][0]
29+
safe_search_response = response['safeSearchAnnotation']
30+
31+
safe_search = self._get_target_class().from_api_repr(
32+
safe_search_response)
33+
34+
self.assertEqual(safe_search.adult, Likelihood.VERY_UNLIKELY)
35+
self.assertEqual(safe_search.spoof, Likelihood.UNLIKELY)
36+
self.assertEqual(safe_search.medical, Likelihood.POSSIBLE)
37+
self.assertEqual(safe_search.violence, Likelihood.VERY_UNLIKELY)
38+
39+
def test_pb_safe_search_annotation(self):
40+
from google.cloud.vision.likelihood import Likelihood
41+
from google.cloud.grpc.vision.v1.image_annotator_pb2 import (
42+
Likelihood as LikelihoodPB)
43+
from google.cloud.grpc.vision.v1 import image_annotator_pb2
44+
45+
possible = LikelihoodPB.Value('POSSIBLE')
46+
possible_name = Likelihood.POSSIBLE
47+
safe_search_annotation = image_annotator_pb2.SafeSearchAnnotation(
48+
adult=possible, spoof=possible, medical=possible, violence=possible
49+
)
50+
51+
safe_search = self._get_target_class().from_pb(safe_search_annotation)
52+
53+
self.assertEqual(safe_search.adult, possible_name)
54+
self.assertEqual(safe_search.spoof, possible_name)
55+
self.assertEqual(safe_search.medical, possible_name)
56+
self.assertEqual(safe_search.violence, possible_name)
57+
58+
def test_empty_pb_safe_search_annotation(self):
59+
from google.cloud.vision.likelihood import Likelihood
60+
from google.cloud.grpc.vision.v1 import image_annotator_pb2
61+
62+
unknown = Likelihood.UNKNOWN
63+
safe_search_annotation = image_annotator_pb2.SafeSearchAnnotation()
64+
65+
safe_search = self._get_target_class().from_pb(safe_search_annotation)
66+
67+
self.assertEqual(safe_search.adult, unknown)
68+
self.assertEqual(safe_search.spoof, unknown)
69+
self.assertEqual(safe_search.medical, unknown)
70+
self.assertEqual(safe_search.violence, unknown)

0 commit comments

Comments
 (0)