Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8f3a7fd

Browse files
authored
Merge pull request #228 from rsepassi/push
v1.1.9
2 parents 45a787e + f5d5405 commit 8f3a7fd

33 files changed

+1464
-320
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.1.8',
8+
version='1.1.9',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/bin/t2t-datagen

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ flags.DEFINE_integer("num_shards", 0, "How many shards to use. Ignored for "
6666
"registered Problems.")
6767
flags.DEFINE_integer("max_cases", 0,
6868
"Maximum number of cases to generate (unbounded if 0).")
69+
flags.DEFINE_bool("only_list", False,
70+
"If true, we only list the problems that will be generated.")
6971
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
7072
flags.DEFINE_integer("task_id", -1, "For distributed data generation.")
7173
flags.DEFINE_string("t2t_usr_dir", "",
@@ -81,33 +83,33 @@ _SUPPORTED_PROBLEM_GENERATORS = {
8183
"algorithmic_algebra_inverse": (
8284
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
8385
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
84-
"wmt_parsing_tokens_8k": (
86+
"parsing_english_ptb8k": (
8587
lambda: wmt.parsing_token_generator(
8688
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
8789
lambda: wmt.parsing_token_generator(
8890
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)),
89-
"wsj_parsing_tokens_16k": (
91+
"parsing_english_ptb16k": (
9092
lambda: wsj_parsing.parsing_token_generator(
9193
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
9294
lambda: wsj_parsing.parsing_token_generator(
9395
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)),
94-
"wmt_ende_bpe32k": (
96+
"translate_ende_wmt_bpe32k": (
9597
lambda: wmt.ende_bpe_token_generator(
9698
FLAGS.data_dir, FLAGS.tmp_dir, True),
9799
lambda: wmt.ende_bpe_token_generator(
98100
FLAGS.data_dir, FLAGS.tmp_dir, False)),
99-
"lm1b_32k": (
101+
"languagemodel_1b32k": (
100102
lambda: lm1b.generator(FLAGS.tmp_dir, True),
101103
lambda: lm1b.generator(FLAGS.tmp_dir, False)
102104
),
103-
"lm1b_characters": (
105+
"languagemodel_1b_characters": (
104106
lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
105107
lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
106108
),
107109
"image_celeba_tune": (
108110
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
109111
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
110-
"snli_32k": (
112+
"inference_snli32k": (
111113
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
112114
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
113115
),
@@ -181,7 +183,11 @@ def main(_):
181183
"Data will be written to default data_dir=%s.",
182184
FLAGS.data_dir)
183185

184-
tf.logging.info("Generating problems:\n * %s\n" % "\n * ".join(problems))
186+
tf.logging.info("Generating problems:\n%s"
187+
% registry.display_list_by_prefix(problems,
188+
starting_spaces=4))
189+
if FLAGS.only_list:
190+
return
185191
for problem in problems:
186192
set_random_seed()
187193

@@ -210,7 +216,7 @@ def generate_data_for_problem(problem):
210216

211217

212218
def generate_data_for_registered_problem(problem_name):
213-
tf.logging.info("Generating training data for %s.", problem_name)
219+
tf.logging.info("Generating data for %s.", problem_name)
214220
if FLAGS.num_shards:
215221
raise ValueError("--num_shards should not be set for registered Problem.")
216222
problem = registry.problem(problem_name)

tensor2tensor/data_generators/cipher.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
@registry.register_problem
32-
class CipherShift5(algorithmic.AlgorithmicProblem):
32+
class AlgorithmicCipherShift5(algorithmic.AlgorithmicProblem):
3333
"""Shift cipher."""
3434

3535
@property
@@ -62,7 +62,7 @@ def dev_length(self):
6262

6363

6464
@registry.register_problem
65-
class CipherVigenere5(algorithmic.AlgorithmicProblem):
65+
class AlgorithmicCipherVigenere5(algorithmic.AlgorithmicProblem):
6666
"""Vinegre cipher."""
6767

6868
@property
@@ -95,7 +95,7 @@ def dev_length(self):
9595

9696

9797
@registry.register_problem
98-
class CipherShift200(CipherShift5):
98+
class AlgorithmicCipherShift200(AlgorithmicCipherShift5):
9999
"""Shift cipher."""
100100

101101
@property
@@ -110,7 +110,7 @@ def distribution(self):
110110

111111

112112
@registry.register_problem
113-
class CipherVigenere200(CipherVigenere5):
113+
class AlgorithmicCipherVigenere200(AlgorithmicCipherVigenere5):
114114
"""Vinegre cipher."""
115115

116116
@property

tensor2tensor/data_generators/desc2code.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def generator_target():
209209
}
210210

211211

212-
@registry.register_problem("desc2code_py")
213-
class Desc2CodePyProblem(Desc2CodeProblem):
212+
@registry.register_problem
213+
class ProgrammingDesc2codePy(Desc2CodeProblem):
214214
"""Description2Code for python problem."""
215215

216216
@property
@@ -222,8 +222,8 @@ def preprocess_target(self, target):
222222
return target.replace("\t", " ")
223223

224224

225-
@registry.register_problem("desc2code_cpp")
226-
class Desc2CodeCppProblem(Desc2CodeProblem):
225+
@registry.register_problem
226+
class ProgrammingDesc2codeCpp(Desc2CodeProblem):
227227
"""Description2Code for C++ problem."""
228228

229229
@property

tensor2tensor/data_generators/desc2code_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Desc2codeTest(tf.test.TestCase):
4747

4848
def testCppPreprocess(self):
4949
"""Check that the file correctly preprocess the code source."""
50-
cpp_pb = desc2code.Desc2CodeCppProblem()
50+
cpp_pb = desc2code.ProgrammingDesc2codeCpp()
5151

5252
self.assertEqual( # Add space beween two lines
5353
cpp_pb.preprocess_target("firstline//comm1\nsecondline//comm2\n"),

tensor2tensor/data_generators/gene_expression.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def eval_metrics(self):
176176
return [metrics.Metrics.LOG_POISSON, metrics.Metrics.R2]
177177

178178

179-
@registry.register_problem("gene_expression_cage10")
180-
class GeneExpressionCAGE10(GeneExpressionProblem):
179+
@registry.register_problem
180+
class GenomicsExpressionCage10(GeneExpressionProblem):
181181

182182
@property
183183
def download_url(self):
@@ -188,8 +188,8 @@ def h5_file(self):
188188
return "cage10.h5"
189189

190190

191-
@registry.register_problem("gene_expression_gm12878")
192-
class GeneExpressionGM12878(GeneExpressionProblem):
191+
@registry.register_problem
192+
class GenomicsExpressionGm12878(GeneExpressionProblem):
193193

194194
@property
195195
def download_url(self):
@@ -200,8 +200,8 @@ def h5_file(self):
200200
return "gm12878.h5"
201201

202202

203-
@registry.register_problem("gene_expression_l262k")
204-
class GeneExpressionL262k(GeneExpressionProblem):
203+
@registry.register_problem
204+
class GenomicsExpressionL262k(GeneExpressionProblem):
205205

206206
@property
207207
def h5_file(self):

tensor2tensor/data_generators/ice_parsing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def tabbed_parsing_character_generator(tmp_dir, train):
6262
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)
6363

6464

65-
@registry.register_problem("ice_parsing_tokens")
66-
class IceParsingTokens(problem.Problem):
65+
@registry.register_problem
66+
class ParsingIcelandic16k(problem.Problem):
6767
"""Problem spec for parsing tokenized Icelandic text to constituency trees."""
6868

6969
@property

tensor2tensor/data_generators/image.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,21 @@ def dataset_filename(self):
214214
def is_small(self):
215215
return True # Modalities like for CIFAR.
216216

217-
def preprocess_examples(self, examples, mode):
218-
examples = imagenet_preprocess_examples(examples, mode)
219-
examples["inputs"] = tf.to_int64(
220-
tf.image.resize_images(examples["inputs"], [32, 32]))
217+
@property
218+
def num_classes(self):
219+
return 1000
220+
221+
def preprocess_examples(self, examples, mode, hparams):
222+
# Just resize with area.
223+
if self._was_reversed:
224+
examples["inputs"] = tf.to_int64(
225+
tf.image.resize_images(examples["inputs"], [32, 32],
226+
tf.image.ResizeMethod.AREA))
227+
else:
228+
examples = imagenet_preprocess_examples(examples, mode)
229+
examples["inputs"] = tf.to_int64(
230+
tf.image.resize_images(examples["inputs"], [32, 32]))
231+
return examples
221232

222233

223234
def image_generator(images, labels):

tensor2tensor/data_generators/problem.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def preprocess_examples_common(examples, hparams):
9898
examples["inputs"] = examples["inputs"][:hparams.max_input_seq_length]
9999
if hparams.max_target_seq_length > 0:
100100
examples["targets"] = examples["targets"][:hparams.max_target_seq_length]
101-
if hparams.prepend_inputs_to_targets:
101+
if hparams.prepend_mode != "none":
102102
examples["targets"] = tf.concat(
103103
[examples["inputs"], [0], examples["targets"]], 0)
104104
return examples
@@ -410,11 +410,12 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
410410
generator_utils.generate_files(
411411
self.generator(data_dir, tmp_dir, True), all_paths)
412412
generator_utils.shuffle_dataset(all_paths)
413-
generator_utils.generate_dataset_and_shuffle(
414-
self.generator(data_dir, tmp_dir, True),
415-
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
416-
self.generator(data_dir, tmp_dir, False),
417-
self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False))
413+
else:
414+
generator_utils.generate_dataset_and_shuffle(
415+
self.generator(data_dir, tmp_dir, True),
416+
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
417+
self.generator(data_dir, tmp_dir, False),
418+
self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False))
418419

419420
def feature_encoders(self, data_dir):
420421
if self.is_character_level:

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,16 +492,16 @@ def image_celeba(unused_model_hparams):
492492
lambda p: audio_wsj_tokens(p, 2**13),
493493
"audio_wsj_tokens_8k_test":
494494
lambda p: audio_wsj_tokens(p, 2**13),
495-
"lm1b_characters":
495+
"languagemodel_1b_characters":
496496
lm1b_characters,
497-
"lm1b_32k":
497+
"languagemodel_1b32k":
498498
lm1b_32k,
499-
"wmt_parsing_tokens_8k":
499+
"parsing_english_ptb8k":
500500
lambda p: wmt_parsing_tokens(p, 2**13),
501-
"wsj_parsing_tokens_16k":
501+
"parsing_english_ptb16k":
502502
lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
503503
p, "wsj", 2**14, 2**9),
504-
"wmt_ende_bpe32k":
504+
"translate_ende_wmt_bpe32k":
505505
wmt_ende_bpe32k,
506506
"image_celeba_tune":
507507
image_celeba,

tensor2tensor/data_generators/ptb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def _generator(self, filename, encoder):
157157
yield {"inputs": [0], "targets": tok}
158158

159159

160-
@registry.register_problem("lm_ptb_10k")
161-
class LmPtb10k(PTBProblem):
160+
@registry.register_problem
161+
class LanguagemodelPtb10k(PTBProblem):
162162
"""A class for generating PTB data, 10k vocab."""
163163

164164
@property
@@ -167,7 +167,7 @@ def is_character_level(self):
167167

168168

169169
@registry.register_problem
170-
class LmPtbCharacters(PTBProblem):
170+
class LanguagemodelPtbCharacters(PTBProblem):
171171
"""A class for generating PTB data, character-level."""
172172

173173
@property

tensor2tensor/data_generators/wiki.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def _page_title(page):
8181

8282

8383
@registry.register_problem
84-
class Wiki32k(problem.Text2TextProblem):
85-
"""A class for generating PTB data."""
84+
class LanguagemodelWikiFull32k(problem.Text2TextProblem):
85+
"""A language model on full English Wikipedia."""
8686

8787
@property
8888
def is_character_level(self):
@@ -129,3 +129,12 @@ def generator(self, data_dir, tmp_dir, _):
129129
encoded = encoder.encode(page) + [EOS]
130130
encoded_title = encoder.encode(title) + [EOS]
131131
yield {"inputs": encoded_title, "targets": encoded}
132+
133+
134+
@registry.register_problem
135+
class LanguagemodelWikiFull8k(problem.Text2TextProblem):
136+
"""A language model on full English Wikipedia."""
137+
138+
@property
139+
def targeted_vocab_size(self):
140+
return 2**13 # 8192

0 commit comments

Comments
 (0)