Skip to content

Commit

Permalink
Resolve instantiation problem with init_meta_context (#10493)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Nov 15, 2021
1 parent ae71284 commit 1de3539
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))


- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))


- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))


Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import Module

import pytorch_lightning as pl


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
Expand Down Expand Up @@ -177,7 +179,9 @@ def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
return
if device is not None:
module._device = device
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
Expand Down Expand Up @@ -1406,10 +1406,21 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self._handle_meta_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

def _handle_meta_model(self) -> None:
if not is_on_meta_device(self.lightning_module):
return

if isinstance(self.training_type_plugin, DDPSpawnPlugin):
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")

materialize_module(self.lightning_module)
# the trainer reference is lost during materialization
self.lightning_module.trainer = proxy(self)

def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn

Expand Down
46 changes: 30 additions & 16 deletions pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type

import torch
from torch import nn, Tensor
from torch.nn import Module
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down Expand Up @@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:

# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}

__CREATED_MODULES__ = set()


Expand Down Expand Up @@ -237,45 +237,52 @@ def _set_meta_device() -> None:

for subclass in get_all_subclasses(torch.nn.modules.module.Module):

if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
continue

# if a subclass has already been stored, we should use the cache
if str(subclass) in __STORAGE_META__:
# reset the class import package to its rightfull state.
# reset the class import package to its rightful state.
mods, subclass, meta_class = __STORAGE_META__[subclass]
for mod in mods:
setattr(mod, subclass.__name__, meta_class)
continue

class _IsinstanceMetaclass(type(subclass)):
def __instancecheck__(self, instance: Any) -> bool:
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
return isinstance(instance, self.__bases__[0])

# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MetaClass(subclass):
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
@classmethod
@contextmanager
def instantiation_context(cls, materialize: bool):
def instantiation_context(cls):
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)

@classmethod
def materialize(cls, materialize_fn: Callable):
with cls.instantiation_context(materialize=True):
with cls.instantiation_context():
obj = materialize_fn()
return obj

@staticmethod
def add_subclasses(subclass):
"""This is used to unrol the instantion tree while creating the modules."""
__CREATED_MODULES__.add(subclass)
"""This is used to unroll the instantiation tree while creating the modules."""
# Don't store the LightningModule as skipped from the Meta process.
if subclass != pl.LightningModule:
__CREATED_MODULES__.add(subclass)
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MetaClass.add_subclasses(subclass.__bases__[0])
_MaterializerModule.add_subclasses(subclass.__bases__[0])

def __new__(cls, *args, **kwargs):
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context(materialize=False):
with cls.instantiation_context():
obj = init_meta(subclass, *args, **kwargs)

obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
Expand All @@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]:
# nn.Module class can be imported at different level and they all need to be mocked.
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
out = []
out.append(search(mod))
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
out = [search(mod)]
for name in submodules[1:]:
mod = getattr(mod, name)
out.append(search(mod))
Expand All @@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]:
mods = [mod for mod in chain(*out) if mod]

# store the modules search so it doesn't have to be performed again for this class
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)

# replace all subclass by its meta form
for mod in mods:
setattr(mod, subclass.__name__, _MetaClass)
setattr(mod, subclass.__name__, _MaterializerModule)


@contextmanager
Expand All @@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
_set_meta_device()
yield
_unset_meta_device()


def is_on_meta_device(module: nn.Module) -> bool:
try:
param = next(module.parameters())
return param.device.type == "meta"
except StopIteration:
return False
9 changes: 7 additions & 2 deletions tests/utilities/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.meta import init_meta_context, materialize_module
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
from tests.helpers.runif import RunIf


Expand All @@ -31,18 +31,23 @@ def __init__(self, num_layers: int):
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)])


@RunIf(min_torch="1.10.0")
@RunIf(special=True, min_torch="1.10.0")
def test_init_meta_context():

with init_meta_context():
m = nn.Linear(in_features=1, out_features=1)
assert isinstance(m, nn.Linear)
assert m.weight.device.type == "meta"
assert is_on_meta_device(m)
mlp = MLP(4)
assert mlp.layer[0].weight.device.type == "meta"

mlp = materialize_module(mlp)
assert mlp.layer[0].weight.device.type == "cpu"

assert not is_on_meta_device(mlp)
assert not is_on_meta_device(nn.Module())

model = BoringModel(4)
assert model.layer[0].weight.device.type == "meta"
materialize_module(model)
Expand Down

0 comments on commit 1de3539

Please sign in to comment.