Skip to content

Add text style transfer (#166) #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/code/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Frequent Use
.. autoclass:: texar.torch.utils.AverageRecorder
:members:

:hidden:`collect_trainable_variables`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.collect_trainable_variables

:hidden:`compat_as_text`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.compat_as_text
Expand All @@ -20,6 +24,17 @@ Frequent Use
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.write_paired_text

Variables
=========

:hidden:`collect_trainable_variables`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.collect_trainable_variables

:hidden:`add_variable`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.add_variable


IO
===
Expand Down
9 changes: 9 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ More examples are continuously added...
* [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation
* [xlnet](https://github.com/asyml/texar-pytorch/tree/master/examples/xlnet): Pre-trained XLNet model for text representation

### GANs / Discriminator-supervision ###

* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation

---

## Examples by Tasks
Expand All @@ -35,6 +39,11 @@ More examples are continuously added...
* [seq2seq_attn](https://github.com/asyml/texar-pytorch/tree/master/examples/seq2seq_attn): Attentional seq2seq
* [transformer](https://github.com/asyml/texar-pytorch/tree/master/examples/transformer): Transformer for machine translation

### Text Style Transfer ###

* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation


### Classification ###

* [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation
Expand Down
8 changes: 8 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ More examples are continuously added...

* [vae_text](./vae_text): VAE language model

### GANs / Discriminiator-supervision ###

* [text_style_transfer](./text_style_transfer): Discriminator supervision for controlled text generation

### Classifier / Sequence Prediction ###

* [bert](./bert): Pre-trained BERT model for text representation
Expand All @@ -43,6 +47,10 @@ More examples are continuously added...
* [seq2seq_attn](./seq2seq_attn): Attentional seq2seq
* [transformer](./transformer): Transformer for machine translation

### Text Style Transfer ###

* [text_style_transfer](./text_style_transfer): Discriminator supervision for controlled text generation

### Classification ###

* [bert](./bert): Pre-trained BERT model for text representation
Expand Down
106 changes: 106 additions & 0 deletions examples/text_style_transfer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Text Style Transfer #

This example implements a simplified variant of the `ctrl-gen` model from

[Toward Controlled Generation of Text](https://arxiv.org/pdf/1703.00955.pdf)
*Zhiting Hu, Zichao Yang, Xiaodan Liang, Ruslan Salakhutdinov, Eric Xing; ICML 2017*

The model roughly has an architecture of `Encoder--Decoder--Classifier`. Compared to the paper, following simplications are made:

* Replaces the base Variational Autoencoder (VAE) model with an attentional Autoencoder (AE) -- VAE is not necessary in the text style transfer setting since we do not need to interpolate the latent space as in the paper.
* Attribute classifier (i.e., discriminator) is trained with real data only. Samples generated by the decoder are not used.
* Independency constraint is omitted.

## Usage ##

### Dataset ###
Download the yelp sentiment dataset with the following command:
```
python prepare_data.py
```

### Train the model ###

Train the model on the above data to do sentiment transfer.
```
python main.py --config config
```

[config.py](./config.py) contains the data and mode configurations.

* The model will first be pre-trained for a few epochs (specified in `config.py`). During pre-training, the `Encoder-Decoder` part is trained as an autoencoder, while the `Classifier` part is trained with the classification labels.
* Full-training is then performed for another few epochs. During full-training, the `Classifier` part is fixed, and the `Encoder-Decoder` part is trained to fit the classifier, along with continuing to minimize the autoencoding loss.

(**Note:** When using your own dataset, make sure to set `max_decoding_length_train` and `max_decoding_length_infer` in [config.py](https://github.com/asyml/texar/blob/master/examples/text_style_transfer/config.py#L85-L86).)

Training log is printed as below:
```
gamma: 1.0, lambda_g: 0.0
step: 1, loss_d: 0.6934 accu_d: 0.4844
step: 1, loss_g_ae: 9.1392
step: 500, loss_d: 0.1488 accu_d: 0.9484
step: 500, loss_g_ae: 4.2884
step: 1000, loss_d: 0.1215 accu_d: 0.9625
step: 1000, loss_g_ae: 2.6201
...
epoch: 1, loss_d: 0.0750 accu_d: 0.9688
epoch: 1, loss_g_ae: 0.8832
val: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2949 loss_d: 0.0702 accu_d: 0.9744 accu_g: 0.3022 accu_g_gdy: 0.2732 bleu: 60.8234
test: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2359 loss_d: 0.0746 accu_d: 0.9733 accu_g: 0.3076 accu_g_gdy: 0.2791 bleu: 60.1810993 accu_g_gdy: 0.5993 bleu: 63.6671
...

```
where:
- `loss_d` and `accu_d` are the classification loss/accuracy of the `Classifier` part.
- `loss_g_class` is the classification loss of the generated sentences.
- `loss_g_ae` is the autoencoding loss.
- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_class`.
- `accu_g` is the classification accuracy of the generated sentences with soft represetations (i.e., Gumbel-softmax).
- `accu_g_gdy` is the classification accuracy of the generated sentences with greedy decoding.
- `bleu` is the BLEU score between the generated and input sentences.

## Results ##

Text style transfer has two primary goals:
1. The generated sentence should have desired attribute (e.g., positive/negative sentiment)
2. The generated sentence should keep the content of the original one

We use automatic metrics to evaluate both:
* For (1), we can use a pre-trained classifier to classify the generated sentences and evaluate the accuracy (the higher the better). In this code we have not implemented a stand-alone classifier for evaluation, which could be very easy though. The `Classifier` part in the model gives a reasonably good estimation (i.e., `accu_g_gdy` in the above) of the accuracy.
* For (2), we evaluate the BLEU score between the generated sentences and the original sentences, i.e., `bleu` in the above (the higher the better) (See [Yang et al., 2018](https://arxiv.org/pdf/1805.11749.pdf) for more details.)

The implementation here gives the following performance after 10 epochs of pre-training and 2 epochs of full-training:

| Accuracy (by the `Classifier` part) | BLEU (with the original sentence) |
| -------------------------------------| ----------------------------------|
| 0.96 | 52.0 |

Also refer to the following papers that used this code and compared to other text style transfer approaches:

* [Unsupervised Text Style Transfer using Language Models as Discriminators](https://papers.nips.cc/paper/7959-unsupervised-text-style-transfer-using-language-models-as-discriminators.pdf). Zichao Yang, Zhiting Hu, Chris Dyer, Eric Xing, Taylor Berg-Kirkpatrick. NeurIPS 2018
* [Structured Content Preservation for Unsupervised Text Style Transfer](https://arxiv.org/pdf/1810.06526.pdf). Youzhi Tian, Zhiting Hu, Zhou Yu. 2018

### Samples ###
Here are some randomly-picked samples. In each pair, the first sentence is the original sentence and the second is the generated.
```
love , love love .
poor , poor poor .

good atmosphere .
disgusted atmosphere .

the donuts are good sized and very well priced .
the donuts are disgusted sized and very _num_ priced .

it is always clean and the staff is super friendly .
it is nasty overpriced and the staff is super cold .

super sweet place .
super plain place .

highly recommended .
horrible horrible .

very good ingredients .
very disgusted ingredients .
```
108 changes: 108 additions & 0 deletions examples/text_style_transfer/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Config
"""
# pylint: disable=invalid-name

import copy
from typing import Dict, Any

# Total number of training epochs (including pre-train and full-train)
max_nepochs = 12
pretrain_nepochs = 10 # Number of pre-train epochs (training as autoencoder)
display = 500 # Display the training results every N training steps.
# Display the dev results every N training steps (set to a
# very large value to disable it).
display_eval = 1e10

sample_path = './samples'
checkpoint_path = './checkpoints'
restore = '' # Model snapshot to restore from

lambda_g = 0.1 # Weight of the classification loss
gamma_decay = 0.5 # Gumbel-softmax temperature anneal rate

max_seq_length = 16 # Maximum sequence length in dataset w/o BOS token

train_data: Dict[str, Any] = {
'batch_size': 64,
# 'seed': 123,
'datasets': [
{
'files': './data/yelp/sentiment.train.text',
'vocab_file': './data/yelp/vocab',
'data_name': ''
},
{
'files': './data/yelp/sentiment.train.labels',
'data_type': 'int',
'data_name': 'labels'
}
],
'name': 'train'
}

val_data = copy.deepcopy(train_data)
val_data['datasets'][0]['files'] = './data/yelp/sentiment.dev.text'
val_data['datasets'][1]['files'] = './data/yelp/sentiment.dev.labels'

test_data = copy.deepcopy(train_data)
test_data['datasets'][0]['files'] = './data/yelp/sentiment.test.text'
test_data['datasets'][1]['files'] = './data/yelp/sentiment.test.labels'

model = {
'dim_c': 200,
'dim_z': 500,
'embedder': {
'dim': 100,
},
'max_seq_length': max_seq_length,
'encoder': {
'rnn_cell': {
'type': 'GRUCell',
'kwargs': {
'num_units': 700
},
'dropout': {
'input_keep_prob': 0.5
}
}
},
'decoder': {
'rnn_cell': {
'type': 'GRUCell',
'kwargs': {
'num_units': 700,
},
'dropout': {
'input_keep_prob': 0.5,
'output_keep_prob': 0.5
},
},
'attention': {
'type': 'BahdanauAttention',
'kwargs': {
'num_units': 700,
},
'attention_layer_size': 700,
},
'max_decoding_length_train': 21,
'max_decoding_length_infer': 20,
},
'classifier': {
'kernel_size': [3, 4, 5],
'out_channels': 128,
'data_format': 'channels_last',
'other_conv_kwargs': [[{'padding': 1}, {'padding': 2}, {'padding': 2}]],
'dropout_conv': [1],
'dropout_rate': 0.5,
'num_dense_layers': 0,
'num_classes': 1
},
'opt': {
'optimizer': {
'type': 'Adam',
'kwargs': {
'lr': 3e-4,
},
},
},
}
Loading