Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 75b75f2

Browse files
authored
Merge pull request #411 from rsepassi/push
v1.2.8
2 parents 8594b4c + 8d191e4 commit 75b75f2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2373
-909
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ script:
2424
- mkdir $T2T_TRAIN_DIR
2525
- t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
2626
- t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
27-
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10,use_last_position_only=True'
27+
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10'
2828
git:
2929
depth: 3

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ t2t-decoder \
124124
--model=$MODEL \
125125
--hparams_set=$HPARAMS \
126126
--output_dir=$TRAIN_DIR \
127-
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA,use_last_position_only=True" \
127+
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
128128
--decode_from_file=$DECODE_FILE
129129
130130
cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes

docs/cloud_tpu.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Running on Cloud TPUs
2+
3+
Tensor2Tensor supports running on Google Cloud Platforms TPUs, chips specialized
4+
for ML training.
5+
6+
Not all models are supported but we've tested so far with Transformer (sequence
7+
model) as well as Xception (image model).
8+
9+
To run on TPUs, you need to be part of the alpha program; if you're not, these
10+
commands won't work for you currently, but access will expand soon, so get
11+
excited for your future ML supercomputers in the cloud.
12+
13+
## Tutorial: Transformer En-De translation on TPU
14+
15+
Set your default zone to a TPU-enabled zone. TPU machines are only available in
16+
certain zones for now.
17+
```
18+
gcloud config set compute/zone us-central1-f
19+
```
20+
21+
Launch a GCE instance; this will run the Python trainer.
22+
```
23+
gcloud compute instances create $USER-vm \
24+
--machine-type=n1-standard-8 \
25+
--image-family=tf-nightly \
26+
--image-project=ml-images \
27+
--scopes=https://www.googleapis.com/auth/cloud-platform
28+
```
29+
30+
Launch the TPU instance; the Python program will connect to this to train on the
31+
TPU device.
32+
```
33+
TPU_IP=10.240.0.2
34+
gcloud alpha compute tpus create \
35+
$USER-tpu \
36+
--range=${TPU_IP/%2/0}/29 \
37+
--version=nightly
38+
```
39+
40+
To see all TPU instances running: `gcloud alpha compute tpus list`. The
41+
`TPU_IP` should be unique amongst the list and follow the format `10.240.i.2`.
42+
43+
Generate data to GCS
44+
If you already have the data locally, use `gsutil cp` to cp to GCS.
45+
```
46+
DATA_DIR=gs://my-bucket/t2t/data/
47+
t2t-datagen --problem=translate_ende_wmt8k --data_dir=$DATA_DIR
48+
```
49+
50+
SSH in with port forwarding for TensorBoard
51+
```
52+
gcloud compute ssh $USER-vm -L 6006:localhost:6006
53+
```
54+
55+
Now that you're on the cloud instance, install T2T:
56+
```
57+
pip install tensor2tensor
58+
```
59+
60+
Setup some vars used below. `TPU_IP` and `DATA_DIR` should be the same as what
61+
was used above. Note that the `DATA_DIR` and `OUT_DIR` must be GCS buckets.
62+
```
63+
TPU_IP=<IP of TPU machine>
64+
DATA_DIR=gs://my-bucket/t2t/data/
65+
OUT_DIR=gs://my-bucket/t2t/training/
66+
TPU_MASTER=grpc://$TPU_IP:8470
67+
```
68+
69+
Launch TensorBoard in the background so you can monitor training:
70+
```
71+
tensorboard --logdir=$OUT_DIR > /tmp/tensorboard_logs.txt 2>&1 &
72+
```
73+
74+
Train and evaluate.
75+
```
76+
t2t-tpu-trainer \
77+
--master=$TPU_MASTER \
78+
--data_dir=$DATA_DIR \
79+
--output_dir=$OUT_DIR \
80+
--problems=translate_ende_wmt8k \
81+
--model=transformer \
82+
--hparams_set=transformer_tiny_tpu \
83+
--train_steps=10 \
84+
--eval_steps=10 \
85+
--local_eval_frequency=10 \
86+
--iterations_per_loop=10
87+
```
88+
89+
The above command will train for 10 steps, then evaluate for 10 steps. You can
90+
(and should) increase the number of total training steps with the
91+
`--train_steps` flag. Evaluation will happen every `--local_eval_frequency`
92+
steps, each time for `--eval_steps`. When you increase then number of training
93+
steps, also increase `--iterations_per_loop`, which controls how frequently the
94+
TPU machine returns control to the Python code (1000 seems like a fine number).
95+
96+
Back on your local machine, open your browser and navigate to `localhost:6006`
97+
for TensorBoard.
98+
99+
Voila. Enjoy your new supercomputer.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.2.7',
8+
version='1.2.8',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/data_generators/image.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,17 @@ def resize_by_area(img, size):
5151

5252
class ImageProblem(problem.Problem):
5353

54-
def example_reading_spec(self, label_key=None):
55-
if label_key is None:
56-
label_key = "image/class/label"
54+
def example_reading_spec(self, label_repr=None):
55+
if label_repr is None:
56+
label_repr = ("image/class/label", tf.FixedLenFeature((1,), tf.int64))
5757

5858
data_fields = {
5959
"image/encoded": tf.FixedLenFeature((), tf.string),
6060
"image/format": tf.FixedLenFeature((), tf.string),
61-
label_key: tf.VarLenFeature(tf.int64)
6261
}
62+
label_key, label_type = label_repr # pylint: disable=unpacking-non-sequence
63+
data_fields[label_key] = label_type
64+
6365
data_items_to_decoders = {
6466
"inputs":
6567
tf.contrib.slim.tfexample_decoder.Image(
@@ -244,8 +246,9 @@ def hparams(self, defaults, unused_model_hparams):
244246

245247
def example_reading_spec(self):
246248
label_key = "image/unpadded_label"
249+
label_type = tf.VarLenFeature(tf.int64)
247250
return super(ImageFSNS, self).example_reading_spec(
248-
self, label_key=label_key)
251+
self, label_repr=(label_key, label_type))
249252

250253

251254
class Image2ClassProblem(ImageProblem):
@@ -283,10 +286,8 @@ def generator(self, data_dir, tmp_dir, is_training):
283286

284287
def hparams(self, defaults, unused_model_hparams):
285288
p = defaults
286-
small_modality = "%s:small_image_modality" % registry.Modalities.IMAGE
287-
modality = small_modality if self.is_small else registry.Modalities.IMAGE
288-
p.input_modality = {"inputs": (modality, None)}
289-
p.target_modality = ("%s:2d" % registry.Modalities.CLASS_LABEL,
289+
p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
290+
p.target_modality = (registry.Modalities.CLASS_LABEL,
290291
self.num_classes)
291292
p.batch_size_multiplier = 4 if self.is_small else 256
292293
p.max_expected_batch_size_per_shard = 8 if self.is_small else 2
@@ -382,6 +383,38 @@ def preprocess_example(self, example, mode, unused_hparams):
382383
return example
383384

384385

386+
@registry.register_problem
387+
class ImageImagenet64(Image2ClassProblem):
388+
"""Imagenet rescaled to 64x64."""
389+
390+
def dataset_filename(self):
391+
return "image_imagenet" # Reuse Imagenet data.
392+
393+
@property
394+
def is_small(self):
395+
return True # Modalities like for CIFAR.
396+
397+
@property
398+
def num_classes(self):
399+
return 1000
400+
401+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
402+
# TODO(lukaszkaiser): find a better way than printing this.
403+
print("To generate the ImageNet dataset in the proper format, follow "
404+
"instructions at https://github.com/tensorflow/models/blob/master"
405+
"/inception/README.md#getting-started")
406+
407+
def preprocess_example(self, example, mode, unused_hparams):
408+
inputs = example["inputs"]
409+
# Just resize with area.
410+
if self._was_reversed:
411+
example["inputs"] = resize_by_area(inputs, 64)
412+
else:
413+
example = imagenet_preprocess_example(example, mode)
414+
example["inputs"] = example["inputs"] = resize_by_area(inputs, 64)
415+
return example
416+
417+
385418
@registry.register_problem
386419
class Img2imgImagenet(ImageProblem):
387420
"""Imagenet rescaled to 8x8 for input and 32x32 for output."""
@@ -623,9 +656,11 @@ def class_labels(self):
623656
]
624657

625658
def preprocess_example(self, example, mode, unused_hparams):
659+
example["inputs"].set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
626660
if mode == tf.estimator.ModeKeys.TRAIN:
627661
example["inputs"] = common_layers.cifar_image_augmentation(
628662
example["inputs"])
663+
example["inputs"] = tf.to_int64(example["inputs"])
629664
return example
630665

631666
def generator(self, data_dir, tmp_dir, is_training):
@@ -649,6 +684,7 @@ def generator(self, data_dir, tmp_dir, is_training):
649684
class ImageCifar10Plain(ImageCifar10):
650685

651686
def preprocess_example(self, example, mode, unused_hparams):
687+
example["inputs"].set_shape([_CIFAR10_IMAGE_SIZE, _CIFAR10_IMAGE_SIZE, 3])
652688
example["inputs"] = tf.to_int64(example["inputs"])
653689
return example
654690

tensor2tensor/data_generators/problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def _default_hparams():
536536
# During inference for autoregressive problems, if the batch_size is 1,
537537
# the inference will stop when the model predict a text_encoder.EOS_ID
538538
# token.
539-
stop_at_eos=int(False),
539+
stop_at_eos=False,
540540

541541
# Modalities used to map from input features to a space compatible with
542542
# chosen model architecture. One modality spec (which is a 2-tuple,

tensor2tensor/data_generators/translate_enfr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ class TranslateEnfrWmtSmallCharacters(translate.TranslateProblem):
151151
def is_character_level(self):
152152
return True
153153

154+
@property
155+
def use_small_dataset(self):
156+
return True
157+
154158
@property
155159
def vocab_name(self):
156160
return "vocab.enfr"

0 commit comments

Comments
 (0)