Skip to content

Add text style transfer #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9e45a3c
initial commit
swapnull7 Nov 27, 2019
277aa84
Merge branch 'master' into add-text-style-transfer
swapnull7 Nov 27, 2019
5ad14e3
bug fixes and adjusting conv inputs
swapnull7 Dec 3, 2019
b99aacc
separate forward function for Discriminator and Generator and disable…
swapnull7 Dec 4, 2019
258753c
remove debugger statement
swapnull7 Dec 4, 2019
44db04a
bug fix
swapnull7 Dec 4, 2019
b3711a8
detaching stuff before accumulating
swapnull7 Dec 4, 2019
d42929e
refactor and add component as optional parameter
swapnull7 Dec 4, 2019
f849833
Add optimizer for and backprop against encoder
swapnull7 Dec 4, 2019
28f2fec
Add in README
swapnull7 Dec 4, 2019
37cee30
Merge remote-tracking branch 'upstream/master' into add-text-style-tr…
swapnull7 Dec 4, 2019
9eb4c97
more fixes to eval mode
swapnull7 Dec 5, 2019
6dcc7fb
create optimizers so that they can be saved
swapnull7 Dec 5, 2019
4d38ff4
fix typo
swapnull7 Dec 5, 2019
c5f9ac6
Merge branch 'master' into add-text-style-transfer
swapnull7 Dec 5, 2019
20b6aa5
linting issues
swapnull7 Dec 5, 2019
fde2163
Merge branch 'master' into add-text-style-transfer
swapnull7 Dec 5, 2019
1b16c48
add type annottation for encoder
swapnull7 Dec 5, 2019
a344e52
fix linting
swapnull7 Dec 5, 2019
f5df906
Disable codecov/patch (#265)
gpengzhi Dec 6, 2019
6a7a2b4
Isolate AE in training
swapnull7 Dec 6, 2019
9c3b7d3
works after changing the learning rate
swapnull7 Dec 8, 2019
8256ed7
remove debugger
swapnull7 Dec 8, 2019
ef36abd
Merge branch 'master' into add-text-style-transfer
swapnull7 Dec 8, 2019
9966731
Merge remote-tracking branch 'upstream/master'
swapnull7 Dec 9, 2019
59fc4f7
Reviewed changes
swapnull7 Dec 9, 2019
2b65b9e
linting
swapnull7 Dec 9, 2019
5cf695a
Merge branch 'master' into add-text-style-transfer
swapnull7 Dec 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions .codecov.yml

This file was deleted.

7 changes: 7 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
coverage:
status:
project:
default:
threshold: 1%

patch: false
15 changes: 15 additions & 0 deletions docs/code/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Frequent Use
.. autoclass:: texar.torch.utils.AverageRecorder
:members:

:hidden:`collect_trainable_variables`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.collect_trainable_variables

:hidden:`compat_as_text`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.compat_as_text
Expand All @@ -20,6 +24,17 @@ Frequent Use
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.write_paired_text

Variables
=========

:hidden:`collect_trainable_variables`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.collect_trainable_variables

:hidden:`add_variable`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: texar.torch.utils.add_variable


IO
===
Expand Down
9 changes: 9 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ More examples are continuously added...
* [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation
* [xlnet](https://github.com/asyml/texar-pytorch/tree/master/examples/xlnet): Pre-trained XLNet model for text representation

### GANs / Discriminiator-supervision ###

* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation

---

## Examples by Tasks
Expand All @@ -35,6 +39,11 @@ More examples are continuously added...
* [seq2seq_attn](https://github.com/asyml/texar-pytorch/tree/master/examples/seq2seq_attn): Attentional seq2seq
* [transformer](https://github.com/asyml/texar-pytorch/tree/master/examples/transformer): Transformer for machine translation

### Text Style Transfer ###

* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation


### Classification ###

* [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation
Expand Down
4 changes: 4 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ More examples are continuously added...
* [seq2seq_attn](./seq2seq_attn): Attentional seq2seq
* [transformer](./transformer): Transformer for machine translation

### Text Style Transfer ###

* [text_style_transfer](./text_style_transfer): Discriminator supervision for controlled text generation

### Classification ###

* [bert](./bert): Pre-trained BERT model for text representation
Expand Down
58 changes: 28 additions & 30 deletions examples/text_style_transfer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The model roughly has an architecture of `Encoder--Decoder--Classifier`. Compare
## Usage ##

### Dataset ###
Download the yelp sentiment dataset with the following cmd:
Download the yelp sentiment dataset with the following command:
```
python prepare_data.py
```
Expand All @@ -36,24 +36,25 @@ python main.py --config config
Training log is printed as below:
```
gamma: 1.0, lambda_g: 0.0
step: 1, loss_d: 0.6903 accu_d: 0.5625
step: 1, loss_g_clas: 0.6991 loss_g: 9.1452 accu_g: 0.2812 loss_g_ae: 9.1452 accu_g_gdy: 0.2969
step: 500, loss_d: 0.0989 accu_d: 0.9688
step: 500, loss_g_clas: 0.2985 loss_g: 3.9696 accu_g: 0.8891 loss_g_ae: 3.9696 accu_g_gdy: 0.7734
step: 1, loss_d: 0.6934 accu_d: 0.4844
step: 1, loss_g_ae: 9.1392
step: 500, loss_d: 0.1488 accu_d: 0.9484
step: 500, loss_g_ae: 4.2884
step: 1000, loss_d: 0.1215 accu_d: 0.9625
step: 1000, loss_g_ae: 2.6201
...
step: 6500, loss_d: 0.0806 accu_d: 0.9703
step: 6500, loss_g_clas: 5.7137 loss_g: 0.2887 accu_g: 0.0844 loss_g_ae: 0.2887 accu_g_gdy: 0.0625
epoch: 1, loss_d: 0.0876 accu_d: 0.9719
epoch: 1, loss_g_clas: 6.7360 loss_g: 0.2195 accu_g: 0.0627 loss_g_ae: 0.2195 accu_g_gdy: 0.0642
val: accu_g: 0.0445 loss_g_ae: 0.1302 accu_d: 0.9774 bleu: 90.7896 loss_g: 0.1302 loss_d: 0.0666 loss_g_clas: 7.0310 accu_g_gdy: 0.0482
epoch: 1, loss_d: 0.0750 accu_d: 0.9688
epoch: 1, loss_g_ae: 0.8832
val: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2949 loss_d: 0.0702 accu_d: 0.9744 accu_g: 0.3022 accu_g_gdy: 0.2732 bleu: 60.8234
test: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2359 loss_d: 0.0746 accu_d: 0.9733 accu_g: 0.3076 accu_g_gdy: 0.2791 bleu: 60.1810993 accu_g_gdy: 0.5993 bleu: 63.6671
...

```
where:
- `loss_d` and `accu_d` are the classification loss/accuracy of the `Classifier` part.
- `loss_g_clas` is the classification loss of the generated sentences.
- `loss_g_class` is the classification loss of the generated sentences.
- `loss_g_ae` is the autoencoding loss.
- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_clas`.
- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_class`.
- `accu_g` is the classification accuracy of the generated sentences with soft represetations (i.e., Gumbel-softmax).
- `accu_g_gdy` is the classification accuracy of the generated sentences with greedy decoding.
- `bleu` is the BLEU score between the generated and input sentences.
Expand All @@ -72,7 +73,7 @@ The implementation here gives the following performance after 10 epochs of pre-t

| Accuracy (by the `Classifier` part) | BLEU (with the original sentence) |
| -------------------------------------| ----------------------------------|
| 0.92 | 54.0 |
| 0.96 | 52.0 |

Also refer to the following papers that used this code and compared to other text style transfer approaches:

Expand All @@ -82,27 +83,24 @@ Also refer to the following papers that used this code and compared to other tex
### Samples ###
Here are some randomly-picked samples. In each pair, the first sentence is the original sentence and the second is the generated.
```
go to place for client visits with gorgeous views .
go to place for client visits with lacking views .
love , love love .
poor , poor poor .

there was lots of people but they still managed to provide great service .
there was lots of people but they still managed to provide careless service .
good atmosphere .
disgusted atmosphere .

this was the best dining experience i have ever had .
this was the worst dining experience i have ever had .
the donuts are good sized and very well priced .
the donuts are disgusted sized and very _num_ priced .

needless to say , we skipped desert .
gentle to say , we edgy desert .
it is always clean and the staff is super friendly .
it is nasty overpriced and the staff is super cold .

the first time i was missing an entire sandwich and a side of fries .
the first time i was beautifully an entire sandwich and a side of fries .
super sweet place .
super plain place .

her boutique has a fabulous selection of designer brands !
her annoying has a sketchy selection of bland warned !
highly recommended .
horrible horrible .

service is pretty good .
service is trashy rude .

ok nothing new .
exceptional impressed new .
very good ingredients .
very disgusted ingredients .
```
13 changes: 6 additions & 7 deletions examples/text_style_transfer/ctrl_gen_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,8 @@

import torch
import torch.nn as nn
from torch.nn import functional as F


import texar.torch as tx
from texar.torch.modules import WordEmbedder, UnidirectionalRNNEncoder, \
Expand Down Expand Up @@ -84,9 +86,7 @@ def forward_D(self, inputs, f_labels):
input=class_inputs,
sequence_length=inputs['length'] - 1)

sig_ce_logits_loss = nn.BCEWithLogitsLoss()

loss_d = sig_ce_logits_loss(class_logits, f_labels)
loss_d = F.binary_cross_entropy_with_logits(class_logits, f_labels)
accu_d = tx.evals.accuracy(labels=f_labels,
preds=class_preds)
return {
Expand Down Expand Up @@ -168,9 +168,8 @@ def forward_G(self, inputs, f_labels, gamma, lambda_g, mode):
input=soft_inputs,
sequence_length=soft_length_)

sig_ce_logits_loss = nn.BCEWithLogitsLoss()

loss_g_class = sig_ce_logits_loss(soft_logits, (1 - f_labels))
loss_g_class = F.binary_cross_entropy_with_logits(soft_logits,
(1 - f_labels))

# Accuracy on greedy-decoded samples, for training progress monitoring
greedy_inputs = self.class_embedder(ids=outputs_.sample_id)
Expand Down
11 changes: 5 additions & 6 deletions examples/text_style_transfer/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,7 +50,7 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _main():
def main():
# Data
train_data = tx.data.MultiAlignedData(hparams=config.train_data,
device=device)
Expand All @@ -64,7 +64,7 @@ def _main():
# once for updating the discriminator. Feedable data iterator is used for
# such case.
iterator = tx.data.DataIterator(
{'train_g': train_data, 'train_d': train_data,
{'train': train_data,
'val': val_data, 'test': test_data})

# Model
Expand Down Expand Up @@ -95,7 +95,7 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True):
model.train()
avg_meters_d = tx.utils.AverageRecorder(size=10)
avg_meters_g = tx.utils.AverageRecorder(size=10)
iterator.switch_to_dataset("train_g")
iterator.switch_to_dataset("train")
step = 0
for batch in iterator:
train_op_d.zero_grad()
Expand Down Expand Up @@ -133,7 +133,6 @@ def _train_epoch(gamma_, lambda_g_, epoch, verbose=True):
print('step: {}, {}'.format(step, avg_meters_g.to_str(4)))

if verbose and step % config.display_eval == 0:
iterator.switch_to_dataset("val")
_eval_epoch(gamma_, lambda_g_, epoch)

print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4)))
Expand Down Expand Up @@ -207,4 +206,4 @@ def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'):


if __name__ == '__main__':
_main()
main()
4 changes: 1 addition & 3 deletions examples/text_style_transfer/prepare_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,8 +15,6 @@
"""
import texar.torch as tx

# pylint: disable=invalid-name


def prepare_data():
"""Downloads data.
Expand Down
22 changes: 12 additions & 10 deletions texar/torch/utils/variables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,16 +15,18 @@
Utility functions related to variables.
"""

# pylint: disable=invalid-name
from typing import Any, List, Tuple, Union

__all__ = [
"add_variable",
"collect_trainable_variables"
]


def add_variable(variable, var_list):
"""Adds variable to a given list.
def add_variable(
variable: Union[List[Any], Tuple[Any]],
var_list: List[Any]):
r"""Adds variable to a given list.

Args:
variable: A (list of) variable(s).
Expand All @@ -34,28 +36,28 @@ def add_variable(variable, var_list):
for var in variable:
add_variable(var, var_list)
else:
# Checking uniqueness gives error
# if variable in var_list:
var_list.append(variable)


def collect_trainable_variables(modules):
"""Collects all trainable variables of modules.
def collect_trainable_variables(
modules: Union[Any, List[Any]]
):
r"""Collects all trainable variables of modules.

Trainable variables included in multiple modules occur only once in the
returned list.

Args:
modules: A (list of) instance of the subclasses of
:class:`~texar.tf.modules.ModuleBase`.
:class:`~texar.torch.modules.ModuleBase`.

Returns:
A list of trainable variables in the modules.
"""
if not isinstance(modules, (list, tuple)):
modules = [modules]

var_list = []
var_list: List[Any] = []
for mod in modules:
add_variable(mod.trainable_variables, var_list)

Expand Down