Skip to content

Support ReadConfig(input_context=) when number of shard is small #3025

Open
@Conchylicultor

Description

Currently in distributed setting, user can use tfds.ReadConfig(input_context=) as described in
https://www.tensorflow.org/datasets/performances#auto-shard_your_data_across_workers. This make sure that each worker read a different slice of the data.

However this only works when the number of shards is bigger than the number of workers. Otherwise, user is required to use the subsplit API:

def make_ds(input_context: tf.distribute.InputContext):
  split = tfds.even_splits('train', n=input_context.num_input_pipelines)[input_context.input_pipeline_id]
  ds = tfds.load('cifar10', split=split)
  return ds

ds = strategy.distribute_datasets_from_function(make_ds) 

It would be nice that the input_context automatically apply the subsplit API if info.splits[split].num_shards < read_config.input_context.num_input_pipelines, so the code in https://www.tensorflow.org/datasets/performances#auto-shard_your_data_across_workers would works for all datasets.

We might want to wait for the new TFDS file format before doing this.

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions