Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Fix bAbi data generator and readme #1235

Merged
merged 5 commits into from
Nov 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pip install tensor2tensor && t2t-trainer \
### Contents

* [Suggested Datasets and Models](#suggested-datasets-and-models)
* [Story, Question and Answer](#story-question-and-answer)
* [Image Classification](#image-classification)
* [Image Generation](#image-generation)
* [Language Modeling](#language-modeling)
Expand Down Expand Up @@ -78,6 +79,16 @@ hyperparameters that we know works well in our setup. We usually
run either on Cloud TPUs or on 8-GPU machines; you might need
to modify the hyperparameters if you run on a different setup.

### Story, Question and Answer

For answering questions based on a story, use

* the [bAbi][1] data-set: `--problem=babi_qa_concat_task1_1k`

You can choose the bAbi task from the range [1,20] and the subset from 1k or 10k. To combine test data from all tasks into a single test set, use `--problem=babi_qa_concat_all_tasks_10k`

[1] https://research.fb.com/downloads/babi/

### Image Classification

For image classification, we have a number of standard data-sets:
Expand Down
19 changes: 13 additions & 6 deletions tensor2tensor/data_generators/babi_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@
import os
import shutil
import tarfile

import requests
import six

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
Expand Down Expand Up @@ -109,8 +108,12 @@ def _prepare_babi_data(tmp_dir, data_dir):
if not tf.gfile.Exists(data_dir):
tf.gfile.MakeDirs(data_dir)

# TODO(dehghani@): find a solution for blocking user-agent (download)
file_path = generator_utils.maybe_download(tmp_dir, _TAR, _URL)
file_path = os.path.join(tmp_dir, _TAR)
headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'}
resp = requests.get(_URL, headers=headers)
with open(file_path, 'wb') as f:
f.write(resp.content)

tar = tarfile.open(file_path)
tar.extractall(tmp_dir)
tar.close()
Expand Down Expand Up @@ -449,8 +452,12 @@ def preprocess_example(self, example, unused_mode, unused_model_hparams):
def hparams(self, defaults, unused_model_hparams):
super(BabiQaConcat, self).hparams(defaults, unused_model_hparams)
p = defaults
del p.modality['context']
del p.vocab_size['context']

if 'context' in p.modality:
del p.modality['context']

if 'context' in p.vocab_size:
del p.vocab_size['context']


def _problems_to_register():
Expand Down