Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2167370

Browse files
artitwafrozenator
authored andcommitted
Fix bAbi data generator and readme (#1235)
* fix bAbi data generator and readme * Fix bAbi hparams deletion * Fix bAbi hparams delete unecessary keys * Fix bAbi hparams clean keys * bAbi hparams delete keys
1 parent abf63bf commit 2167370

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pip install tensor2tensor && t2t-trainer \
4747
### Contents
4848

4949
* [Suggested Datasets and Models](#suggested-datasets-and-models)
50+
* [Story, Question and Answer](#story-question-and-answer)
5051
* [Image Classification](#image-classification)
5152
* [Image Generation](#image-generation)
5253
* [Language Modeling](#language-modeling)
@@ -78,6 +79,16 @@ hyperparameters that we know works well in our setup. We usually
7879
run either on Cloud TPUs or on 8-GPU machines; you might need
7980
to modify the hyperparameters if you run on a different setup.
8081

82+
### Story, Question and Answer
83+
84+
For answering questions based on a story, use
85+
86+
* the [bAbi][1] data-set: `--problem=babi_qa_concat_task1_1k`
87+
88+
You can choose the bAbi task from the range [1,20] and the subset from 1k or 10k. To combine test data from all tasks into a single test set, use `--problem=babi_qa_concat_all_tasks_10k`
89+
90+
[1] https://research.fb.com/downloads/babi/
91+
8192
### Image Classification
8293

8394
For image classification, we have a number of standard data-sets:

tensor2tensor/data_generators/babi_qa.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@
3434
import os
3535
import shutil
3636
import tarfile
37-
37+
import requests
3838
import six
3939

40-
from tensor2tensor.data_generators import generator_utils
4140
from tensor2tensor.data_generators import problem
4241
from tensor2tensor.data_generators import text_encoder
4342
from tensor2tensor.data_generators import text_problems
@@ -109,8 +108,12 @@ def _prepare_babi_data(tmp_dir, data_dir):
109108
if not tf.gfile.Exists(data_dir):
110109
tf.gfile.MakeDirs(data_dir)
111110

112-
# TODO(dehghani@): find a solution for blocking user-agent (download)
113-
file_path = generator_utils.maybe_download(tmp_dir, _TAR, _URL)
111+
file_path = os.path.join(tmp_dir, _TAR)
112+
headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'}
113+
resp = requests.get(_URL, headers=headers)
114+
with open(file_path, 'wb') as f:
115+
f.write(resp.content)
116+
114117
tar = tarfile.open(file_path)
115118
tar.extractall(tmp_dir)
116119
tar.close()
@@ -449,8 +452,12 @@ def preprocess_example(self, example, unused_mode, unused_model_hparams):
449452
def hparams(self, defaults, unused_model_hparams):
450453
super(BabiQaConcat, self).hparams(defaults, unused_model_hparams)
451454
p = defaults
452-
del p.modality['context']
453-
del p.vocab_size['context']
455+
456+
if 'context' in p.modality:
457+
del p.modality['context']
458+
459+
if 'context' in p.vocab_size:
460+
del p.vocab_size['context']
454461

455462

456463
def _problems_to_register():

0 commit comments

Comments
 (0)