Skip to content

Fabric: Enable auto for devices and accelerator cli` arguments #20913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
12 changes: 10 additions & 2 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.device_parser import _parse_gpu_ids, _select_auto_accelerator_fabric
from lightning.fabric.utilities.distributed import _suggested_max_num_threads
from lightning.fabric.utilities.load import _load_distributed_checkpoint

Expand All @@ -34,7 +34,7 @@
_CLICK_AVAILABLE = RequirementCache("click")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")


def _get_supported_strategies() -> list[str]:
Expand Down Expand Up @@ -187,6 +187,14 @@ def _set_env_variables(args: Namespace) -> None:

def _get_num_processes(accelerator: str, devices: str) -> int:
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""

if accelerator == "auto" or accelerator is None:
accelerator = _select_auto_accelerator_fabric()
if devices == "auto":
if accelerator == "cuda" or accelerator == "mps" or accelerator == "cpu":
devices = "1"
else:
raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'")
if accelerator == "gpu":
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
elif accelerator == "cuda":
Expand Down
15 changes: 15 additions & 0 deletions src/lightning/fabric/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,18 @@ def _check_data_type(device_ids: object) -> None:
raise TypeError(f"{msg} a sequence of {type(id_).__name__}.")
elif type(device_ids) not in (int, str):
raise TypeError(f"{msg} {device_ids!r}.")


def _select_auto_accelerator_fabric() -> str:
"""Choose the accelerator type (str) based on availability."""
from lightning.fabric.accelerators.cuda import CUDAAccelerator
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.accelerators.xla import XLAAccelerator

if XLAAccelerator.is_available():
return "tpu"
if MPSAccelerator.is_available():
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
return "cpu"
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
SLURMEnvironment,
TorchElasticEnvironment,
)
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device, _select_auto_accelerator_fabric
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
from lightning.pytorch.accelerators import AcceleratorRegistry
from lightning.pytorch.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -332,18 +332,12 @@ def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str
@staticmethod
def _choose_auto_accelerator() -> str:
"""Choose the accelerator type (str) based on availability."""
if XLAAccelerator.is_available():
return "tpu"
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

if HPUAccelerator.is_available():
return "hpu"
if MPSAccelerator.is_available():
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
return "cpu"
return _select_auto_accelerator_fabric()

@staticmethod
def _choose_gpu_accelerator_backend() -> str:
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_fabric/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_run_env_vars_defaults(monkeypatch, fake_script):
assert "LT_PRECISION" not in os.environ


@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", "auto", pytest.param("mps", marks=RunIf(mps=True))])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
assert f"Invalid value for '--strategy': '{strategy}'" in ioerr.getvalue()


@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1", "auto"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
Expand All @@ -97,7 +97,7 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):


@RunIf(mps=True)
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
@pytest.mark.parametrize("accelerator", ["mps", "gpu", "auto"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
Expand Down
Loading