Open
Description
Currently in distributed setting, user can use tfds.ReadConfig(input_context=)
as described in
https://www.tensorflow.org/datasets/performances#auto-shard_your_data_across_workers. This make sure that each worker read a different slice of the data.
However this only works when the number of shards is bigger than the number of workers. Otherwise, user is required to use the subsplit API:
def make_ds(input_context: tf.distribute.InputContext):
split = tfds.even_splits('train', n=input_context.num_input_pipelines)[input_context.input_pipeline_id]
ds = tfds.load('cifar10', split=split)
return ds
ds = strategy.distribute_datasets_from_function(make_ds)
It would be nice that the input_context automatically apply the subsplit API if info.splits[split].num_shards < read_config.input_context.num_input_pipelines
, so the code in https://www.tensorflow.org/datasets/performances#auto-shard_your_data_across_workers would works for all datasets.
We might want to wait for the new TFDS file format before doing this.