|
| 1 | +# Multi-problem training |
| 2 | + |
| 3 | +Multi-problem training is possible by defining [MultiProblem](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py) sub-classes that specify a list of [Problem](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py) objects to include in training. In some cases, multi-problem training can be used to improve performance compared to training on individual problems. |
| 4 | + |
| 5 | +In the following sections we'll discuss MultiProblem from a usage perspective followed by that of someone wishing to build upon it. |
| 6 | + |
| 7 | +Please note the [T2T Walkthrough](https://github.com/tensorflow/tensor2tensor/blob/master/docs/walkthrough.md) documentation is a good place to start to understand the variety of component concepts we'll build on here. |
| 8 | + |
| 9 | +## Usage |
| 10 | + |
| 11 | +### Problem definition and datagen |
| 12 | + |
| 13 | +In this discussion we'll consider the following (large) multi-problem that includes ten different sub-problems. These include: |
| 14 | + |
| 15 | +1. A [language modeling](https://en.wikipedia.org/wiki/Language_model) [problem](https://github.com/tensorflow/tensor2tensor/blob/0dff89d64c3406d42717280cb9135a5ce7af793c/tensor2tensor/data_generators/wiki_lm.py#L223) operating on a corpus of German, English, French, and Romanian language wikipedia articles. |
| 16 | +2. Multiple compatible pairwise language translation problems (En -> De, En -> Fr, En -> Ro, De -> En, Fr -> En, Ro -> En) |
| 17 | +3. A compatible [version](https://github.com/tensorflow/tensor2tensor/blob/ef12bee72270b322165d073c39a650a189de39aa/tensor2tensor/data_generators/cnn_dailymail.py#L267) of the combined CNN/DailyMail news article summarization problem. |
| 18 | +4. A compatible [version](https://github.com/tensorflow/tensor2tensor/blob/ef12bee72270b322165d073c39a650a189de39aa/tensor2tensor/data_generators/multinli.py#L155) of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) textual entailment classification problem. |
| 19 | +5. A compatible [version](https://github.com/tensorflow/tensor2tensor/blob/1de13dbebccb415d89b0658e18a57e9607bafd32/tensor2tensor/data_generators/squad.py#L126) of the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) question/answer problem. |
| 20 | + |
| 21 | +```python |
| 22 | + |
| 23 | +@registry.register_problem |
| 24 | +class LanguagemodelMultiWikiTranslate(multi_problem.MultiProblem): |
| 25 | + """Wiki multi-lingual LM and multiple translations.""" |
| 26 | + |
| 27 | + def __init__(self, was_reversed=False, was_copy=False): |
| 28 | + super(LanguagemodelMultiWikiTranslate, self).__init__( |
| 29 | + was_reversed, was_copy) |
| 30 | + self.task_list.append(wiki_lm.LanguagemodelDeEnFrRoWiki64k()) |
| 31 | + self.task_list.append(translate_ende.TranslateEndeWmtMulti64k()) |
| 32 | + self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k()) |
| 33 | + self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k()) |
| 34 | + self.task_list.append(translate_ende.TranslateEndeWmtMulti64k( |
| 35 | + was_reversed=True)) |
| 36 | + self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k( |
| 37 | + was_reversed=True)) |
| 38 | + self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k( |
| 39 | + was_reversed=True)) |
| 40 | + self.task_list.append( |
| 41 | + cnn_dailymail.SummarizeCnnDailymailWikiLMMultiVocab64k()) |
| 42 | + self.task_list.append(multinli.MultiNLIWikiLMMultiVocab64k()) |
| 43 | + self.task_list.append(squad.SquadConcatMulti64k()) |
| 44 | + |
| 45 | + @property |
| 46 | + def vocab_type(self): |
| 47 | + return text_problems.VocabType.SUBWORD |
| 48 | + |
| 49 | +``` |
| 50 | + |
| 51 | +The word "compatible" was used a lot above! That's because each of these problems have been modified to use the vocabulary produced by the Wikipedia-based language modeling problem, e.g. the following |
| 52 | + |
| 53 | +```python |
| 54 | +@registry.register_problem |
| 55 | +class SummarizeCnnDailymailWikiLMMultiVocab64k(SummarizeCnnDailymail32k): |
| 56 | + """Summarize CNN and Daily Mail articles using multi-lingual 64k vocab.""" |
| 57 | + |
| 58 | + @property |
| 59 | + def vocab_filename(self): |
| 60 | + return wiki_lm.LanguagemodelDeEnFrRoWiki64k().vocab_filename |
| 61 | +``` |
| 62 | + |
| 63 | +**Important note:** It's easy to miss the key point that, as implemented currently, the first task in the task list must be a language modelling problem and each included task must be modified to use the resulting vocabulary. |
| 64 | + |
| 65 | +With a properly defined and registered multi-problem we can now run datagen as follows: |
| 66 | + |
| 67 | +```bash |
| 68 | + |
| 69 | +t2t-datagen --problem=languagemodel_multi_wiki_translate |
| 70 | + |
| 71 | +``` |
| 72 | + |
| 73 | +This will take approximately the following amount of space (and several hours): |
| 74 | + |
| 75 | +```bash |
| 76 | +(t2t) username@instance-2:~$ du -sh /tmp |
| 77 | +99G /tmp |
| 78 | +(t2t) username@instance-2:~$ du -sh /tmp/t2t_datagen |
| 79 | +81G /tmp/t2t_datagen |
| 80 | +``` |
| 81 | + |
| 82 | +### Training |
| 83 | + |
| 84 | +Next we're ready to try training a model on this MultiProblem. Note that by not specifying `--data_dir` above TFExample's were by default generated into /tmp so that's what we'll explicitly provide here. |
| 85 | + |
| 86 | +```bash |
| 87 | + |
| 88 | +t2t-trainer --problem=languagemodel_multi_wiki_translate \ |
| 89 | + --model=transformer \ |
| 90 | + --hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \ |
| 91 | + --output_dir ~/t2t_train/transformer_multi_2jan19 \ |
| 92 | + --data_dir=/tmp \ |
| 93 | + --train_steps=1 \ |
| 94 | + --eval_steps=1 |
| 95 | + |
| 96 | +``` |
| 97 | + |
| 98 | +The `hparams_set` parameter we provided above was [transformer_tall_pretrain_lm_tpu_adafactor_large](https://github.com/tensorflow/tensor2tensor/blob/08e83030acf3ef13d15ad6eaefaa0a67fb20b59d/tensor2tensor/models/transformer.py#L1721), also provided below: |
| 99 | + |
| 100 | +```python |
| 101 | + |
| 102 | +@registry.register_hparams |
| 103 | +def transformer_tall_pretrain_lm_tpu_adafactor_large(): |
| 104 | + """Hparams for transformer on LM pretraining on TPU, large model.""" |
| 105 | + hparams = transformer_tall_pretrain_lm_tpu_adafactor() |
| 106 | + hparams.hidden_size = 1024 |
| 107 | + hparams.num_heads = 16 |
| 108 | + hparams.filter_size = 32768 # max fitting in 16G memory is 49152, batch 2 |
| 109 | + hparams.batch_size = 4 |
| 110 | + hparams.multiproblem_mixing_schedule = "constant" |
| 111 | + # Task order: lm/en-de/en-fr/en-ro/de-en/fr-en/ro-en/cnndm/mnli/squad. |
| 112 | + hparams.multiproblem_per_task_threshold = "320,80,160,2,80,160,2,20,5,5" |
| 113 | + return hparams |
| 114 | + |
| 115 | +``` |
| 116 | + |
| 117 | +Here it's worth noting a couple things, one that we have specified a `multi_problem_mixing_schedule` (which is required), consumed by [MultiProblem.mix_data](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py#L280). When set to "constant" the strategy for sampling examples is not a function of step and is proportional only to the per-task "thresholds" which are by default equal (sample examples from each problem with equal probability). |
| 118 | + |
| 119 | +But notice we have also specified the (non-required) `multiproblem_per_task_threshold` parameter, also consumed by mix_data, and specifically used by [sample_task](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py#L340) which defines non-uniform thresholds to inform a weighted random sampling. E.g. for two problems with weights 1 and 9 the first would be sampled 1/10 of the time and the other 9/10. |
| 120 | + |
| 121 | +### Inference |
| 122 | + |
| 123 | +You can try translating from English to German using a model previously trained on `LanguagemodelMultiWikiTranslate` (the one shown above) ([gs://tensor2tensor-checkpoints/transformer_multi_2jan19/](https://console.cloud.google.com/storage/browser/tensor2tensor-checkpoints/transformer_multi_2jan19/)). Just copy the checkpoint down to a local directory such as the one given via `--output_dir` below: |
| 124 | + |
| 125 | +```bash |
| 126 | + |
| 127 | +t2t-decoder --problem=languagemodel_multi_wiki_translate \ |
| 128 | + --model=transformer \ |
| 129 | + --hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \ |
| 130 | + --decode_hparams='batch_size=1,multiproblem_task_id=64510' \ |
| 131 | + --hparams="" \ |
| 132 | + --output_dir=~/t2t_train/transformer_multi_2jan19 \ |
| 133 | + --decode_from_file ~/newstest2014.en \ |
| 134 | + --data_dir=~/t2t_train/transformer_multi_2jan19 |
| 135 | + |
| 136 | +``` |
| 137 | + |
| 138 | +Here we'll point `--data_dir` to the checkpoint directory which includes the vocab file `vocab.languagemodel_de_en_fr_ro_wiki64k.64000.subwords`; typically data_dir would point to the directory containing your TFRecord example dataset(s). |
| 139 | + |
| 140 | +The file passed to `--decode_from_file` is simply a file with one sentence to translate on each line (in its original form, not post-vocabulary-encoded). |
| 141 | + |
| 142 | +A key requirement for multi-problem inference is that we specify the ID of the problem for which we want to perform inference. But wait, why is the task ID 64510? We can see from the code for [`MultiProblem.update_task_ids`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/multi_problem.py#L386) that TID's have a place at the end of the vocabulary. |
| 143 | + |
| 144 | +```python |
| 145 | + |
| 146 | +class MultiProblem(problem.Problem): |
| 147 | + """MultiProblem base class.""" |
| 148 | + |
| 149 | + ... |
| 150 | + |
| 151 | + def update_task_ids(self, encoder_vocab_size): |
| 152 | + """Generate task_ids for each problem. |
| 153 | + These ids correspond to the index of the task in the task_list. |
| 154 | + Args: |
| 155 | + encoder_vocab_size: the size of the vocab which is used to compute |
| 156 | + the index offset. |
| 157 | + """ |
| 158 | + for idx, task in enumerate(self.task_list): |
| 159 | + task.set_task_id(idx + encoder_vocab_size) |
| 160 | + tf.logging.info("Task %d (%s) has id %d." % |
| 161 | + (idx, task.name, task.task_id)) |
| 162 | + |
| 163 | +``` |
| 164 | + |
| 165 | +We can look up the task_id that is assigned to each task we may want to use for inference by instantiating the MultiProblem subclass and obtaining the value, in this case via the following: |
| 166 | + |
| 167 | +```python |
| 168 | + |
| 169 | +task_index = 1 # The second task in the list is En -> De |
| 170 | +LanguagemodelMultiWikiTranslate().task_list[task_index].task_id |
| 171 | + |
| 172 | +``` |
| 173 | + |
| 174 | +For me running the `t2t-decode` command provided above gave the following output: |
| 175 | + |
| 176 | +```bash |
| 177 | +... |
| 178 | + |
| 179 | +INFO:tensorflow:Running local_init_op. |
| 180 | +INFO:tensorflow:Done running local_init_op. |
| 181 | +INFO:tensorflow:Inference results INPUT: hello world was the news of the day |
| 182 | +INFO:tensorflow:Inference results OUTPUT: Hallo Welt war die Nachricht des Tages |
| 183 | +INFO:tensorflow:Elapsed Time: 37.15079 |
| 184 | +INFO:tensorflow:Averaged Single Token Generation Time: 3.3009222 (time 36.3101439 count 11) |
| 185 | + |
| 186 | +... |
| 187 | + |
| 188 | +``` |
0 commit comments