-
Notifications
You must be signed in to change notification settings - Fork 115
Add text style transfer (#166) #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
swapnull7
wants to merge
21
commits into
asyml:master
Choose a base branch
from
swapnull7:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
3d2db27
Add text style transfer (#1)
swapnull7 e401929
Add text style transfer with improvements (#2)
swapnull7 7bb76b7
restore optimizers
swapnull7 28e930a
Merge branch 'master' of github.com-personal:swapnull7/texar-pytorch
swapnull7 5999f69
Update ctrl_gen_model.py
swapnull7 9f0ac5d
remove tensorflow import
swapnull7 c62e3b7
Add text style transfer (#3)
swapnull7 6c5b81f
Add text style transfer (#4)
swapnull7 9966731
Merge remote-tracking branch 'upstream/master'
swapnull7 9ce07e5
Add text style transfer (#5)
swapnull7 adca7fa
Merge branch 'master' of github.com-personal:swapnull7/texar-pytorch
swapnull7 6055075
Add text style transfer (#6)
swapnull7 6d7a0cd
Fix docs build issue
swapnull7 460d7a5
Merge remote-tracking branch 'upstream/master'
swapnull7 a6a0472
Fix typo
swapnull7 6f63f28
Make sure all variables are appended only once
swapnull7 4d2b2a0
Merge remote-tracking branch 'upstream/master'
swapnull7 fa07314
Merge remote-tracking branch 'upstream/master'
swapnull7 8927039
Update main.py
swapnull7 ca88d14
fix docstrings
swapnull7 3776f4b
Merge remote-tracking branch 'upstream/master'
swapnull7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Text Style Transfer # | ||
|
||
This example implements a simplified variant of the `ctrl-gen` model from | ||
|
||
[Toward Controlled Generation of Text](https://arxiv.org/pdf/1703.00955.pdf) | ||
*Zhiting Hu, Zichao Yang, Xiaodan Liang, Ruslan Salakhutdinov, Eric Xing; ICML 2017* | ||
|
||
The model roughly has an architecture of `Encoder--Decoder--Classifier`. Compared to the paper, following simplications are made: | ||
|
||
* Replaces the base Variational Autoencoder (VAE) model with an attentional Autoencoder (AE) -- VAE is not necessary in the text style transfer setting since we do not need to interpolate the latent space as in the paper. | ||
* Attribute classifier (i.e., discriminator) is trained with real data only. Samples generated by the decoder are not used. | ||
* Independency constraint is omitted. | ||
|
||
## Usage ## | ||
|
||
### Dataset ### | ||
Download the yelp sentiment dataset with the following command: | ||
``` | ||
python prepare_data.py | ||
``` | ||
|
||
### Train the model ### | ||
|
||
Train the model on the above data to do sentiment transfer. | ||
``` | ||
python main.py --config config | ||
``` | ||
|
||
[config.py](./config.py) contains the data and mode configurations. | ||
|
||
* The model will first be pre-trained for a few epochs (specified in `config.py`). During pre-training, the `Encoder-Decoder` part is trained as an autoencoder, while the `Classifier` part is trained with the classification labels. | ||
* Full-training is then performed for another few epochs. During full-training, the `Classifier` part is fixed, and the `Encoder-Decoder` part is trained to fit the classifier, along with continuing to minimize the autoencoding loss. | ||
|
||
(**Note:** When using your own dataset, make sure to set `max_decoding_length_train` and `max_decoding_length_infer` in [config.py](https://github.com/asyml/texar/blob/master/examples/text_style_transfer/config.py#L85-L86).) | ||
|
||
Training log is printed as below: | ||
``` | ||
swapnull7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
gamma: 1.0, lambda_g: 0.0 | ||
step: 1, loss_d: 0.6934 accu_d: 0.4844 | ||
step: 1, loss_g_ae: 9.1392 | ||
step: 500, loss_d: 0.1488 accu_d: 0.9484 | ||
step: 500, loss_g_ae: 4.2884 | ||
step: 1000, loss_d: 0.1215 accu_d: 0.9625 | ||
step: 1000, loss_g_ae: 2.6201 | ||
... | ||
epoch: 1, loss_d: 0.0750 accu_d: 0.9688 | ||
epoch: 1, loss_g_ae: 0.8832 | ||
val: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2949 loss_d: 0.0702 accu_d: 0.9744 accu_g: 0.3022 accu_g_gdy: 0.2732 bleu: 60.8234 | ||
test: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2359 loss_d: 0.0746 accu_d: 0.9733 accu_g: 0.3076 accu_g_gdy: 0.2791 bleu: 60.1810993 accu_g_gdy: 0.5993 bleu: 63.6671 | ||
... | ||
|
||
``` | ||
where: | ||
- `loss_d` and `accu_d` are the classification loss/accuracy of the `Classifier` part. | ||
- `loss_g_class` is the classification loss of the generated sentences. | ||
- `loss_g_ae` is the autoencoding loss. | ||
- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_class`. | ||
- `accu_g` is the classification accuracy of the generated sentences with soft represetations (i.e., Gumbel-softmax). | ||
- `accu_g_gdy` is the classification accuracy of the generated sentences with greedy decoding. | ||
- `bleu` is the BLEU score between the generated and input sentences. | ||
|
||
## Results ## | ||
|
||
Text style transfer has two primary goals: | ||
1. The generated sentence should have desired attribute (e.g., positive/negative sentiment) | ||
2. The generated sentence should keep the content of the original one | ||
|
||
We use automatic metrics to evaluate both: | ||
* For (1), we can use a pre-trained classifier to classify the generated sentences and evaluate the accuracy (the higher the better). In this code we have not implemented a stand-alone classifier for evaluation, which could be very easy though. The `Classifier` part in the model gives a reasonably good estimation (i.e., `accu_g_gdy` in the above) of the accuracy. | ||
* For (2), we evaluate the BLEU score between the generated sentences and the original sentences, i.e., `bleu` in the above (the higher the better) (See [Yang et al., 2018](https://arxiv.org/pdf/1805.11749.pdf) for more details.) | ||
|
||
The implementation here gives the following performance after 10 epochs of pre-training and 2 epochs of full-training: | ||
|
||
| Accuracy (by the `Classifier` part) | BLEU (with the original sentence) | | ||
swapnull7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| -------------------------------------| ----------------------------------| | ||
| 0.96 | 52.0 | | ||
|
||
Also refer to the following papers that used this code and compared to other text style transfer approaches: | ||
|
||
* [Unsupervised Text Style Transfer using Language Models as Discriminators](https://papers.nips.cc/paper/7959-unsupervised-text-style-transfer-using-language-models-as-discriminators.pdf). Zichao Yang, Zhiting Hu, Chris Dyer, Eric Xing, Taylor Berg-Kirkpatrick. NeurIPS 2018 | ||
* [Structured Content Preservation for Unsupervised Text Style Transfer](https://arxiv.org/pdf/1810.06526.pdf). Youzhi Tian, Zhiting Hu, Zhou Yu. 2018 | ||
|
||
### Samples ### | ||
Here are some randomly-picked samples. In each pair, the first sentence is the original sentence and the second is the generated. | ||
``` | ||
swapnull7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
love , love love . | ||
poor , poor poor . | ||
|
||
good atmosphere . | ||
disgusted atmosphere . | ||
|
||
the donuts are good sized and very well priced . | ||
the donuts are disgusted sized and very _num_ priced . | ||
|
||
it is always clean and the staff is super friendly . | ||
it is nasty overpriced and the staff is super cold . | ||
|
||
super sweet place . | ||
super plain place . | ||
|
||
highly recommended . | ||
horrible horrible . | ||
|
||
very good ingredients . | ||
very disgusted ingredients . | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Config | ||
""" | ||
# pylint: disable=invalid-name | ||
|
||
import copy | ||
from typing import Dict, Any | ||
|
||
# Total number of training epochs (including pre-train and full-train) | ||
max_nepochs = 12 | ||
pretrain_nepochs = 10 # Number of pre-train epochs (training as autoencoder) | ||
display = 500 # Display the training results every N training steps. | ||
# Display the dev results every N training steps (set to a | ||
# very large value to disable it). | ||
display_eval = 1e10 | ||
|
||
sample_path = './samples' | ||
checkpoint_path = './checkpoints' | ||
restore = '' # Model snapshot to restore from | ||
|
||
lambda_g = 0.1 # Weight of the classification loss | ||
gamma_decay = 0.5 # Gumbel-softmax temperature anneal rate | ||
|
||
max_seq_length = 16 # Maximum sequence length in dataset w/o BOS token | ||
|
||
train_data: Dict[str, Any] = { | ||
'batch_size': 64, | ||
# 'seed': 123, | ||
'datasets': [ | ||
{ | ||
'files': './data/yelp/sentiment.train.text', | ||
'vocab_file': './data/yelp/vocab', | ||
'data_name': '' | ||
}, | ||
{ | ||
'files': './data/yelp/sentiment.train.labels', | ||
'data_type': 'int', | ||
'data_name': 'labels' | ||
} | ||
], | ||
'name': 'train' | ||
} | ||
|
||
val_data = copy.deepcopy(train_data) | ||
val_data['datasets'][0]['files'] = './data/yelp/sentiment.dev.text' | ||
val_data['datasets'][1]['files'] = './data/yelp/sentiment.dev.labels' | ||
|
||
test_data = copy.deepcopy(train_data) | ||
test_data['datasets'][0]['files'] = './data/yelp/sentiment.test.text' | ||
test_data['datasets'][1]['files'] = './data/yelp/sentiment.test.labels' | ||
|
||
model = { | ||
'dim_c': 200, | ||
'dim_z': 500, | ||
'embedder': { | ||
'dim': 100, | ||
}, | ||
'max_seq_length': max_seq_length, | ||
'encoder': { | ||
'rnn_cell': { | ||
'type': 'GRUCell', | ||
'kwargs': { | ||
'num_units': 700 | ||
}, | ||
'dropout': { | ||
'input_keep_prob': 0.5 | ||
} | ||
} | ||
}, | ||
'decoder': { | ||
'rnn_cell': { | ||
'type': 'GRUCell', | ||
'kwargs': { | ||
'num_units': 700, | ||
}, | ||
'dropout': { | ||
'input_keep_prob': 0.5, | ||
'output_keep_prob': 0.5 | ||
}, | ||
}, | ||
'attention': { | ||
'type': 'BahdanauAttention', | ||
'kwargs': { | ||
'num_units': 700, | ||
}, | ||
'attention_layer_size': 700, | ||
}, | ||
'max_decoding_length_train': 21, | ||
'max_decoding_length_infer': 20, | ||
}, | ||
'classifier': { | ||
'kernel_size': [3, 4, 5], | ||
'out_channels': 128, | ||
'data_format': 'channels_last', | ||
'other_conv_kwargs': [[{'padding': 1}, {'padding': 2}, {'padding': 2}]], | ||
'dropout_conv': [1], | ||
'dropout_rate': 0.5, | ||
'num_dense_layers': 0, | ||
'num_classes': 1 | ||
}, | ||
'opt': { | ||
'optimizer': { | ||
'type': 'Adam', | ||
'kwargs': { | ||
'lr': 3e-4, | ||
}, | ||
}, | ||
}, | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.