Skip to content

Commit

Permalink
If version is the same, then only update dynamic properties when
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 224604269
  • Loading branch information
afrozenator authored and Copybara-Service committed Dec 8, 2018
1 parent 55bb2a8 commit 866e0c3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
16 changes: 14 additions & 2 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,20 @@ def read_from_directory(self, dataset_info_dir):
dataset_info_json_str = f.read()

# Parse it back into a proto.
self._info_proto = json_format.Parse(dataset_info_json_str,
dataset_info_pb2.DatasetInfo())
parsed_proto = json_format.Parse(dataset_info_json_str,
dataset_info_pb2.DatasetInfo())

# If the version in the code and version in the given file match, then only
# update the stats and schema, everything else is specified in the code and
# let it be.
if parsed_proto.version == self._info_proto.version:
self.splits = splits_lib.SplitDict.from_proto(parsed_proto.splits)
self.as_proto.schema.CopyFrom(parsed_proto.schema)
self._fully_initialized = True
return True

# Update our representation.
self._info_proto = parsed_proto

# Restore the Splits
self.splits = splits_lib.SplitDict.from_proto(self.as_proto.splits)
Expand Down
34 changes: 34 additions & 0 deletions tensorflow_datasets/core/dataset_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,39 @@ def test_statistics_generation_variable_sizes(self):
self.assertEqual(-1, schema_feature.shape.dim[1].size)
self.assertEqual(3, schema_feature.shape.dim[2].size)

def test_updates_dynamic_properties_on_version_match(self):
info = dataset_info.DatasetInfo(version="1.0.0",
description="won't be updated")
# No statistics in the above.
self.assertEqual(0, info.num_examples)
self.assertEqual(0, len(info.as_proto.schema.feature))

# Partial update will happen here.
info.read_from_directory(_INFO_DIR)

# Assert that description (things specified in the code) didn't change
# but statistics are updated.
self.assertEqual("won't be updated", info.description)

# These are dynamically computed, so will be updated.
self.assertEqual(70000, info.num_examples)
self.assertEqual(2, len(info.as_proto.schema.feature))

def test_full_update_on_version_mismatch(self):
info = dataset_info.DatasetInfo(version="2.0.0",
description="will be updated")
# No statistics in the above.
self.assertEqual(0, info.num_examples)
self.assertEqual(0, len(info.as_proto.schema.feature))

# Full update should happen here.
info.read_from_directory(_INFO_DIR)

# Assert that description (things specified in the code) didn't change
# but statistics are updated.
self.assertNotEqual("will be updated", info.description)
self.assertEqual(70000, info.num_examples)
self.assertEqual(2, len(info.as_proto.schema.feature))

if __name__ == "__main__":
tf.test.main()

0 comments on commit 866e0c3

Please sign in to comment.