Skip to content

Commit

Permalink
Re-factoring usage of mocks in language.
Browse files Browse the repository at this point in the history
This way repeated code is reduced to a single call site.
  • Loading branch information
dhermes committed Nov 14, 2016
1 parent 8d45b39 commit c75e05c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 53 deletions.
46 changes: 16 additions & 30 deletions language/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
import unittest


def make_mock_credentials():
import mock
from oauth2client.client import GoogleCredentials

credentials = mock.Mock(spec=GoogleCredentials)
credentials.create_scoped_required.return_value = False
return credentials


class TestClient(unittest.TestCase):

@staticmethod
Expand All @@ -26,25 +35,19 @@ def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def test_ctor(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.connection import Connection

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
creds = make_mock_credentials()
http = object()
client = self._make_one(credentials=creds, http=http)
self.assertIsInstance(client._connection, Connection)
self.assertIs(client._connection.credentials, creds)
self.assertIs(client._connection.http, http)

def test_document_from_text_factory(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

content = 'abc'
Expand All @@ -59,23 +62,16 @@ def test_document_from_text_factory(self):
self.assertEqual(document.language, language)

def test_document_from_text_factory_failure(self):
import mock
from oauth2client.client import GoogleCredentials

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

with self.assertRaises(TypeError):
client.document_from_text('abc', doc_type='foo')

def test_document_from_html_factory(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

content = '<html>abc</html>'
Expand All @@ -90,23 +86,16 @@ def test_document_from_html_factory(self):
self.assertEqual(document.language, language)

def test_document_from_html_factory_failure(self):
import mock
from oauth2client.client import GoogleCredentials

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

with self.assertRaises(TypeError):
client.document_from_html('abc', doc_type='foo')

def test_document_from_url_factory(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

gcs_url = 'gs://my-text-bucket/sentiment-me.txt'
Expand All @@ -118,13 +107,10 @@ def test_document_from_url_factory(self):
self.assertEqual(document.doc_type, Document.PLAIN_TEXT)

def test_document_from_url_factory_explicit(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document
from google.cloud.language.document import Encoding

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

encoding = Encoding.UTF32
Expand Down
39 changes: 16 additions & 23 deletions language/unit_tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ def _get_entities(include_entities):
return entities


def make_mock_client(response):
import mock
from google.cloud.language.connection import Connection
from google.cloud.language.client import Client

connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
return mock.Mock(_connection=connection, spec=Client)


class TestDocument(unittest.TestCase):

@staticmethod
Expand Down Expand Up @@ -216,9 +226,6 @@ def _expected_data(content, encoding_type=None,
return expected

def test_analyze_entities(self):
import mock
from google.cloud.language.connection import Connection
from google.cloud.language.client import Client
from google.cloud.language.document import Encoding
from google.cloud.language.entity import EntityType

Expand Down Expand Up @@ -261,9 +268,7 @@ def test_analyze_entities(self):
],
'language': 'en-US',
}
connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
client = mock.Mock(_connection=connection, spec=Client)
client = make_mock_client(response)
document = self._make_one(client, content)

entities = document.analyze_entities()
Expand All @@ -278,7 +283,7 @@ def test_analyze_entities(self):
# Verify the request.
expected = self._expected_data(
content, encoding_type=Encoding.UTF8)
connection.api_request.assert_called_once_with(
client._connection.api_request.assert_called_once_with(
path='analyzeEntities', method='POST', data=expected)

def _verify_sentiment(self, sentiment, polarity, magnitude):
Expand All @@ -289,10 +294,6 @@ def _verify_sentiment(self, sentiment, polarity, magnitude):
self.assertEqual(sentiment.magnitude, magnitude)

def test_analyze_sentiment(self):
import mock
from google.cloud.language.connection import Connection
from google.cloud.language.client import Client

content = 'All the pretty horses.'
polarity = 1
magnitude = 0.6
Expand All @@ -303,17 +304,15 @@ def test_analyze_sentiment(self):
},
'language': 'en-US',
}
connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
client = mock.Mock(_connection=connection, spec=Client)
client = make_mock_client(response)
document = self._make_one(client, content)

sentiment = document.analyze_sentiment()
self._verify_sentiment(sentiment, polarity, magnitude)

# Verify the request.
expected = self._expected_data(content)
connection.api_request.assert_called_once_with(
client._connection.api_request.assert_called_once_with(
path='analyzeSentiment', method='POST', data=expected)

def _verify_sentences(self, include_syntax, annotations):
Expand Down Expand Up @@ -343,10 +342,6 @@ def _verify_tokens(self, annotations, token_info):

def _annotate_text_helper(self, include_sentiment,
include_entities, include_syntax):
import mock

from google.cloud.language.connection import Connection
from google.cloud.language.client import Client
from google.cloud.language.document import Annotations
from google.cloud.language.document import Encoding
from google.cloud.language.entity import EntityType
Expand All @@ -366,9 +361,7 @@ def _annotate_text_helper(self, include_sentiment,
'magnitude': ANNOTATE_MAGNITUDE,
}

connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
client = mock.Mock(_connection=connection, spec=Client)
client = make_mock_client(response)
document = self._make_one(client, ANNOTATE_CONTENT)

annotations = document.annotate_text(
Expand Down Expand Up @@ -400,7 +393,7 @@ def _annotate_text_helper(self, include_sentiment,
extract_sentiment=include_sentiment,
extract_entities=include_entities,
extract_syntax=include_syntax)
connection.api_request.assert_called_once_with(
client._connection.api_request.assert_called_once_with(
path='annotateText', method='POST', data=expected)

def test_annotate_text(self):
Expand Down

0 comments on commit c75e05c

Please sign in to comment.