Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 121ee60

Browse files
MechCodercopybara-github
authored andcommitted
Adds MovingMnist to the T2T dataset generators using utilites from tensorflow-datasets wherever appropriate.
PiperOrigin-RevId: 250750540
1 parent ed518a1 commit 121ee60

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"tensor2tensor.data_generators.lm1b_imdb",
5555
"tensor2tensor.data_generators.lm1b_mnli",
5656
"tensor2tensor.data_generators.mnist",
57+
"tensor2tensor.data_generators.moving_mnist",
5758
"tensor2tensor.data_generators.mrpc",
5859
"tensor2tensor.data_generators.mscoco",
5960
"tensor2tensor.data_generators.multinli",
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# coding=utf-8
2+
# Copyright 2019 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Moving MNIST dataset.
17+
18+
Unsupervised Learning of Video Representations using LSTMs
19+
Nitish Srivastava, Elman Mansimov, Ruslan Salakhutdinov
20+
https://arxiv.org/abs/1502.04681
21+
22+
"""
23+
24+
from __future__ import absolute_import
25+
from __future__ import division
26+
from __future__ import print_function
27+
28+
import os
29+
import numpy as np
30+
31+
from tensor2tensor.data_generators import generator_utils
32+
from tensor2tensor.data_generators import problem
33+
from tensor2tensor.data_generators import video_utils
34+
from tensor2tensor.layers import modalities
35+
from tensor2tensor.utils import registry
36+
37+
import tensorflow as tf
38+
import tensorflow_datasets as tfds
39+
from tensorflow_datasets.video import moving_sequence
40+
41+
42+
DATA_URL = (
43+
"http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy")
44+
SPLIT_TO_SIZE = {
45+
problem.DatasetSplit.TRAIN: 100000,
46+
problem.DatasetSplit.EVAL: 10000,
47+
problem.DatasetSplit.TEST: 10000}
48+
49+
50+
@registry.register_problem
51+
class VideoMovingMnist(video_utils.VideoProblem):
52+
"""MovingMnist Dataset."""
53+
54+
@property
55+
def num_channels(self):
56+
return 1
57+
58+
@property
59+
def frame_height(self):
60+
return 64
61+
62+
@property
63+
def frame_width(self):
64+
return 64
65+
66+
@property
67+
def is_generate_per_split(self):
68+
return True
69+
70+
# num_videos * num_frames
71+
@property
72+
def total_number_of_frames(self):
73+
return 100000 * 20
74+
75+
def max_frames_per_video(self, hparams):
76+
return 20
77+
78+
@property
79+
def random_skip(self):
80+
return False
81+
82+
def eval_metrics(self):
83+
return []
84+
85+
@property
86+
def dataset_splits(self):
87+
"""Splits of data to produce and number of output shards for each."""
88+
return [
89+
{"split": problem.DatasetSplit.TRAIN, "shards": 10},
90+
{"split": problem.DatasetSplit.EVAL, "shards": 1},
91+
{"split": problem.DatasetSplit.TEST, "shards": 1}]
92+
93+
@property
94+
def extra_reading_spec(self):
95+
"""Additional data fields to store on disk and their decoders."""
96+
data_fields = {
97+
"frame_number": tf.FixedLenFeature([1], tf.int64),
98+
}
99+
decoders = {
100+
"frame_number": tf.contrib.slim.tfexample_decoder.Tensor(
101+
tensor_key="frame_number"),
102+
}
103+
return data_fields, decoders
104+
105+
def hparams(self, defaults, unused_model_hparams):
106+
p = defaults
107+
p.modality = {"inputs": modalities.ModalityType.VIDEO,
108+
"targets": modalities.ModalityType.VIDEO}
109+
p.vocab_size = {"inputs": 256,
110+
"targets": 256}
111+
112+
def get_test_iterator(self, tmp_dir):
113+
path = generator_utils.maybe_download(
114+
tmp_dir, os.path.basename(DATA_URL), DATA_URL)
115+
with tf.io.gfile.GFile(path, "rb") as fp:
116+
mnist_test = np.load(fp)
117+
mnist_test = np.transpose(mnist_test, (1, 0, 2, 3))
118+
mnist_test = np.expand_dims(mnist_test, axis=-1)
119+
mnist_test = tf.data.Dataset.from_tensor_slices(mnist_test)
120+
return mnist_test.make_initializable_iterator()
121+
122+
def map_fn(self, image, label):
123+
sequence = moving_sequence.image_as_moving_sequence(
124+
image, sequence_length=20)
125+
return sequence.image_sequence
126+
127+
def get_train_iterator(self):
128+
mnist_ds = tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True)
129+
mnist_ds = mnist_ds.repeat()
130+
moving_mnist_ds = mnist_ds.map(self.map_fn).batch(2)
131+
moving_mnist_ds = moving_mnist_ds.map(lambda x: tf.reduce_max(x, axis=0))
132+
return moving_mnist_ds.make_initializable_iterator()
133+
134+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
135+
with tf.Graph().as_default():
136+
# train and eval set are generated on-the-fly.
137+
# test set is the official test-set.
138+
if dataset_split == problem.DatasetSplit.TEST:
139+
moving_ds = self.get_test_iterator(tmp_dir)
140+
else:
141+
moving_ds = self.get_train_iterator()
142+
143+
next_video = moving_ds.get_next()
144+
with tf.Session() as sess:
145+
sess.run(moving_ds.initializer)
146+
147+
n_samples = SPLIT_TO_SIZE[dataset_split]
148+
for _ in range(n_samples):
149+
next_video_np = sess.run(next_video)
150+
for frame_number, frame in enumerate(next_video_np):
151+
yield {
152+
"frame_number": [frame_number],
153+
"frame": frame,
154+
}

0 commit comments

Comments
 (0)