Description
Proposed refactoring or deprecation
Introduce a device selection dataclass that holds the device selection in a standardized format. Idea by @ananthsub
Motivation
We have a _parse_devices
function used in the Trainer and Lite that returns a tuple of parsed device indices.
From @ananthsub in #10230 (comment)
returning a tuple isn't going to scale well with more device types. it's not easy to tell which positional index maps to which device id type. it could be better to introduce a dataclass to represent the schema concretely. that would also naturally allow for extensions like IPUs
Pitch
@dataclass
class DeviceSelection
devices: List[int] = []
type: DeviceType = CPU
def parse_input(gpus, tpu_cores, ipus, ...)
# validate user inputs
# map various input formats to standardized one in this dataclass
...
return DeviceSelection(devices=..., type=...)
The AcceleratorConnector new gets as input the DeviceSelection instance instead of a growing list of arguments. It currently takes: devices, gpus, gpu_ids, tpu_cores, ipus, num_processes
Additional context
Alternative to #10231
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @justusschock @awaelchli @akihironitta @rohitgr7 @tchaton @Borda