diff --git a/samples/snippets/snippets.py b/samples/snippets/snippets.py index 2b754ace..30b591a4 100644 --- a/samples/snippets/snippets.py +++ b/samples/snippets/snippets.py @@ -253,12 +253,61 @@ def entity_sentiment_file(gcs_uri): print(u'Sentiment: {}\n'.format(entity.sentiment)) +# [START def_classify_text] +def classify_text(text): + """Classifies content categories of the provided text.""" + client = language.LanguageServiceClient() + + if isinstance(text, six.binary_type): + text = text.decode('utf-8') + + document = types.Document( + content=text.encode('utf-8'), + type=enums.Document.Type.PLAIN_TEXT) + + categories = client.classify_text(document).categories + + for category in categories: + print(u'=' * 20) + print(u'{:<16}: {}'.format('name', category.name)) + print(u'{:<16}: {}'.format('confidence', category.confidence)) +# [END def_classify_text] + + +# [START def_classify_file] +def classify_file(gcs_uri): + """Classifies content categories of the text in a Google Cloud Storage + file. + """ + client = language.LanguageServiceClient() + + document = types.Document( + gcs_content_uri=gcs_uri, + type=enums.Document.Type.PLAIN_TEXT) + + categories = client.classify_text(document).categories + + for category in categories: + print(u'=' * 20) + print(u'{:<16}: {}'.format('name', category.name)) + print(u'{:<16}: {}'.format('confidence', category.confidence)) +# [END def_classify_file] + + if __name__ == '__main__': parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) subparsers = parser.add_subparsers(dest='command') + classify_text_parser = subparsers.add_parser( + 'classify-text', help=classify_text.__doc__) + classify_text_parser.add_argument('text') + + classify_text_parser = subparsers.add_parser( + 'classify-file', help=classify_file.__doc__) + classify_text_parser.add_argument('gcs_uri') + sentiment_entities_text_parser = subparsers.add_parser( 'sentiment-entities-text', help=entity_sentiment_text.__doc__) sentiment_entities_text_parser.add_argument('text') @@ -309,3 +358,7 @@ def entity_sentiment_file(gcs_uri): entity_sentiment_text(args.text) elif args.command == 'sentiment-entities-file': entity_sentiment_file(args.gcs_uri) + elif args.command == 'classify-text': + classify_text(args.text) + elif args.command == 'classify-file': + classify_file(args.gcs_uri) diff --git a/samples/snippets/snippets_test.py b/samples/snippets/snippets_test.py index 168701dc..27fbee24 100644 --- a/samples/snippets/snippets_test.py +++ b/samples/snippets/snippets_test.py @@ -19,6 +19,7 @@ BUCKET = os.environ['CLOUD_STORAGE_BUCKET'] TEST_FILE_URL = 'gs://{}/text.txt'.format(BUCKET) +LONG_TEST_FILE_URL = 'gs://{}/android_text.txt'.format(BUCKET) def test_sentiment_text(capsys): @@ -77,3 +78,20 @@ def test_sentiment_entities_utf(capsys): 'foo→bar') out, _ = capsys.readouterr() assert 'Begin Offset : 4' in out + + +def test_classify_text(capsys): + snippets.classify_text( + 'Android is a mobile operating system developed by Google, ' + 'based on the Linux kernel and designed primarily for touchscreen ' + 'mobile devices such as smartphones and tablets.') + out, _ = capsys.readouterr() + assert 'name' in out + assert '/Computers & Electronics' in out + + +def test_classify_file(capsys): + snippets.classify_file(LONG_TEST_FILE_URL) + out, _ = capsys.readouterr() + assert 'name' in out + assert '/Computers & Electronics' in out