Skip to content

Commit

Permalink
Merge pull request tensorflow#3308 from 8bitmp3:patch-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 403977716
  • Loading branch information
copybara-github committed Oct 18, 2021
2 parents 0f97767 + 0a77a81 commit d00e675
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions docs/performances.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Performance tips

This document provides TFDS-specific performance tips. Note that TFDS provides
datasets as `tf.data.Dataset`s, so the advice from the
This document provides TensorFlow Datasets (TFDS)-specific performance tips.
Note that TFDS provides datasets as `tf.data.Dataset` objects, so the advice
from the
[`tf.data` guide](https://www.tensorflow.org/guide/data_performance#optimize_performance)
still applies.

Expand All @@ -21,16 +22,16 @@ tfds.benchmark(ds, batch_size=32)
tfds.benchmark(ds, batch_size=32)
```

## Small datasets (< GB)
## Small datasets (less than 1 GB)

All TFDS datasets store the data on disk in the
[`TFRecord`](https://www.tensorflow.org/tutorials/load_data/tfrecord) format.
For small datasets (e.g. Mnist, Cifar,...), reading from `.tfrecord` can add
For small datasets (e.g. MNIST, CIFAR-10/-100), reading from `.tfrecord` can add
significant overhead.

As those datasets fit in memory, it is possible to significantly improve the
performance by caching or pre-loading the dataset. Note that TFDS automatically
caches small datasets (see next section for details).
caches small datasets (the following section has the details).

### Caching the dataset

Expand All @@ -53,7 +54,7 @@ ds, ds_info = tfds.load(
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
Expand Down Expand Up @@ -86,7 +87,7 @@ Tensor or NumPy array. It is possible to do so by setting `batch_size=-1` to
batch all examples in a single `tf.Tensor`. Then use `tfds.as_numpy` for the
conversion from `tf.Tensor` to `np.array`.

```
```python
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
'mnist',
split=['train', 'test'],
Expand All @@ -97,25 +98,25 @@ conversion from `tf.Tensor` to `np.array`.

## Large datasets

Large datasets are sharded (split in multiple files), and typically do not fit
in memory so they should not be cached.
Large datasets are sharded (split in multiple files) and typically do not fit
in memory, so they should not be cached.

### Shuffle and training

During training, it's important to shuffle the data well; poorly shuffled data
During training, it's important to shuffle the data well - poorly shuffled data
can result in lower training accuracy.

In addition to using `ds.shuffle` to shuffle records, you should also set
`shuffle_files=True` to get good shuffling behavior for larger datasets that are
sharded into multiple files. Otherwise, epochs will read the shards in the same
order, and so data won't be truly randomized.

```
```python
ds = tfds.load('imagenet2012', split='train', shuffle_files=True)
```

Additionally, when `shuffle_files=True`, TFDS disables
[`options.experimental_deterministic`](https://www.tensorflow.org/api_docs/python/tf/data/Options?version=nightly#experimental_deterministic),
[`options.experimental_deterministic`](https://www.tensorflow.org/api_docs/python/tf/data/Options#experimental_deterministic),
which may give a slight performance boost. To get deterministic shuffling, it is
possible to opt-out of this feature with `tfds.ReadConfig`: either by setting
`read_config.shuffle_seed` or overwriting
Expand All @@ -137,17 +138,17 @@ read_config = tfds.ReadConfig(
ds = tfds.load('dataset', split='train', read_config=read_config)
```

This is complementary to the subsplit API. First the subplit API is applied (
`train[:50%]` is converted into a list of files to read), then a `ds.shard()` op
is applied on those files. Example: when using `train[:50%]` with
`num_input_pipelines=2`, each of the 2 worker will read 1/4 of the data.
This is complementary to the subsplit API. First, the subplit API is applied:
`train[:50%]` is converted into a list of files to read. Then, a `ds.shard()` op
is applied on those files. For example, when using `train[:50%]` with
`num_input_pipelines=2`, each of the 2 workers will read 1/4 of the data.

When `shuffle_files=True`, files are shuffled within one worker, but not across
workers. Each worker will read the same subset of files between epochs.

Note: When using `tf.distribute.Strategy`, the `input_context` can be
automatically created with
[experimental_distribute_datasets_from_function](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy?version=nightly#experimental_distribute_datasets_from_function)
[distribute_datasets_from_function](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#distribute_datasets_from_function)

### Auto-shard your data across workers (Jax)

Expand All @@ -162,12 +163,12 @@ ds = tfds.load('my_dataset', split=splits[jax.process_index()])

### Faster image decoding

By default TFDS automatically decodes images. However, there are cases where it
By default, TFDS automatically decodes images. However, there are cases where it
can be more performant to skip the image decoding with
`tfds.decode.SkipDecoding` and manually apply the `tf.io.decode_image` op:

* When filtering examples (with `ds.filter`), to decode images after examples
have been filtered.
* When filtering examples (with `tf.data.Dataset.filter`), to decode images
after examples have been filtered.
* When cropping images, to use the fused `tf.image.decode_and_crop_jpeg` op.

The code for both examples is available in the
Expand Down

0 comments on commit d00e675

Please sign in to comment.