|
34 | 34 | import os
|
35 | 35 | import shutil
|
36 | 36 | import tarfile
|
37 |
| - |
| 37 | +import requests |
38 | 38 | import six
|
39 | 39 |
|
40 |
| -from tensor2tensor.data_generators import generator_utils |
41 | 40 | from tensor2tensor.data_generators import problem
|
42 | 41 | from tensor2tensor.data_generators import text_encoder
|
43 | 42 | from tensor2tensor.data_generators import text_problems
|
@@ -109,8 +108,12 @@ def _prepare_babi_data(tmp_dir, data_dir):
|
109 | 108 | if not tf.gfile.Exists(data_dir):
|
110 | 109 | tf.gfile.MakeDirs(data_dir)
|
111 | 110 |
|
112 |
| - # TODO(dehghani@): find a solution for blocking user-agent (download) |
113 |
| - file_path = generator_utils.maybe_download(tmp_dir, _TAR, _URL) |
| 111 | + file_path = os.path.join(tmp_dir, _TAR) |
| 112 | + headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'} |
| 113 | + resp = requests.get(_URL, headers=headers) |
| 114 | + with open(file_path, 'wb') as f: |
| 115 | + f.write(resp.content) |
| 116 | + |
114 | 117 | tar = tarfile.open(file_path)
|
115 | 118 | tar.extractall(tmp_dir)
|
116 | 119 | tar.close()
|
@@ -449,8 +452,12 @@ def preprocess_example(self, example, unused_mode, unused_model_hparams):
|
449 | 452 | def hparams(self, defaults, unused_model_hparams):
|
450 | 453 | super(BabiQaConcat, self).hparams(defaults, unused_model_hparams)
|
451 | 454 | p = defaults
|
452 |
| - del p.modality['context'] |
453 |
| - del p.vocab_size['context'] |
| 455 | + |
| 456 | + if 'context' in p.modality: |
| 457 | + del p.modality['context'] |
| 458 | + |
| 459 | + if 'context' in p.vocab_size: |
| 460 | + del p.vocab_size['context'] |
454 | 461 |
|
455 | 462 |
|
456 | 463 | def _problems_to_register():
|
|
0 commit comments