diff --git a/tests/tests_fabric/helpers/datasets.py b/tests/tests_fabric/helpers/datasets.py new file mode 100644 index 0000000000000..211e1f36a9ab5 --- /dev/null +++ b/tests/tests_fabric/helpers/datasets.py @@ -0,0 +1,27 @@ +from typing import Iterator + +import torch +from torch import Tensor +from torch.utils.data import Dataset, IterableDataset + + +class RandomDataset(Dataset): + def __init__(self, size: int, length: int) -> None: + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index: int) -> Tensor: + return self.data[index] + + def __len__(self) -> int: + return self.len + + +class RandomIterableDataset(IterableDataset): + def __init__(self, size: int, count: int) -> None: + self.count = count + self.size = size + + def __iter__(self) -> Iterator[Tensor]: + for _ in range(self.count): + yield torch.randn(self.size) diff --git a/tests/tests_fabric/helpers/models.py b/tests/tests_fabric/helpers/models.py deleted file mode 100644 index c204fecc25702..0000000000000 --- a/tests/tests_fabric/helpers/models.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Any, Iterator - -import torch -import torch.nn as nn -from lightning.fabric import Fabric -from torch import Tensor -from torch.nn import Module -from torch.optim import Optimizer -from torch.utils.data import DataLoader, Dataset, IterableDataset - - -class RandomDataset(Dataset): - def __init__(self, size: int, length: int) -> None: - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index: int) -> Tensor: - return self.data[index] - - def __len__(self) -> int: - return self.len - - -class RandomIterableDataset(IterableDataset): - def __init__(self, size: int, count: int) -> None: - self.count = count - self.size = size - - def __iter__(self) -> Iterator[Tensor]: - for _ in range(self.count): - yield torch.randn(self.size) - - -class BoringFabric(Fabric): - def get_model(self) -> Module: - return nn.Linear(32, 2) - - def get_optimizer(self, module: Module) -> Optimizer: - return torch.optim.Adam(module.parameters(), lr=0.1) - - def get_dataloader(self) -> DataLoader: - return DataLoader(RandomDataset(32, 64)) - - def step(self, model: Module, batch: Any) -> Tensor: - output = model(batch) - return torch.nn.functional.mse_loss(output, torch.ones_like(output)) - - def after_backward(self, model: Module, optimizer: Optimizer) -> None: - pass - - def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None: - pass - - def run(self) -> None: - with self.init_module(): - model = self.get_model() - optimizer = self.get_optimizer(model) - model, optimizer = self.setup(model, optimizer) - - dataloader = self.get_dataloader() - dataloader = self.setup_dataloaders(dataloader) - - self.model = model - self.optimizer = optimizer - self.dataloader = dataloader - - model.train() - - data_iter = iter(dataloader) - batch = next(data_iter) - loss = self.step(model, batch) - self.backward(loss) - self.after_backward(model, optimizer) - optimizer.step() - self.after_optimizer_step(model, optimizer) - optimizer.zero_grad() diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 6b97fde2e0a49..2476cd3961504 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -25,7 +25,7 @@ from lightning.fabric.strategies import DeepSpeedStrategy from torch.utils.data import DataLoader -from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset +from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 704a86a997bc5..42fc1ba399e69 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -19,6 +19,7 @@ import pytest import torch +import torch.nn as nn from lightning.fabric import Fabric from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.strategies import FSDPStrategy @@ -31,13 +32,51 @@ from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType from torch.distributed.fsdp.wrap import always_wrap_policy, wrap from torch.nn import Parameter +from torch.utils.data import DataLoader -from tests_fabric.helpers.models import BoringFabric +from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel -class _MyFabric(BoringFabric): +class BasicTrainer: + """Implements a basic training loop for the end-to-end tests.""" + + def __init__(self, fabric): + self.fabric = fabric + self.model = self.optimizer = self.dataloader = None + + def get_model(self): + return nn.Linear(32, 2) + + def step(self, model, batch): + output = model(batch) + return torch.nn.functional.mse_loss(output, torch.ones_like(output)) + + def run(self) -> None: + with self.fabric.init_module(): + model = self.get_model() + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = self.fabric.setup(model, optimizer) + + dataloader = DataLoader(RandomDataset(32, 64)) + dataloader = self.fabric.setup_dataloaders(dataloader) + + self.model = model + self.optimizer = optimizer + self.dataloader = dataloader + + model.train() + + data_iter = iter(dataloader) + batch = next(data_iter) + loss = self.step(model, batch) + self.fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + + +class _Trainer(BasicTrainer): def get_model(self): model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) self.num_wrapped = 4 @@ -48,7 +87,7 @@ def step(self, model, batch): assert len(wrapped_layers) == self.num_wrapped assert (self.num_wrapped == 4) == isinstance(model._forward_module, FullyShardedDataParallel) - precision = self._precision + precision = self.fabric._precision assert isinstance(precision, FSDPPrecision) if precision.precision == "16-mixed": param_dtype = torch.float32 @@ -72,7 +111,7 @@ def step(self, model, batch): return torch.nn.functional.mse_loss(output, torch.ones_like(output)) -class _MyFabricManualWrapping(_MyFabric): +class _TrainerManualWrapping(_Trainer): def get_model(self): model = super().get_model() for i, layer in enumerate(model): @@ -87,29 +126,39 @@ def get_model(self): @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision): """Test FSDP training, saving and loading with different wrapping and precision settings.""" - fabric_cls = _MyFabricManualWrapping if manual_wrapping else _MyFabric - fabric = fabric_cls( - accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, precision=precision + trainer_cls = _TrainerManualWrapping if manual_wrapping else _Trainer + fabric = Fabric( + accelerator="cuda", + strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), + devices=2, + precision=precision, ) - fabric.run() + fabric.launch() + trainer = trainer_cls(fabric) + trainer.run() checkpoint_path = fabric.broadcast(str(tmp_path / "fsdp-checkpoint")) - params_before = deepcopy(list(fabric.model.parameters())) - state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1} + params_before = deepcopy(list(trainer.model.parameters())) + state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1} fabric.save(checkpoint_path, state) assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"} # re-init all objects and resume - fabric = fabric_cls( - accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, precision=precision + fabric = Fabric( + accelerator="cuda", + strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), + devices=2, + precision=precision, ) - fabric.run() + fabric.launch() + trainer = trainer_cls(fabric) + trainer.run() # check correctness with loaded state - state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 0} + state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 0} metadata = fabric.load(checkpoint_path, state) - for p0, p1 in zip(params_before, fabric.model.parameters()): + for p0, p1 in zip(params_before, trainer.model.parameters()): torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True) # check user data in state reloaded @@ -117,12 +166,12 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision): assert not metadata # attempt to load a key not in the metadata checkpoint - state = {"model": fabric.model, "coconut": 11} + state = {"model": trainer.model, "coconut": 11} with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"): fabric.load(checkpoint_path, state) # `strict=False` ignores the missing key - state = {"model": fabric.model, "coconut": 11} + state = {"model": trainer.model, "coconut": 11} fabric.load(checkpoint_path, state, strict=False) assert state["coconut"] == 11 @@ -130,16 +179,18 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision): @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0") def test_fsdp_save_full_state_dict(tmp_path): """Test that FSDP saves the full state into a single file with `state_dict_type="full"`.""" - fabric = BoringFabric( + fabric = Fabric( accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="full"), devices=2, ) - fabric.run() + fabric.launch() + trainer = BasicTrainer(fabric) + trainer.run() checkpoint_path = Path(fabric.broadcast(str(tmp_path / "fsdp-checkpoint.pt"))) - state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1} + state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1} fabric.save(checkpoint_path, state) checkpoint = torch.load(checkpoint_path) @@ -147,58 +198,61 @@ def test_fsdp_save_full_state_dict(tmp_path): loaded_state_dict = checkpoint["model"] # assert the correct state model was saved - with FullyShardedDataParallel.summon_full_params(fabric.model): - state_dict = fabric.model.state_dict() + with FullyShardedDataParallel.summon_full_params(trainer.model): + state_dict = trainer.model.state_dict() assert set(loaded_state_dict.keys()) == set(state_dict.keys()) for param_name in state_dict: assert torch.equal(loaded_state_dict[param_name], state_dict[param_name].cpu()) - params_before = [p.cpu() for p in fabric.model.parameters()] + params_before = [p.cpu() for p in trainer.model.parameters()] # assert the correct optimizer state was saved optimizer_state_before = FullyShardedDataParallel.full_optim_state_dict( - fabric.model, fabric.optimizer, rank0_only=False + trainer.model, trainer.optimizer, rank0_only=False ) assert set(checkpoint["optimizer"].keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"} # 1. verify the FSDP state can be loaded back into a FSDP model/strategy directly - fabric = BoringFabric( + fabric = Fabric( accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, ) - fabric.run() - metadata = fabric.load(checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer}) + fabric.launch() + trainer = BasicTrainer(fabric) + trainer.run() + metadata = fabric.load(checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer}) assert metadata == {"steps": 1} - with FullyShardedDataParallel.summon_full_params(fabric.model): - params_after = list(fabric.model.parameters()) + with FullyShardedDataParallel.summon_full_params(trainer.model): + params_after = list(trainer.model.parameters()) assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after)) # assert the correct optimizer state was loaded optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict( - fabric.model, fabric.optimizer, rank0_only=False + trainer.model, trainer.optimizer, rank0_only=False ) assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"} torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0) assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"] # run a step to verify the optimizer state is correct - fabric.run() + trainer.run() # 2. verify the FSDP state can be loaded back into a single-device model/strategy - fabric = BoringFabric(accelerator="cpu", devices=1) - fabric.run() - metadata = fabric.load(checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer}) + fabric = Fabric(accelerator="cpu", devices=1) + trainer = BasicTrainer(fabric) + trainer.run() + metadata = fabric.load(checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer}) assert metadata == {"steps": 1} - params_after = list(fabric.model.parameters()) + params_after = list(trainer.model.parameters()) assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after)) # get optimizer state after loading normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt"))) - fabric.save(normal_checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 2}) + fabric.save(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 2}) optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"] optimizer_state_after = FullyShardedDataParallel.rekey_optim_state_dict( - optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=fabric.model + optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=trainer.model ) # assert the correct optimizer state was loaded @@ -206,32 +260,34 @@ def test_fsdp_save_full_state_dict(tmp_path): torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0) # run a step to verify the optimizer state is correct - fabric.run() + trainer.run() # 3. verify that a single-device model/strategy states can be loaded into a FSDP model/strategy - fabric = BoringFabric( + fabric = Fabric( accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, ) - fabric.run() - metadata = fabric.load(normal_checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer}) + fabric.launch() + trainer = BasicTrainer(fabric) + trainer.run() + metadata = fabric.load(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer}) assert metadata == {"steps": 2} - with FullyShardedDataParallel.summon_full_params(fabric.model): - params_after = list(fabric.model.parameters()) + with FullyShardedDataParallel.summon_full_params(trainer.model): + params_after = list(trainer.model.parameters()) assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after)) # assert the correct optimizer state was loaded optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict( - fabric.model, fabric.optimizer, rank0_only=False + trainer.model, trainer.optimizer, rank0_only=False ) assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"} torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0) assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"] # run a step to verify the optimizer state is correct - fabric.run() + trainer.run() @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0") @@ -239,34 +295,37 @@ def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path): """Test that the strategy can load a full-state checkpoint into a FSDP sharded model.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - fabric = BoringFabric(accelerator="cuda", devices=1) + fabric = Fabric(accelerator="cuda", devices=1) fabric.seed_everything(0) - fabric.run() + trainer = BasicTrainer(fabric) + trainer.run() # Save a full-state-dict checkpoint checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt"))) - state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1} + state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1} fabric.save(checkpoint_path, state) # Gather all weights and store a copy manually - with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False): - params_before = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()]) + with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False): + params_before = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()]) # Create a FSDP sharded model - fabric = BoringFabric( + fabric = Fabric( accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, ) - fabric.run() + fabric.launch() + trainer = BasicTrainer(fabric) + trainer.run() - state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 44} + state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 44} fabric.load(checkpoint_path, state) assert state["steps"] == 1 # Gather all weights and compare - with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False): - params_after = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()]) + with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False): + params_after = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()]) assert torch.equal(params_before, params_after) # Create a raw state-dict checkpoint to test `Fabric.load_raw` too @@ -276,12 +335,12 @@ def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path): torch.save(checkpoint["model"], raw_checkpoint_path) fabric.barrier() - fabric.run() - fabric.load_raw(raw_checkpoint_path, fabric.model) + trainer.run() + fabric.load_raw(raw_checkpoint_path, trainer.model) # Gather all weights and compare - with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False): - params_after = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()]) + with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False): + params_after = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()]) assert torch.equal(params_before, params_after) @@ -429,9 +488,9 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0") def test_fsdp_save_filter(tmp_path): - fabric = BoringFabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2) + fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2) fabric.launch() - model = fabric.get_model() + model = nn.Linear(32, 2) model = fabric.setup_module(model) tmp_path = Path(fabric.broadcast(str(tmp_path))) diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index 45967a6a9af68..f711eb3470b45 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -25,7 +25,7 @@ from lightning.fabric.utilities.seed import seed_everything from torch.utils.data import DataLoader -from tests_fabric.helpers.models import RandomDataset +from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py index b5e6adb32c78f..999b8473b28aa 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py @@ -22,7 +22,7 @@ from lightning.fabric.strategies import XLAFSDPStrategy from torch.utils.data import DataLoader -from tests_fabric.helpers.models import RandomDataset +from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 6519591cb4788..656b9cac3d77e 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -25,7 +25,7 @@ from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, RandomSampler -from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset +from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset def test_has_iterable_dataset():