Skip to content

Commit

Permalink
Provide full Mention objects, not just the content. (googleapis#3156)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukesneeringer authored Mar 16, 2017
1 parent 5921766 commit e135dba
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 16 deletions.
77 changes: 75 additions & 2 deletions language/google/cloud/language/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,80 @@ class EntityType(object):
"""Other entity type (i.e. known but not classified)."""


class MentionType(object):
"""List of possible mention types."""

TYPE_UNKNOWN = 'TYPE_UNKNOWN'
"""Unknown mention type"""

PROPER = 'PROPER'
"""Proper name"""

COMMON = 'COMMON'
"""Common noun (or noun compound)"""


class Mention(object):
"""A Google Cloud Natural Language API mention.
Represents a mention for an entity in the text. Currently, proper noun
mentions are supported.
"""
def __init__(self, text, mention_type):
self.text = text
self.mention_type = mention_type

def __str__(self):
return str(self.text)

@classmethod
def from_api_repr(cls, payload):
"""Convert a Mention from the JSON API into an :class:`Mention`.
:param payload: dict
:type payload: The value from the backend.
:rtype: :class:`Mention`
:returns: The mention parsed from the API representation.
"""
text = TextSpan.from_api_repr(payload['text'])
mention_type = payload['type']
return cls(text, mention_type)


class TextSpan(object):
"""A span of text from Google Cloud Natural Language API.
Represents a word or phrase of text, as well as its offset
from the original document.
"""
def __init__(self, content, begin_offset):
self.content = content
self.begin_offset = begin_offset

def __str__(self):
"""Return the string representation of this TextSpan.
:rtype: str
:returns: The text content
"""
return self.content

@classmethod
def from_api_repr(cls, payload):
"""Convert a TextSpan from the JSON API into an :class:`TextSpan`.
:param payload: dict
:type payload: The value from the backend.
:rtype: :class:`TextSpan`
:returns: The text span parsed from the API representation.
"""
content = payload['content']
begin_offset = payload['beginOffset']
return cls(content=content, begin_offset=begin_offset)


class Entity(object):
"""A Google Cloud Natural Language API entity.
Expand Down Expand Up @@ -101,6 +175,5 @@ def from_api_repr(cls, payload):
entity_type = payload['type']
metadata = payload['metadata']
salience = payload['salience']
mentions = [value['text']['content']
for value in payload['mentions']]
mentions = [Mention.from_api_repr(val) for val in payload['mentions']]
return cls(name, entity_type, metadata, salience, mentions)
9 changes: 7 additions & 2 deletions language/unit_tests/test_api_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

class TestEntityResponse(unittest.TestCase):
ENTITY_DICT = {
'mentions': [{'text': {'content': 'Italian'}}],
'mentions': [{
'text': {'content': 'Italian', 'beginOffset': 0},
'type': 'PROPER',
}],
'metadata': {'wikipedia_url': 'http://en.wikipedia.org/wiki/Italy'},
'name': 'Italian',
'salience': 0.15,
Expand Down Expand Up @@ -46,12 +49,14 @@ def test_api_repr_factory(self):

def _verify_entity_response(self, entity_response):
from google.cloud.language.entity import EntityType
from google.cloud.language.entity import Mention

self.assertEqual(len(entity_response.entities), 1)
entity = entity_response.entities[0]
self.assertEqual(entity.name, 'Italian')
self.assertEqual(len(entity.mentions), 1)
self.assertEqual(entity.mentions[0], 'Italian')
self.assertIsInstance(entity.mentions[0], Mention)
self.assertEqual(str(entity.mentions[0]), 'Italian')
self.assertTrue(entity.metadata['wikipedia_url'].endswith('Italy'))
self.assertAlmostEqual(entity.salience, 0.15)
self.assertEqual(entity.entity_type, EntityType.LOCATION)
Expand Down
10 changes: 7 additions & 3 deletions language/unit_tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def _get_entities(include_entities):
'text': {
'content': ANNOTATE_NAME,
'beginOffset': -1
}
},
'type': 'TYPE_UNKNOWN',
}
]
},
Expand Down Expand Up @@ -215,7 +216,8 @@ def _verify_entity(self, entity, name, entity_type, wiki_url, salience):
else:
self.assertEqual(entity.metadata, {})
self.assertEqual(entity.salience, salience)
self.assertEqual(entity.mentions, [name])
self.assertEqual(len(entity.mentions), 1)
self.assertEqual(entity.mentions[0].text.content, name)

@staticmethod
def _expected_data(content, encoding_type=None,
Expand Down Expand Up @@ -265,7 +267,8 @@ def test_analyze_entities(self):
'text': {
'content': name1,
'beginOffset': -1
}
},
'type': 'TYPE_UNKNOWN',
}
]
},
Expand All @@ -280,6 +283,7 @@ def test_analyze_entities(self):
'content': name2,
'beginOffset': -1,
},
'type': 'PROPER',
},
],
},
Expand Down
104 changes: 98 additions & 6 deletions language/unit_tests/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def test_constructor_defaults(self):
from google.cloud.language.entity import Mention
from google.cloud.language.entity import MentionType
from google.cloud.language.entity import TextSpan

name = 'Italian'
entity_type = 'LOCATION'
wiki_url = 'http://en.wikipedia.org/wiki/Italy'
Expand All @@ -35,7 +39,10 @@ def test_constructor_defaults(self):
'wikipedia_url': wiki_url,
}
salience = 0.19960518
mentions = ['Italian']
mentions = [Mention(
mention_type=MentionType.PROPER,
text=TextSpan(content='Italian', begin_offset=0),
)]
entity = self._make_one(name, entity_type, metadata,
salience, mentions)
self.assertEqual(entity.name, name)
Expand All @@ -45,9 +52,13 @@ def test_constructor_defaults(self):
self.assertEqual(entity.mentions, mentions)

def test_from_api_repr(self):
from google.cloud.language.entity import EntityType
from google.cloud.language.entity import Mention
from google.cloud.language.entity import MentionType

klass = self._get_target_class()
name = 'Italy'
entity_type = 'LOCATION'
entity_type = EntityType.LOCATION
salience = 0.223
wiki_url = 'http://en.wikipedia.org/wiki/Italy'
mention1 = 'Italy'
Expand All @@ -59,14 +70,95 @@ def test_from_api_repr(self):
'salience': salience,
'metadata': {'wikipedia_url': wiki_url},
'mentions': [
{'text': {'content': mention1}},
{'text': {'content': mention2}},
{'text': {'content': mention3}},
{'text': {'content': mention1, 'beginOffset': 3},
'type': 'PROPER'},
{'text': {'content': mention2, 'beginOffset': 5},
'type': 'PROPER'},
{'text': {'content': mention3, 'beginOffset': 8},
'type': 'PROPER'},
],
}
entity = klass.from_api_repr(payload)
self.assertEqual(entity.name, name)
self.assertEqual(entity.entity_type, entity_type)
self.assertEqual(entity.salience, salience)
self.assertEqual(entity.metadata, {'wikipedia_url': wiki_url})
self.assertEqual(entity.mentions, [mention1, mention2, mention3])

# Assert that we got back Mention objects for each mention.
self.assertIsInstance(entity.mentions[0], Mention)
self.assertIsInstance(entity.mentions[1], Mention)
self.assertIsInstance(entity.mentions[2], Mention)

# Assert that the text (and string coercison) are correct.
self.assertEqual([str(i) for i in entity.mentions],
[mention1, mention2, mention3])

# Assert that the begin offsets are preserved.
self.assertEqual([i.text.begin_offset for i in entity.mentions],
[3, 5, 8])

# Assert that the mention types are preserved.
for mention in entity.mentions:
self.assertEqual(mention.mention_type, MentionType.PROPER)


class TestMention(unittest.TestCase):
PAYLOAD = {
'text': {'content': 'Greece', 'beginOffset': 42},
'type': 'PROPER',
}

def test_constructor(self):
from google.cloud.language.entity import Mention
from google.cloud.language.entity import MentionType
from google.cloud.language.entity import TextSpan

mention = Mention(
text=TextSpan(content='snails', begin_offset=90),
mention_type=MentionType.COMMON,
)

self.assertIsInstance(mention.text, TextSpan)
self.assertEqual(mention.text.content, 'snails')
self.assertEqual(mention.text.begin_offset, 90)
self.assertEqual(mention.mention_type, MentionType.COMMON)

def test_from_api_repr(self):
from google.cloud.language.entity import Mention
from google.cloud.language.entity import MentionType
from google.cloud.language.entity import TextSpan

mention = Mention.from_api_repr(self.PAYLOAD)

self.assertIsInstance(mention, Mention)
self.assertIsInstance(mention.text, TextSpan)
self.assertEqual(mention.text.content, 'Greece')
self.assertEqual(mention.text.begin_offset, 42)
self.assertEqual(mention.mention_type, MentionType.PROPER)

def test_dunder_str(self):
from google.cloud.language.entity import Mention

mention = Mention.from_api_repr(self.PAYLOAD)
self.assertEqual(str(mention), 'Greece')


class TestTextSpan(unittest.TestCase):
def test_constructor(self):
from google.cloud.language.entity import TextSpan

text = TextSpan(content='Winston Churchill', begin_offset=1945)
self.assertIsInstance(text, TextSpan)
self.assertEqual(text.content, str(text), 'Winston Churchill')
self.assertEqual(text.begin_offset, 1945)

def test_from_api_repr(self):
from google.cloud.language.entity import TextSpan

text = TextSpan.from_api_repr({
'beginOffset': 1953,
'content': 'Queen Elizabeth',
})
self.assertIsInstance(text, TextSpan)
self.assertEqual(text.content, str(text), 'Queen Elizabeth')
self.assertEqual(text.begin_offset, 1953)
6 changes: 3 additions & 3 deletions system_tests/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def _check_analyze_entities_result(self, entities):
self.assertEqual(entity1.entity_type, EntityType.PERSON)
self.assertGreater(entity1.salience, 0.0)
# Other mentions may occur, e.g. "painter".
self.assertIn(entity1.name, entity1.mentions)
self.assertIn(entity1.name, [str(i) for i in entity1.mentions])
self.assertEqual(entity1.metadata['wikipedia_url'],
'http://en.wikipedia.org/wiki/Caravaggio')
self.assertIsInstance(entity1.metadata, dict)
# Verify entity 2.
self.assertEqual(entity2.name, self.NAME2)
self.assertEqual(entity2.entity_type, EntityType.LOCATION)
self.assertGreater(entity2.salience, 0.0)
self.assertEqual(entity2.mentions, [entity2.name])
self.assertEqual([str(i) for i in entity2.mentions], [entity2.name])
self.assertEqual(entity2.metadata['wikipedia_url'],
'http://en.wikipedia.org/wiki/Italy')
self.assertIsInstance(entity2.metadata, dict)
Expand All @@ -92,7 +92,7 @@ def _check_analyze_entities_result(self, entities):
choices = (EntityType.EVENT, EntityType.WORK_OF_ART)
self.assertIn(entity3.entity_type, choices)
self.assertGreater(entity3.salience, 0.0)
self.assertEqual(entity3.mentions, [entity3.name])
self.assertEqual([str(i) for i in entity3.mentions], [entity3.name])
wiki_url = ('http://en.wikipedia.org/wiki/'
'The_Calling_of_St_Matthew_(Caravaggio)')
self.assertEqual(entity3.metadata['wikipedia_url'], wiki_url)
Expand Down

0 comments on commit e135dba

Please sign in to comment.