From 866e0c3b14d43f54040324c64b330ed19c3bf85b Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Fri, 7 Dec 2018 17:33:50 -0800 Subject: [PATCH] If version is the same, then only update dynamic properties when PiperOrigin-RevId: 224604269 --- tensorflow_datasets/core/dataset_info.py | 16 +++++++-- tensorflow_datasets/core/dataset_info_test.py | 34 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/tensorflow_datasets/core/dataset_info.py b/tensorflow_datasets/core/dataset_info.py index 9134376f282..d972de9e789 100644 --- a/tensorflow_datasets/core/dataset_info.py +++ b/tensorflow_datasets/core/dataset_info.py @@ -298,8 +298,20 @@ def read_from_directory(self, dataset_info_dir): dataset_info_json_str = f.read() # Parse it back into a proto. - self._info_proto = json_format.Parse(dataset_info_json_str, - dataset_info_pb2.DatasetInfo()) + parsed_proto = json_format.Parse(dataset_info_json_str, + dataset_info_pb2.DatasetInfo()) + + # If the version in the code and version in the given file match, then only + # update the stats and schema, everything else is specified in the code and + # let it be. + if parsed_proto.version == self._info_proto.version: + self.splits = splits_lib.SplitDict.from_proto(parsed_proto.splits) + self.as_proto.schema.CopyFrom(parsed_proto.schema) + self._fully_initialized = True + return True + + # Update our representation. + self._info_proto = parsed_proto # Restore the Splits self.splits = splits_lib.SplitDict.from_proto(self.as_proto.splits) diff --git a/tensorflow_datasets/core/dataset_info_test.py b/tensorflow_datasets/core/dataset_info_test.py index 03ed77b5cc6..e3798752ca8 100644 --- a/tensorflow_datasets/core/dataset_info_test.py +++ b/tensorflow_datasets/core/dataset_info_test.py @@ -174,5 +174,39 @@ def test_statistics_generation_variable_sizes(self): self.assertEqual(-1, schema_feature.shape.dim[1].size) self.assertEqual(3, schema_feature.shape.dim[2].size) + def test_updates_dynamic_properties_on_version_match(self): + info = dataset_info.DatasetInfo(version="1.0.0", + description="won't be updated") + # No statistics in the above. + self.assertEqual(0, info.num_examples) + self.assertEqual(0, len(info.as_proto.schema.feature)) + + # Partial update will happen here. + info.read_from_directory(_INFO_DIR) + + # Assert that description (things specified in the code) didn't change + # but statistics are updated. + self.assertEqual("won't be updated", info.description) + + # These are dynamically computed, so will be updated. + self.assertEqual(70000, info.num_examples) + self.assertEqual(2, len(info.as_proto.schema.feature)) + + def test_full_update_on_version_mismatch(self): + info = dataset_info.DatasetInfo(version="2.0.0", + description="will be updated") + # No statistics in the above. + self.assertEqual(0, info.num_examples) + self.assertEqual(0, len(info.as_proto.schema.feature)) + + # Full update should happen here. + info.read_from_directory(_INFO_DIR) + + # Assert that description (things specified in the code) didn't change + # but statistics are updated. + self.assertNotEqual("will be updated", info.description) + self.assertEqual(70000, info.num_examples) + self.assertEqual(2, len(info.as_proto.schema.feature)) + if __name__ == "__main__": tf.test.main()