-
Notifications
You must be signed in to change notification settings - Fork 606
Transformer model sentiment analysis example #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
eb77e7b
mnist basic t2t model
1vn b2448b8
merge master
1vn 9917cd7
add newline
1vn 8edcb58
fix prediction time shaping
1vn a6e0f47
clean reviews example
1vn df24557
if undefined shape, take the length
1vn 6202511
Merge branch 'master' into t2t-example
1vn b5c60a3
add numpy to api image
1vn 7145df8
remove numpy dep, dont restrict unspecified python pkgs
1vn e4a01c9
add TODO comment to address later
1vn f8c128d
clean up
1vn 17144a2
clean up example and transform tensor api
1vn e73d95c
transform_tensors -> transform_tensorflow
1vn 99e0b2a
add back dnn
1vn ad0be81
add back dnn
1vn e9e7c92
fix example
1vn 9be820a
remove TODO
1vn 5b5263b
add docs
1vn 4b27fa3
checkin
1vn 43b1a6e
checkin
1vn f204061
checkin
1vn a9fed6b
merge master
1vn dbec6bf
transformer model
1vn 51b8854
remove extraenous changes
1vn b6e1271
clean up
1vn 73728c8
remove unused transformer
1vn 91f6c07
remove unused transformed column
1vn e3b9315
clean up
1vn aa950fe
Merge branch 'master' into t2t-blog
1vn 96cd0e5
address comments, updates to reflect blog post
1vn 346a224
clean up code, remove constants
1vn 1bcd2a3
add back gpu
1vn 729e613
fix code
1vn c41716f
fix implementation
1vn 765b46f
Merge branch 'master' into t2t-blog
1vn 2c26ecd
address comments
1vn c6f8c6d
dont use set
1vn d08b5f1
remove stopwords
1vn 6140ca3
tune model
1vn 2e2c5b1
rearrange yaml
1vn bc91201
fix YAML
1vn a81730e
remove whitespace
1vn 52a7360
Merge branch 'master' into t2t-blog
1vn 9cc5fc8
remove extra new line
1vn 5b7875d
bump spark
1vn 03eb2f6
bump spark
1vn 8c87774
address comments
1vn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
examples/reviews/implementations/models/t2t_transformer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import tensorflow as tf | ||
from tensor2tensor.utils import trainer_lib | ||
from tensor2tensor import models # pylint: disable=unused-import | ||
from tensor2tensor import problems # pylint: disable=unused-import | ||
from tensor2tensor.data_generators import problem_hparams | ||
from tensor2tensor.utils import registry | ||
from tensor2tensor.utils import metrics | ||
from tensor2tensor.data_generators import imdb | ||
from tensor2tensor.data_generators import text_encoder | ||
|
||
|
||
def create_estimator(run_config, model_config): | ||
# t2t expects these keys in run_config | ||
run_config.data_parallelism = None | ||
run_config.t2t_device_info = {"num_async_replicas": 1} | ||
|
||
hparams = trainer_lib.create_hparams("transformer_base_single_gpu") | ||
|
||
problem = SentimentIMDBCortex(list(model_config["aggregates"]["reviews_vocab"])) | ||
p_hparams = problem.get_hparams(hparams) | ||
hparams.problem = problem | ||
hparams.problem_hparams = p_hparams | ||
|
||
problem.eval_metrics = lambda: [ | ||
metrics.Metrics.ACC_TOP5, | ||
metrics.Metrics.ACC_PER_SEQ, | ||
metrics.Metrics.NEG_LOG_PERPLEXITY, | ||
] | ||
|
||
# t2t expects this key | ||
hparams.warm_start_from = None | ||
|
||
# reduce memory load | ||
hparams.num_hidden_layers = 2 | ||
hparams.hidden_size = 32 | ||
hparams.filter_size = 32 | ||
hparams.num_heads = 2 | ||
|
||
estimator = trainer_lib.create_estimator("transformer", hparams, run_config) | ||
return estimator | ||
|
||
|
||
def transform_tensorflow(features, labels, model_config): | ||
max_length = model_config["aggregates"]["max_review_length"] | ||
|
||
features["inputs"] = tf.expand_dims(tf.reshape(features["embedding_input"], [max_length]), -1) | ||
features["targets"] = tf.expand_dims(tf.expand_dims(labels, -1), -1) | ||
|
||
return features, labels | ||
|
||
|
||
class SentimentIMDBCortex(imdb.SentimentIMDB): | ||
"""IMDB sentiment classification, with an in-memory vocab""" | ||
|
||
def __init__(self, vocab_list): | ||
super().__init__() | ||
self.vocab = vocab_list | ||
|
||
def feature_encoders(self, data_dir): | ||
encoder = text_encoder.TokenTextEncoder(vocab_filename=None, vocab_list=self.vocab) | ||
|
||
return { | ||
"inputs": encoder, | ||
"targets": text_encoder.ClassLabelEncoder(self.class_labels(data_dir)), | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tensor2tensor==1.10.0 |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.