Skip to content

Commit

Permalink
Vision: Add batch processing (#2978)
Browse files Browse the repository at this point in the history
* Add Vision batch support to the surface.
  • Loading branch information
daspecster authored Feb 9, 2017
1 parent 4d2a7d1 commit db0dc85
Show file tree
Hide file tree
Showing 13 changed files with 289 additions and 46 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@

vision-usage
vision-annotations
vision-batch
vision-client
vision-color
vision-entity
Expand Down
10 changes: 10 additions & 0 deletions docs/vision-batch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Vision Batch
============

Batch
~~~~~

.. automodule:: google.cloud.vision.batch
:members:
:undoc-members:
:show-inheritance:
34 changes: 34 additions & 0 deletions docs/vision-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,40 @@ image and determine the dominant colors in the image.
0.758658
*********************
Batch image detection
*********************

Multiple images can be processed with a single request by passing
:class:`~google.cloud.vision.image.Image` to
:meth:`~google.cloud.vision.client.Client.batch()`.

.. code-block:: python
>>> from google.cloud import vision
>>> from google.cloud.vision.feature import Feature
>>> from google.cloud.vision.feature import FeatureTypes
>>>
>>> client = vision.Client()
>>> batch = client.batch()
>>>
>>> image_one = client.image(source_uri='gs://my-test-bucket/image1.jpg')
>>> image_two = client.image(source_uri='gs://my-test-bucket/image2.jpg')
>>> face_feature = Feature(FeatureTypes.FACE_DETECTION, 2)
>>> logo_feature = Feature(FeatureTypes.LOGO_DETECTION, 2)
>>> batch.add_image(image_one, [face_feature, logo_feature])
>>> batch.add_image(image_two, [logo_feature])
>>> results = batch.detect()
>>> for image in results:
... for face in image.faces:
... print('=' * 40)
... print(face.joy)
========================================
<Likelihood.VERY_LIKELY: 'VERY_LIKELY'>
========================================
<Likelihood.VERY_LIKELY: 'POSSIBLE'>
****************
No results found
****************
Expand Down
52 changes: 52 additions & 0 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from google.cloud import storage
from google.cloud import vision
from google.cloud.vision.entity import EntityAnnotation
from google.cloud.vision.feature import Feature
from google.cloud.vision.feature import FeatureTypes

from system_test_utils import unique_resource_id
from retry import RetryErrors
Expand Down Expand Up @@ -507,3 +509,53 @@ def test_detect_properties_filename(self):
image = client.image(filename=FACE_FILE)
properties = image.detect_properties()
self._assert_properties(properties)


class TestVisionBatchProcessing(BaseVisionTestCase):
def setUp(self):
self.to_delete_by_case = []

def tearDown(self):
for value in self.to_delete_by_case:
value.delete()

def test_batch_detect_gcs(self):
client = Config.CLIENT
bucket_name = Config.TEST_BUCKET.name

# Logo GCS image.
blob_name = 'logos.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
self.to_delete_by_case.append(blob) # Clean-up.
with open(LOGO_FILE, 'rb') as file_obj:
blob.upload_from_file(file_obj)

logo_source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

image_one = client.image(source_uri=logo_source_uri)
logo_feature = Feature(FeatureTypes.LOGO_DETECTION, 2)

# Faces GCS image.
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
self.to_delete_by_case.append(blob) # Clean-up.
with open(FACE_FILE, 'rb') as file_obj:
blob.upload_from_file(file_obj)

face_source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

image_two = client.image(source_uri=face_source_uri)
face_feature = Feature(FeatureTypes.FACE_DETECTION, 2)

batch = client.batch()
batch.add_image(image_one, [logo_feature])
batch.add_image(image_two, [face_feature, logo_feature])
results = batch.detect()
self.assertEqual(len(results), 2)
self.assertIsInstance(results[0], vision.annotations.Annotations)
self.assertIsInstance(results[1], vision.annotations.Annotations)
self.assertEqual(len(results[0].logos), 1)
self.assertEqual(len(results[0].faces), 0)

self.assertEqual(len(results[1].logos), 0)
self.assertEqual(len(results[1].faces), 2)
26 changes: 15 additions & 11 deletions vision/google/cloud/vision/_gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,28 @@ def __init__(self, client=None):
self._client = client
self._annotator_client = image_annotator_client.ImageAnnotatorClient()

def annotate(self, image, features):
def annotate(self, images):
"""Annotate images through GAX.
:type image: :class:`~google.cloud.vision.image.Image`
:param image: Instance of ``Image``.
:type features: list
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
:type images: list
:param images: List containing pairs of
:class:`~google.cloud.vision.image.Image` and
:class:`~google.cloud.vision.feature.Feature`.
e.g. [(image, [feature_one, feature_two]),]
:rtype: list
:returns: List of
:class:`~google.cloud.vision.annotations.Annotations`.
"""
gapic_features = [_to_gapic_feature(feature) for feature in features]
gapic_image = _to_gapic_image(image)
request = image_annotator_pb2.AnnotateImageRequest(
image=gapic_image, features=gapic_features)
requests = [request]
requests = []
for image, features in images:
gapic_features = [_to_gapic_feature(feature)
for feature in features]
gapic_image = _to_gapic_image(image)
request = image_annotator_pb2.AnnotateImageRequest(
image=gapic_image, features=gapic_features)
requests.append(request)

annotator_client = self._annotator_client
responses = annotator_client.batch_annotate_images(requests).responses
return [Annotations.from_pb(response) for response in responses]
Expand Down
19 changes: 7 additions & 12 deletions vision/google/cloud/vision/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,19 @@ def __init__(self, client):
self._client = client
self._connection = client._connection

def annotate(self, image, features):
def annotate(self, images):
"""Annotate an image to discover it's attributes.
:type image: :class:`~google.cloud.vision.image.Image`
:param image: A instance of ``Image``.
:type images: list of :class:`~google.cloud.vision.image.Image`
:param images: A list of ``Image``.
:type features: list of :class:`~google.cloud.vision.feature.Feature`
:param features: The type of detection that the Vision API should
use to determine image attributes. Pricing is
based on the number of Feature Types.
See: https://cloud.google.com/vision/docs/pricing
:rtype: list
:returns: List of :class:`~googe.cloud.vision.annotations.Annotations`.
"""
request = _make_request(image, features)

data = {'requests': [request]}
requests = []
for image, features in images:
requests.append(_make_request(image, features))
data = {'requests': requests}
api_response = self._connection.api_request(
method='POST', path='/images:annotate', data=data)
responses = api_response.get('responses')
Expand Down
57 changes: 57 additions & 0 deletions vision/google/cloud/vision/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Batch multiple images into one request."""


class Batch(object):
"""Batch of images to process.
:type client: :class:`~google.cloud.vision.client.Client`
:param client: Vision client.
"""
def __init__(self, client):
self._client = client
self._images = []

def add_image(self, image, features):
"""Add image to batch request.
:type image: :class:`~google.cloud.vision.image.Image`
:param image: Istance of ``Image``.
:type features: list
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
"""
self._images.append((image, features))

@property
def images(self):
"""List of images to process.
:rtype: list
:returns: List of :class:`~google.cloud.vision.image.Image`.
"""
return self._images

def detect(self):
"""Perform batch detection of images.
:rtype: list
:returns: List of
:class:`~google.cloud.vision.annotations.Annotations`.
"""
results = self._client._vision_api.annotate(self.images)
self._images = []
return results
9 changes: 9 additions & 0 deletions vision/google/cloud/vision/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from google.cloud.environment_vars import DISABLE_GRPC

from google.cloud.vision._gax import _GAPICVisionAPI
from google.cloud.vision.batch import Batch
from google.cloud.vision.connection import Connection
from google.cloud.vision.image import Image
from google.cloud.vision._http import _HTTPVisionAPI
Expand Down Expand Up @@ -71,6 +72,14 @@ def __init__(self, project=None, credentials=None, http=None,
else:
self._use_gax = use_gax

def batch(self):
"""Batch multiple images into a single API request.
:rtype: :class:`google.cloud.vision.batch.Batch`
:returns: Instance of ``Batch``.
"""
return Batch(self)

def image(self, content=None, filename=None, source_uri=None):
"""Get instance of Image using current client.
Expand Down
31 changes: 14 additions & 17 deletions vision/google/cloud/vision/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,17 @@ def source(self):
"""
return self._source

def _detect_annotation(self, features):
def _detect_annotation(self, images):
"""Generic method for detecting annotations.
:type features: list
:param features: List of :class:`~google.cloud.vision.feature.Feature`
indicating the type of annotations to perform.
:type images: list
:param images: List of :class:`~google.cloud.vision.image.Image`.
:rtype: list
:returns: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`,
:class:`~google.cloud.vision.face.Face`,
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
:class:`~google.cloud.vision.annotations.Annotations`.
"""
return self.client._vision_api.annotate(self, features)
return self.client._vision_api.annotate(images)

def detect(self, features):
"""Detect multiple feature types.
Expand All @@ -121,7 +117,8 @@ def detect(self, features):
:returns: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
return self._detect_annotation(features)
images = ((self, features),)
return self._detect_annotation(images)

def detect_faces(self, limit=10):
"""Detect faces in image.
Expand All @@ -133,7 +130,7 @@ def detect_faces(self, limit=10):
:returns: List of :class:`~google.cloud.vision.face.Face`.
"""
features = [Feature(FeatureTypes.FACE_DETECTION, limit)]
annotations = self._detect_annotation(features)
annotations = self.detect(features)
return annotations[0].faces

def detect_labels(self, limit=10):
Expand All @@ -146,7 +143,7 @@ def detect_labels(self, limit=10):
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
"""
features = [Feature(FeatureTypes.LABEL_DETECTION, limit)]
annotations = self._detect_annotation(features)
annotations = self.detect(features)
return annotations[0].labels

def detect_landmarks(self, limit=10):
Expand All @@ -160,7 +157,7 @@ def detect_landmarks(self, limit=10):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)]
annotations = self._detect_annotation(features)
annotations = self.detect(features)
return annotations[0].landmarks

def detect_logos(self, limit=10):
Expand All @@ -174,7 +171,7 @@ def detect_logos(self, limit=10):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
features = [Feature(FeatureTypes.LOGO_DETECTION, limit)]
annotations = self._detect_annotation(features)
annotations = self.detect(features)
return annotations[0].logos

def detect_properties(self, limit=10):
Expand All @@ -188,7 +185,7 @@ def detect_properties(self, limit=10):
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
"""
features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)]
annotations = self._detect_annotation(features)
annotations = self.detect(features)
return annotations[0].properties

def detect_safe_search(self, limit=10):
Expand All @@ -202,7 +199,7 @@ def detect_safe_search(self, limit=10):
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`.
"""
features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)]
annotations = self._detect_annotation(features)
annotations = self.detect(features)
return annotations[0].safe_searches

def detect_text(self, limit=10):
Expand All @@ -216,5 +213,5 @@ def detect_text(self, limit=10):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
features = [Feature(FeatureTypes.TEXT_DETECTION, limit)]
annotations = self._detect_annotation(features)
annotations = self.detect(features)
return annotations[0].texts
Loading

0 comments on commit db0dc85

Please sign in to comment.