Skip to content

Commit

Permalink
Fix Compute statistics for graph mode.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 223107937
  • Loading branch information
afrozenator authored and Copybara-Service committed Nov 28, 2018
1 parent eae0f1e commit 10afbf9
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 38 deletions.
49 changes: 28 additions & 21 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import tensorflow as tf

from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core import dataset_utils
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.proto import dataset_info_pb2
from google.protobuf import json_format
Expand Down Expand Up @@ -249,19 +250,17 @@ def read_from_directory(self, dataset_info_dir):
#

_FEATURE_TYPE_MAP = {
tf.uint8: schema_pb2.INT,
tf.int64: schema_pb2.INT,
tf.float32: schema_pb2.FLOAT,
tf.float16: schema_pb2.FLOAT,
tf.float64: schema_pb2.FLOAT,
tf.int8: schema_pb2.INT,
tf.int16: schema_pb2.INT,
tf.int32: schema_pb2.INT,
tf.int64: schema_pb2.INT,
tf.uint8: schema_pb2.INT,
tf.uint16: schema_pb2.INT,
tf.uint32: schema_pb2.INT,
tf.uint64: schema_pb2.INT,
np.float16: schema_pb2.FLOAT,
np.float32: schema_pb2.FLOAT,
np.float64: schema_pb2.FLOAT,
np.int8: schema_pb2.INT,
np.int16: schema_pb2.INT,
np.int32: schema_pb2.INT,
np.int64: schema_pb2.INT,
np.uint8: schema_pb2.INT,
np.uint16: schema_pb2.INT,
np.uint32: schema_pb2.INT,
np.uint64: schema_pb2.INT,
}

_SCHEMA_TYPE_MAP = {
Expand Down Expand Up @@ -293,7 +292,7 @@ def get_dataset_feature_statistics(builder, split):
feature_to_min = {}
feature_to_max = {}

for example in dataset:
for example in dataset_utils.iterate_over_dataset(dataset):
statistics.num_examples += 1

assert isinstance(example, dict)
Expand All @@ -304,14 +303,22 @@ def get_dataset_feature_statistics(builder, split):
# Update the number of examples this feature appears in.
feature_to_num_examples[feature_name] += 1

feature_shape = example[feature_name].shape
feature_dtype = example[feature_name].dtype
feature_np = example[feature_name].numpy()
feature_np = example[feature_name]

# For compatibility in graph and eager mode, we can get PODs here and
# everything may not be neatly wrapped up in numpy's ndarray.

# TODO(afrozm): Use tf.data.Dataset.{output_types, output_shapes} here.
feature_shape = ()
feature_dtype = type(feature_np)

if isinstance(feature_np, np.ndarray):
feature_shape = feature_np.shape
feature_dtype = feature_np.dtype.type

feature_min, feature_max = None, None
is_numeric = (
feature_dtype.is_floating or feature_dtype.is_integer or
feature_dtype.is_bool)
is_numeric = (np.issubdtype(feature_dtype, np.number) or
feature_dtype == np.bool_)
if is_numeric:
feature_min = np.min(feature_np)
feature_max = np.max(feature_np)
Expand Down Expand Up @@ -354,7 +361,7 @@ def get_dataset_feature_statistics(builder, split):

# TODO(afrozm): What do we do for non fixed size shapes?
# What to do for scalars?
for dim in feature_to_shape[feature_name].as_list():
for dim in feature_to_shape[feature_name]:
feature.shape.dim.add().size = dim
feature_type = feature_to_dtype[feature_name]
feature.type = _FEATURE_TYPE_MAP.get(feature_type, schema_pb2.BYTES)
Expand Down
39 changes: 39 additions & 0 deletions tensorflow_datasets/core/dataset_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,39 @@
import os
import tensorflow as tf

from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info
from tensorflow_datasets.core import features
from tensorflow_datasets.core import splits
from tensorflow_datasets.core import test_utils

pkg_dir, _ = os.path.split(__file__)
_TESTDATA = os.path.join(pkg_dir, "test_data")
_NON_EXISTENT_DIR = os.path.join(pkg_dir, "non_existent_dir")


class DummyDatasetSharedGenerator(dataset_builder.GeneratorBasedDatasetBuilder):

def _split_generators(self, dl_manager):
# Split the 30 examples from the generator into 2 train shards and 1 test
# shard.
del dl_manager
return [splits.SplitGenerator(
name=[splits.Split.TRAIN, splits.Split.TEST],
num_shards=[2, 1],
)]

def _info(self):
return dataset_info.DatasetInfo(
features=features.FeaturesDict({"x": tf.int64}),
supervised_keys=("x", "x"),
)

def _generate_samples(self):
for i in range(30):
yield self.info.features.encode_sample({"x": i})


class DatasetInfoTest(tf.test.TestCase):

def test_undefined_dir(self):
Expand Down Expand Up @@ -89,6 +114,20 @@ def test_writing(self):
# Assert what was read and then written and read again is the same.
self.assertEqual(existing_json, new_json)

@tf.contrib.eager.run_test_in_graph_and_eager_modes
def test_statistics_generation(self):
with test_utils.tmp_dir(self.get_temp_dir()) as tmp_dir:
builder = DummyDatasetSharedGenerator(data_dir=tmp_dir)
builder.download_and_prepare(compute_stats=True)

# Overall
self.assertEqual(30, builder.info.num_examples)

# Per split.
test_split = builder.info.splits["test"].get_proto()
train_split = builder.info.splits["train"].get_proto()
self.assertEqual(10, test_split.statistics.num_examples)
self.assertEqual(20, train_split.statistics.num_examples)

if __name__ == "__main__":
tf.test.main()
34 changes: 17 additions & 17 deletions tensorflow_datasets/core/test_data/dataset_info.json
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
{
"supervisedKeys": {
"input": "image",
"output": "label"
},
"citation": "Y. Lecun and C. Cortes, \"The MNIST database of handwritten digits,\" 1998.\n[Online]. Available: http://yann.lecun.com/exdb/mnist/",
"description": "The MNIST database of handwritten digits, has a training set of 60,000 examples, and a test set of 10,000 examples.",
"splits": [
{
"numShards": "1",
"name": "test",
"statistics": {
"numExamples": "10000",
"features": [
{
"name": "image",
Expand All @@ -29,14 +24,14 @@
"max": 9.0
}
}
],
"numExamples": "10000"
]
}
},
{
"numShards": "10",
"name": "train",
"statistics": {
"numExamples": "60000",
"features": [
{
"name": "image",
Expand All @@ -56,15 +51,16 @@
"max": 9.0
}
}
],
"numExamples": "60000"
]
}
}
],
"name": "mnist",
"sizeInBytes": "11534336",
"schema": {
"feature": [
{
"type": "INT",
"name": "image",
"shape": {
"dim": [
{
Expand All @@ -77,13 +73,11 @@
"size": "1"
}
]
},
"name": "image",
"type": "INT"
}
},
{
"name": "label",
"type": "INT"
"type": "INT",
"name": "label"
}
]
},
Expand All @@ -92,5 +86,11 @@
"http://yann.lecun.com/exdb/mnist/"
]
},
"sizeInBytes": "11534336"
"description": "The MNIST database of handwritten digits, has a training set of 60,000 examples, and a test set of 10,000 examples.",
"supervisedKeys": {
"output": "label",
"input": "image"
},
"name": "mnist",
"citation": "Y. Lecun and C. Cortes, \"The MNIST database of handwritten digits,\" 1998.\n[Online]. Available: http://yann.lecun.com/exdb/mnist/"
}

0 comments on commit 10afbf9

Please sign in to comment.