Skip to content

Commit

Permalink
Expose API for reading individual TFRecord files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 292534021
  • Loading branch information
adarob authored and copybara-github committed Jan 31, 2020
1 parent cf61126 commit 34ae6dd
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 26 deletions.
76 changes: 50 additions & 26 deletions tensorflow_datasets/core/tfrecords_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import division
from __future__ import print_function

import copy
import functools
import math
import os
Expand Down Expand Up @@ -160,34 +161,22 @@ def _make_file_instructions_from_absolutes(
)


def _read_single_instruction(
instruction,
def _read_files(
files,
parse_fn,
read_config,
name,
path,
split_infos,
shuffle_files):
"""Returns tf.data.Dataset for given instruction.
"""Returns tf.data.Dataset for given file instructions.
Args:
instruction (ReadInstruction or str): if str, a ReadInstruction will be
constructed using `ReadInstruction.from_spec(str)`.
files: List[dict(filename, skip, take)], the files information.
The filenames contain the absolute path, not relative.
skip/take indicates which example read in the shard: `ds.skip().take()`
parse_fn (callable): function used to parse each record.
read_config: `tfds.ReadConfig`, Additional options to configure the
input pipeline (e.g. seed, num parallel reads,...).
name (str): name of the dataset.
path (str): path to directory where to read tfrecords from.
split_infos: `SplitDict`, the `info.splits` container of `SplitInfo`.
shuffle_files (bool): Defaults to False. True to shuffle input files.
"""
file_instructions = make_file_instructions(name, split_infos, instruction)
for fi in file_instructions.file_instructions:
fi['filename'] = os.path.join(path, fi['filename'])
files = file_instructions.file_instructions
if not files:
msg = 'Instruction "%s" corresponds to no data!' % instruction
raise AssertionError(msg)
# Eventually apply a transformation to the instruction function.
# This allow the user to have direct control over the interleave order.
if read_config.experimental_interleave_sort_fn is not None:
Expand Down Expand Up @@ -276,16 +265,51 @@ def read(
ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
corresponding to given instructions param shape.
"""
read_instruction = functools.partial(
_read_single_instruction,
parse_fn=self._parser.parse_example,
def _read_instruction_to_file_instructions(instruction):
file_instructions = make_file_instructions(name, split_infos, instruction)
files = file_instructions.file_instructions
if not files:
msg = 'Instruction "%s" corresponds to no data!' % instruction
raise AssertionError(msg)
return tuple(files)

files = utils.map_nested(
_read_instruction_to_file_instructions, instructions, map_tuple=False)
return utils.map_nested(
functools.partial(
self.read_files, read_config=read_config,
shuffle_files=shuffle_files),
files,
map_tuple=False)

def read_files(
self,
files,
read_config,
shuffle_files
):
"""Returns single tf.data.Dataset instance for the set of file instructions.
Args:
files: List[dict(filename, skip, take)], the files information.
The filenames contains the relative path, not absolute.
skip/take indicates which example read in the shard: `ds.skip().take()`
read_config: `tfds.ReadConfig`, the input pipeline options
shuffle_files (bool): If True, input files are shuffled before being read.
Returns:
a tf.data.Dataset instance.
"""
# Prepend path to filename
files = copy.deepcopy(files)
for f in files:
f.update(filename=os.path.join(self._path, f['filename']))
dataset = _read_files(
files=files,
read_config=read_config,
split_infos=split_infos,
name=name,
path=self._path,
parse_fn=self._parser.parse_example,
shuffle_files=shuffle_files)
datasets = utils.map_nested(read_instruction, instructions, map_tuple=True)
return datasets
return dataset


@attr.s(frozen=True)
Expand Down
10 changes: 10 additions & 0 deletions tensorflow_datasets/core/tfrecords_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,16 @@ def test_4fold(self):
[b'a', b'b', b'c', b'd', b'e', b'f', b'j', b'k', b'l'],
[b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']])

def test_read_files(self):
self._write_tfrecord('train', 4, 'abcdefghijkl')
fname_pattern = 'mnist-train.tfrecord-0000%d-of-00004'
ds = self.reader.read_files(
[{'filename': fname_pattern % 1, 'skip': 0, 'take': -1},
{'filename': fname_pattern % 3, 'skip': 1, 'take': 1}],
read_config=read_config_lib.ReadConfig(),
shuffle_files=False)
read_data = list(tfds.as_numpy(ds))
self.assertEqual(read_data, [six.b(l) for l in 'defk'])

if __name__ == '__main__':
testing.test_main()

0 comments on commit 34ae6dd

Please sign in to comment.