Skip to content

Commit

Permalink
Add metadata field to DatasetInfo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 246604783
  • Loading branch information
Conchylicultor authored and copybara-github committed May 4, 2019
1 parent 0c8841e commit ab791a4
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 11 deletions.
6 changes: 6 additions & 0 deletions docs/release_notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Release notes

## Nightly

* It is now possible to add arbitrary metadata to `tfds.core.DatasetInfo`
which will be stored/restored with the dataset. See `tfds.core.Metadata`.
4 changes: 4 additions & 0 deletions tensorflow_datasets/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from tensorflow_datasets.core.dataset_builder import GeneratorBasedBuilder

from tensorflow_datasets.core.dataset_info import DatasetInfo
from tensorflow_datasets.core.dataset_info import Metadata
from tensorflow_datasets.core.dataset_info import MetadataDict

from tensorflow_datasets.core.lazy_imports import lazy_imports

Expand All @@ -41,6 +43,8 @@
"get_tfds_path",
"DatasetInfo",
"NamedSplit",
"Metadata",
"MetadataDict",
"SplitBase",
"SplitDict",
"SplitGenerator",
Expand Down
84 changes: 76 additions & 8 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@
from __future__ import division
from __future__ import print_function

import abc
import collections
import json
import os
import posixpath
import pprint
import tempfile

from absl import logging
import numpy as np
import six
import tensorflow as tf

from tensorflow_datasets.core import api_utils
Expand All @@ -57,7 +60,6 @@

# Name of the file to output the DatasetInfo protobuf object.
DATASET_INFO_FILENAME = "dataset_info.json"

LICENSE_FILENAME = "LICENSE"

INFO_STR = """tfds.core.DatasetInfo(
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(self,
supervised_keys=None,
urls=None,
citation=None,
metadata=None,
redistribution_info=None):
"""Constructs DatasetInfo.
Expand All @@ -110,6 +113,9 @@ def __init__(self,
supervised learning, if applicable for the dataset.
urls: `list(str)`, optional, the homepage(s) for this dataset.
citation: `str`, optional, the citation to use for this dataset.
metadata: `tfds.core.Metadata`, additonal object which will be
stored/restored with the dataset. This allows for storing additional
information with the dataset.
redistribution_info: `dict`, optional, information needed for
redistribution, as specified in `dataset_info_pb2.RedistributionInfo`.
The content of the `license` subfield will automatically be written to a
Expand All @@ -135,6 +141,12 @@ def __init__(self,
self._info_proto.supervised_keys.input = supervised_keys[0]
self._info_proto.supervised_keys.output = supervised_keys[1]

if metadata and not isinstance(metadata, Metadata):
raise ValueError(
"Metadata should be a `tfds.core.Metadata` instance. Received "
"{}".format(metadata))
self._metadata = metadata

# Is this object initialized with both the static and the dynamic data?
self._fully_initialized = False

Expand Down Expand Up @@ -179,6 +191,10 @@ def size_in_bytes(self, size):
def features(self):
return self._features

@property
def metadata(self):
return self._metadata

@property
def supervised_keys(self):
if not self.as_proto.HasField("supervised_keys"):
Expand Down Expand Up @@ -236,10 +252,10 @@ def initialized(self):
"""Whether DatasetInfo has been fully initialized."""
return self._fully_initialized

def _dataset_info_filename(self, dataset_info_dir):
def _dataset_info_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, DATASET_INFO_FILENAME)

def _license_filename(self, dataset_info_dir):
def _license_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, LICENSE_FILENAME)

def compute_dynamic_properties(self):
Expand Down Expand Up @@ -287,13 +303,15 @@ def write_to_directory(self, dataset_info_dir):
if self.features:
self.features.save_metadata(dataset_info_dir)

# Save any additional metadata
if self.metadata is not None:
self.metadata.save_metadata(dataset_info_dir)

if self.redistribution_info.license:
with tf.io.gfile.GFile(self._license_filename(dataset_info_dir),
"w") as f:
with tf.io.gfile.GFile(self._license_path(dataset_info_dir), "w") as f:
f.write(self.redistribution_info.license)

with tf.io.gfile.GFile(self._dataset_info_filename(dataset_info_dir),
"w") as f:
with tf.io.gfile.GFile(self._dataset_info_path(dataset_info_dir), "w") as f:
f.write(self.as_json)

def read_from_directory(self, dataset_info_dir):
Expand All @@ -312,7 +330,7 @@ def read_from_directory(self, dataset_info_dir):
raise ValueError(
"Calling read_from_directory with undefined dataset_info_dir.")

json_filename = self._dataset_info_filename(dataset_info_dir)
json_filename = self._dataset_info_path(dataset_info_dir)

# Load the metadata from disk
parsed_proto = read_from_json(json_filename)
Expand All @@ -324,6 +342,9 @@ def read_from_directory(self, dataset_info_dir):
if self.features:
self.features.load_metadata(dataset_info_dir)

if self.metadata is not None:
self.metadata.load_metadata(dataset_info_dir)

# Update fields which are not defined in the code. This means that
# the code will overwrite fields which are present in
# dataset_info.json.
Expand Down Expand Up @@ -564,3 +585,50 @@ def read_from_json(json_filename):
parsed_proto = json_format.Parse(dataset_info_json_str,
dataset_info_pb2.DatasetInfo())
return parsed_proto


@six.add_metaclass(abc.ABCMeta)
class Metadata(dict):
"""Abstract base class for DatasetInfo metadata container.
`builder.info.metadata` allows the dataset to expose additional general
information about the dataset which are not specific to a feature or
individual example.
To implement the interface, overwrite `save_metadata` and
`load_metadata`.
See `tfds.core.MetadataDict` for a simple implementation that acts as a
dict that saves data to/from a JSON file.
"""

@abc.abstractmethod
def save_metadata(self, data_dir):
"""Save the metadata."""
raise NotImplementedError()

@abc.abstractmethod
def load_metadata(self, data_dir):
"""Restore the metadata."""
raise NotImplementedError()


class MetadataDict(Metadata, dict):
"""A `tfds.core.Metadata` object that acts as a `dict`.
By default, the metadata will be serialized as JSON.
"""

def _build_filepath(self, data_dir):
return os.path.join(data_dir, "metadata.json")

def save_metadata(self, data_dir):
"""Save the metadata."""
with tf.io.gfile.GFile(self._build_filepath(data_dir), "w") as f:
json.dump(self, f)

def load_metadata(self, data_dir):
"""Restore the metadata."""
self.clear()
with tf.io.gfile.GFile(self._build_filepath(data_dir), "r") as f:
self.update(json.load(f))
20 changes: 17 additions & 3 deletions tensorflow_datasets/core/dataset_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ def _info(self):
builder=self,
features=features.FeaturesDict({"im": features.Image()}),
supervised_keys=("im", "im"),
metadata=dataset_info.MetadataDict(),
)

def _generate_examples(self):
self.info.metadata["some_key"] = 123

for _ in range(30):
height = np.random.randint(5, high=10)
width = np.random.randint(5, high=10)
Expand Down Expand Up @@ -128,19 +131,19 @@ def test_writing(self):
info.read_from_directory(_INFO_DIR)

# Read the json file into a string.
with tf.io.gfile.GFile(info._dataset_info_filename(_INFO_DIR)) as f:
with tf.io.gfile.GFile(info._dataset_info_path(_INFO_DIR)) as f:
existing_json = json.load(f)

# Now write to a temp directory.
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
info.write_to_directory(tmp_dir)

# Read the newly written json file into a string.
with tf.io.gfile.GFile(info._dataset_info_filename(tmp_dir)) as f:
with tf.io.gfile.GFile(info._dataset_info_path(tmp_dir)) as f:
new_json = json.load(f)

# Read the newly written LICENSE file into a string.
with tf.io.gfile.GFile(info._license_filename(tmp_dir)) as f:
with tf.io.gfile.GFile(info._license_path(tmp_dir)) as f:
license_ = f.read()

# Assert what was read and then written and read again is the same.
Expand Down Expand Up @@ -262,6 +265,17 @@ def test_statistics_generation_variable_sizes(self):
self.assertEqual(-1, schema_feature.shape.dim[1].size)
self.assertEqual(3, schema_feature.shape.dim[2].size)

def test_metadata(self):
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
builder = RandomShapedImageGenerator(data_dir=tmp_dir)
builder.download_and_prepare()
# Metadata should have been created
self.assertEqual(builder.info.metadata, {"some_key": 123})

# Metadata should have been restored
builder2 = RandomShapedImageGenerator(data_dir=tmp_dir)
self.assertEqual(builder2.info.metadata, {"some_key": 123})

def test_updates_on_bucket_info(self):

info = dataset_info.DatasetInfo(builder=self._builder,
Expand Down

0 comments on commit ab791a4

Please sign in to comment.