Skip to content

Commit e0a694e

Browse files
cwbeitelkpe
authored andcommitted
First draft of multi-problem docs (tensorflow#1399)
* first draft of multi-problem docs * simplification of tid lookup docs * update multi-problem inference from ckpt docs * minor command fixes; sp. * polish
1 parent 990ed59 commit e0a694e

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed

docs/multi_problem.md

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Multi-problem training
2+
3+
Multi-problem training is possible by defining [MultiProblem](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py) sub-classes that specify a list of [Problem](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py) objects to include in training. In some cases, multi-problem training can be used to improve performance compared to training on individual problems.
4+
5+
In the following sections we'll discuss MultiProblem from a usage perspective followed by that of someone wishing to build upon it.
6+
7+
Please note the [T2T Walkthrough](https://github.com/tensorflow/tensor2tensor/blob/master/docs/walkthrough.md) documentation is a good place to start to understand the variety of component concepts we'll build on here.
8+
9+
## Usage
10+
11+
### Problem definition and datagen
12+
13+
In this discussion we'll consider the following (large) multi-problem that includes ten different sub-problems. These include:
14+
15+
1. A [language modeling](https://en.wikipedia.org/wiki/Language_model) [problem](https://github.com/tensorflow/tensor2tensor/blob/0dff89d64c3406d42717280cb9135a5ce7af793c/tensor2tensor/data_generators/wiki_lm.py#L223) operating on a corpus of German, English, French, and Romanian language wikipedia articles.
16+
2. Multiple compatible pairwise language translation problems (En -> De, En -> Fr, En -> Ro, De -> En, Fr -> En, Ro -> En)
17+
3. A compatible [version](https://github.com/tensorflow/tensor2tensor/blob/ef12bee72270b322165d073c39a650a189de39aa/tensor2tensor/data_generators/cnn_dailymail.py#L267) of the combined CNN/DailyMail news article summarization problem.
18+
4. A compatible [version](https://github.com/tensorflow/tensor2tensor/blob/ef12bee72270b322165d073c39a650a189de39aa/tensor2tensor/data_generators/multinli.py#L155) of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) textual entailment classification problem.
19+
5. A compatible [version](https://github.com/tensorflow/tensor2tensor/blob/1de13dbebccb415d89b0658e18a57e9607bafd32/tensor2tensor/data_generators/squad.py#L126) of the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) question/answer problem.
20+
21+
```python
22+
23+
@registry.register_problem
24+
class LanguagemodelMultiWikiTranslate(multi_problem.MultiProblem):
25+
"""Wiki multi-lingual LM and multiple translations."""
26+
27+
def __init__(self, was_reversed=False, was_copy=False):
28+
super(LanguagemodelMultiWikiTranslate, self).__init__(
29+
was_reversed, was_copy)
30+
self.task_list.append(wiki_lm.LanguagemodelDeEnFrRoWiki64k())
31+
self.task_list.append(translate_ende.TranslateEndeWmtMulti64k())
32+
self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k())
33+
self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k())
34+
self.task_list.append(translate_ende.TranslateEndeWmtMulti64k(
35+
was_reversed=True))
36+
self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k(
37+
was_reversed=True))
38+
self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k(
39+
was_reversed=True))
40+
self.task_list.append(
41+
cnn_dailymail.SummarizeCnnDailymailWikiLMMultiVocab64k())
42+
self.task_list.append(multinli.MultiNLIWikiLMMultiVocab64k())
43+
self.task_list.append(squad.SquadConcatMulti64k())
44+
45+
@property
46+
def vocab_type(self):
47+
return text_problems.VocabType.SUBWORD
48+
49+
```
50+
51+
The word "compatible" was used a lot above! That's because each of these problems have been modified to use the vocabulary produced by the Wikipedia-based language modeling problem, e.g. the following
52+
53+
```python
54+
@registry.register_problem
55+
class SummarizeCnnDailymailWikiLMMultiVocab64k(SummarizeCnnDailymail32k):
56+
"""Summarize CNN and Daily Mail articles using multi-lingual 64k vocab."""
57+
58+
@property
59+
def vocab_filename(self):
60+
return wiki_lm.LanguagemodelDeEnFrRoWiki64k().vocab_filename
61+
```
62+
63+
**Important note:** It's easy to miss the key point that, as implemented currently, the first task in the task list must be a language modelling problem and each included task must be modified to use the resulting vocabulary.
64+
65+
With a properly defined and registered multi-problem we can now run datagen as follows:
66+
67+
```bash
68+
69+
t2t-datagen --problem=languagemodel_multi_wiki_translate
70+
71+
```
72+
73+
This will take approximately the following amount of space (and several hours):
74+
75+
```bash
76+
(t2t) username@instance-2:~$ du -sh /tmp
77+
99G /tmp
78+
(t2t) username@instance-2:~$ du -sh /tmp/t2t_datagen
79+
81G /tmp/t2t_datagen
80+
```
81+
82+
### Training
83+
84+
Next we're ready to try training a model on this MultiProblem. Note that by not specifying `--data_dir` above TFExample's were by default generated into /tmp so that's what we'll explicitly provide here.
85+
86+
```bash
87+
88+
t2t-trainer --problem=languagemodel_multi_wiki_translate \
89+
--model=transformer \
90+
--hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \
91+
--output_dir ~/t2t_train/transformer_multi_2jan19 \
92+
--data_dir=/tmp \
93+
--train_steps=1 \
94+
--eval_steps=1
95+
96+
```
97+
98+
The `hparams_set` parameter we provided above was [transformer_tall_pretrain_lm_tpu_adafactor_large](https://github.com/tensorflow/tensor2tensor/blob/08e83030acf3ef13d15ad6eaefaa0a67fb20b59d/tensor2tensor/models/transformer.py#L1721), also provided below:
99+
100+
```python
101+
102+
@registry.register_hparams
103+
def transformer_tall_pretrain_lm_tpu_adafactor_large():
104+
"""Hparams for transformer on LM pretraining on TPU, large model."""
105+
hparams = transformer_tall_pretrain_lm_tpu_adafactor()
106+
hparams.hidden_size = 1024
107+
hparams.num_heads = 16
108+
hparams.filter_size = 32768 # max fitting in 16G memory is 49152, batch 2
109+
hparams.batch_size = 4
110+
hparams.multiproblem_mixing_schedule = "constant"
111+
# Task order: lm/en-de/en-fr/en-ro/de-en/fr-en/ro-en/cnndm/mnli/squad.
112+
hparams.multiproblem_per_task_threshold = "320,80,160,2,80,160,2,20,5,5"
113+
return hparams
114+
115+
```
116+
117+
Here it's worth noting a couple things, one that we have specified a `multi_problem_mixing_schedule` (which is required), consumed by [MultiProblem.mix_data](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py#L280). When set to "constant" the strategy for sampling examples is not a function of step and is proportional only to the per-task "thresholds" which are by default equal (sample examples from each problem with equal probability).
118+
119+
But notice we have also specified the (non-required) `multiproblem_per_task_threshold` parameter, also consumed by mix_data, and specifically used by [sample_task](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py#L340) which defines non-uniform thresholds to inform a weighted random sampling. E.g. for two problems with weights 1 and 9 the first would be sampled 1/10 of the time and the other 9/10.
120+
121+
### Inference
122+
123+
You can try translating from English to German using a model previously trained on `LanguagemodelMultiWikiTranslate` (the one shown above) ([gs://tensor2tensor-checkpoints/transformer_multi_2jan19/](https://console.cloud.google.com/storage/browser/tensor2tensor-checkpoints/transformer_multi_2jan19/)). Just copy the checkpoint down to a local directory such as the one given via `--output_dir` below:
124+
125+
```bash
126+
127+
t2t-decoder --problem=languagemodel_multi_wiki_translate \
128+
--model=transformer \
129+
--hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \
130+
--decode_hparams='batch_size=1,multiproblem_task_id=64510' \
131+
--hparams="" \
132+
--output_dir=~/t2t_train/transformer_multi_2jan19 \
133+
--decode_from_file ~/newstest2014.en \
134+
--data_dir=~/t2t_train/transformer_multi_2jan19
135+
136+
```
137+
138+
Here we'll point `--data_dir` to the checkpoint directory which includes the vocab file `vocab.languagemodel_de_en_fr_ro_wiki64k.64000.subwords`; typically data_dir would point to the directory containing your TFRecord example dataset(s).
139+
140+
The file passed to `--decode_from_file` is simply a file with one sentence to translate on each line (in its original form, not post-vocabulary-encoded).
141+
142+
A key requirement for multi-problem inference is that we specify the ID of the problem for which we want to perform inference. But wait, why is the task ID 64510? We can see from the code for [`MultiProblem.update_task_ids`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py#L386) that TID's have a place at the end of the vocabulary.
143+
144+
```python
145+
146+
class MultiProblem(problem.Problem):
147+
"""MultiProblem base class."""
148+
149+
...
150+
151+
def update_task_ids(self, encoder_vocab_size):
152+
"""Generate task_ids for each problem.
153+
These ids correspond to the index of the task in the task_list.
154+
Args:
155+
encoder_vocab_size: the size of the vocab which is used to compute
156+
the index offset.
157+
"""
158+
for idx, task in enumerate(self.task_list):
159+
task.set_task_id(idx + encoder_vocab_size)
160+
tf.logging.info("Task %d (%s) has id %d." %
161+
(idx, task.name, task.task_id))
162+
163+
```
164+
165+
We can look up the task_id that is assigned to each task we may want to use for inference by instantiating the MultiProblem subclass and obtaining the value, in this case via the following:
166+
167+
```python
168+
169+
task_index = 1 # The second task in the list is En -> De
170+
LanguagemodelMultiWikiTranslate().task_list[task_index].task_id
171+
172+
```
173+
174+
For me running the `t2t-decode` command provided above gave the following output:
175+
176+
```bash
177+
...
178+
179+
INFO:tensorflow:Running local_init_op.
180+
INFO:tensorflow:Done running local_init_op.
181+
INFO:tensorflow:Inference results INPUT: hello world was the news of the day
182+
INFO:tensorflow:Inference results OUTPUT: Hallo Welt war die Nachricht des Tages
183+
INFO:tensorflow:Elapsed Time: 37.15079
184+
INFO:tensorflow:Averaged Single Token Generation Time: 3.3009222 (time 36.3101439 count 11)
185+
186+
...
187+
188+
```

0 commit comments

Comments
 (0)