Skip to content

Commit

Permalink
Fix propagation of device and dtype properties in Lite modules (#10559)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 16, 2021
1 parent af4af3d commit d50e169
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))


-
- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#10559](https://github.com/PyTorchLightning/pytorch-lightning/issues/10559))


-
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device

Expand Down Expand Up @@ -64,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None:
)


class _LiteModule(nn.Module):
class _LiteModule(DeviceDtypeModuleMixin):
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
Expand Down
22 changes: 22 additions & 0 deletions tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -65,6 +66,27 @@ def check_autocast(forward_input):
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()


@pytest.mark.parametrize(
"device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))]
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_lite_module_device_dtype_propagation(device, dtype):
"""Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""

class DeviceModule(DeviceDtypeModuleMixin):
pass

device_module = DeviceModule()
lite_module = _LiteModule(device_module, Mock())
lite_module.to(device)
assert device_module.device == device
assert lite_module.device == device

lite_module.to(dtype)
assert device_module.dtype == dtype
assert lite_module.dtype == dtype


def test_lite_dataloader_iterator():
"""Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic
device placement)."""
Expand Down

0 comments on commit d50e169

Please sign in to comment.