Skip to content

Commit

Permalink
Remove CSV from file_format_adapter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 245254012
  • Loading branch information
pierrot0 authored and copybara-github committed Apr 25, 2019
1 parent 7e86944 commit 5f11aff
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 120 deletions.
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ class FileAdapterBuilder(DatasetBuilder):

@utils.memoized_property
def _file_format_adapter(self):
# Load the format adapter (CSV, TF-Record,...)
# Load the format adapter (TF-Record,...)
file_adapter_cls = file_format_adapter.TFRecordExampleAdapter
serialized_info = self.info.features.get_serialized_info()
return file_adapter_cls(serialized_info)
Expand Down
109 changes: 2 additions & 107 deletions tensorflow_datasets/core/file_format_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
FileFormatAdapters implement methods to write and read data from a
particular file format.
Currently, two FileAdapter are available:
Currently, a single FileAdapter is available:
* TFRecordExampleAdapter: To store the pre-processed dataset as .tfrecord file
* CSVAdapter: To store the dataset as CSV file
```python
return TFRecordExampleAdapter({
Expand All @@ -35,10 +34,7 @@
from __future__ import print_function

import abc
import collections
import contextlib
import csv
import itertools
import random
import string

Expand All @@ -54,7 +50,6 @@
__all__ = [
"FileFormatAdapter",
"TFRecordExampleAdapter",
"CSVAdapter",
]


Expand Down Expand Up @@ -94,7 +89,7 @@ def dataset_from_filename(self, filename):

@abc.abstractproperty
def filetype_suffix(self):
"""Returns a str file type suffix (e.g. "csv")."""
"""Returns a str file type suffix (e.g. "tfrecord")."""
raise NotImplementedError


Expand Down Expand Up @@ -196,79 +191,6 @@ def _decode(self, record):
sequence_features=self._sequence_reading_spec)


class CSVAdapter(FileFormatAdapter):
"""Writes/reads features to/from CSV files.
Constraints on generators:
* The generator must yield feature dictionaries (`dict<str feature_name,
feature_value>`).
* The allowed feature types are `int`, `float`, and `str`. By default, only
scalar features are supported (that is, not lists).
You can modify how records are written by passing `csv_writer_ctor`.
You can modify how records are read by passing `csv_dataset_kwargs`.
Note that all CSV files produced will have a header row.
"""

# TODO(rsepassi): Instead of feature_types, take record_defaults and
# infer the types from the default values if provided.
def __init__(self,
feature_types,
csv_dataset_kwargs=None,
csv_writer_ctor=csv.writer):
"""Constructs CSVAdapter.
Args:
feature_types (dict<name, type>): specifies the dtypes of each of the
features (columns in the CSV file).
csv_dataset_kwargs (dict): forwarded to `tf.data.experimental.CsvDataset`.
csv_writer_ctor (function): takes file handle and returns writer.
Raises:
ValueError: if csv_dataset_kwargs["header"] is present.
"""
self._csv_kwargs = csv_dataset_kwargs or {}
if "header" in self._csv_kwargs:
raise ValueError("header must not be present")
self._feature_types = collections.OrderedDict(sorted(feature_types.items()))
self._csv_kwargs["header"] = True
# TODO(epot): Should check feature_types and raise error is some are
# not supported with CSV. Currently CSV files only support single
# values, no array.
if "record_defaults" not in self._csv_kwargs:
types = [f.dtype for f in self._feature_types.values()]
self._csv_kwargs["record_defaults"] = types
self._csv_writer_ctor = csv_writer_ctor

# TODO(rsepassi): Add support for non-scalar features (e.g. list of integers).
def write_from_generator(self, generator_fn, output_files):
# Flatten the dict returned by the generator and add the header
header_keys = list(self._feature_types.keys())
rows_generator = ([d[k] for k in header_keys] for d in generator_fn()) # pylint: disable=g-complex-comprehension
generator_with_header = itertools.chain([header_keys], rows_generator)
_write_csv_from_generator(
generator_with_header,
output_files,
self._csv_writer_ctor)

def dataset_from_filename(self, filename):
dataset = tf.data.experimental.CsvDataset(filename, **self._csv_kwargs)
return dataset.map(self._decode,
num_parallel_calls=tf.data.experimental.AUTOTUNE)

def _decode(self, *record):
return {
k: v for k, v in zip(self._feature_types.keys(), record)
}

@property
def filetype_suffix(self):
return "csv"


def do_files_exist(filenames):
"""Whether any of the filenames exist."""
preexisting = [tf.io.gfile.exists(f) for f in filenames]
Expand Down Expand Up @@ -366,33 +288,6 @@ def _round_robin_write(writers, generator):
writers[i % len(writers)].write(example)


def _write_csv_from_generator(generator, output_files, writer_ctor=None):
"""Write records to CSVs using writer_ctor (defaults to csv.writer)."""
if do_files_exist(output_files):
raise ValueError(
"Pre-processed files already exists: {}.".format(output_files))

if writer_ctor is None:
writer_ctor = csv.writer

def create_csv_writer(filename):
with tf.io.gfile.GFile(filename, "wb") as f:
writer = writer_ctor(f)
# Simple way to give the writer a "write" method proxying writerow
writer = collections.namedtuple("_writer", ["write"])(
write=writer.writerow)
return f, writer

with _incomplete_files(output_files) as tmp_files:
handles, writers = zip(*[create_csv_writer(fname) for fname in tmp_files])
with _close_on_exit(handles):
logging.info("Writing CSVs")
header = next(generator)
for w in writers:
w.write(header)
_round_robin_write(writers, generator)


def _dicts_to_tf_sequence_example(context_dict, sequences_dict):
flists = {}
for k, flist in six.iteritems(sequences_dict):
Expand Down
12 changes: 0 additions & 12 deletions tensorflow_datasets/core/file_format_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,6 @@ def _info(self):
)


class DummyCSVBuilder(DummyTFRecordBuilder):

@property
def _file_format_adapter(self):
file_adapter_cls = file_format_adapter.CSVAdapter
serialized_info = self.info.features.get_serialized_info()
return file_adapter_cls(serialized_info)


class FileFormatAdapterTest(testing.TestCase):

def _test_generator_based_builder(self, builder_cls):
Expand Down Expand Up @@ -105,9 +96,6 @@ def validate_dataset(dataset, min_val, max_val, test_range=False):
def test_tfrecords(self):
self._test_generator_based_builder(DummyTFRecordBuilder)

def test_csv(self):
self._test_generator_based_builder(DummyCSVBuilder)


class TFRecordUtilsTest(testing.TestCase):

Expand Down

0 comments on commit 5f11aff

Please sign in to comment.