Description
Maintain Constant Global Batch Size Upon Failure
With the current implementation of DistributedSampler
, the global_batch_size
is group_batch_size * num_replica_group
.
It may be more preferable if the DistributedSampler is implemented such that the global batch size stays constant. The group_batch_size
is global_batch_size//num_replica_group
. If it is not completely divisible, then the remaining samples is distributed amongst the available replica groups.
For hardware efficiency reasons one may potentially want to keep the local_batch_size
a multiple of a certain integer. If so, the global_batch_size
should still be as close to the targetted size as possible, subject to these constraints.
Implementation
The dataloader control logic could be folded into quorum.
The Lighthouse would then coordinate the loading, specifying the local_batch_size along with the start_index for each replica group.
To make this implementation efficient, each replica group do not need to self.wait_quorum()
before loading a batch. Instead, since the local batch size could only increase (upon failure), the batch can be loaded assuming num_replica_group remains the same as the last iteration, loading additional samples if needed after receiving quorum.