Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve instantiation problem with init_meta_context #10493

Merged
merged 29 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))


- Fix `isinstance` not working with `init_meta_context`, materialize model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))


Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import Module

import pytorch_lightning as pl


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
Expand Down Expand Up @@ -177,7 +179,9 @@ def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
return
if device is not None:
module._device = device
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.meta import is_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
Expand Down Expand Up @@ -1404,10 +1404,21 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self._handle_meta_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

def _handle_meta_model(self) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not is_meta_device(self.lightning_module):
return

if isinstance(self.training_type_plugin, DDPSpawnPlugin):
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

materialize_module(self.lightning_module)
# the trainer reference is lost during materialization
self.lightning_module.trainer = proxy(self)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn

Expand Down
46 changes: 30 additions & 16 deletions pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type

import torch
from torch import nn, Tensor
from torch.nn import Module
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down Expand Up @@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:

# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}

__CREATED_MODULES__ = set()


Expand Down Expand Up @@ -237,45 +237,52 @@ def _set_meta_device() -> None:

for subclass in get_all_subclasses(torch.nn.modules.module.Module):

if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
continue

# if a subclass has already been stored, we should use the cache
if str(subclass) in __STORAGE_META__:
# reset the class import package to its rightfull state.
# reset the class import package to its rightful state.
mods, subclass, meta_class = __STORAGE_META__[subclass]
for mod in mods:
setattr(mod, subclass.__name__, meta_class)
continue

class _IsinstanceMetaclass(type(subclass)):
def __instancecheck__(self, instance: Any) -> bool:
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
return isinstance(instance, self.__bases__[0])

# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MetaClass(subclass):
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
@classmethod
@contextmanager
def instantiation_context(cls, materialize: bool):
def instantiation_context(cls):
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)

@classmethod
def materialize(cls, materialize_fn: Callable):
with cls.instantiation_context(materialize=True):
with cls.instantiation_context():
obj = materialize_fn()
return obj

@staticmethod
def add_subclasses(subclass):
"""This is used to unrol the instantion tree while creating the modules."""
__CREATED_MODULES__.add(subclass)
"""This is used to unroll the instantiation tree while creating the modules."""
# Don't store the LightningModule as skipped from the Meta process.
if subclass != pl.LightningModule:
__CREATED_MODULES__.add(subclass)
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MetaClass.add_subclasses(subclass.__bases__[0])
_MaterializerModule.add_subclasses(subclass.__bases__[0])

def __new__(cls, *args, **kwargs):
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context(materialize=False):
with cls.instantiation_context():
obj = init_meta(subclass, *args, **kwargs)

obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
Expand All @@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]:
# nn.Module class can be imported at different level and they all need to be mocked.
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
out = []
out.append(search(mod))
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
out = [search(mod)]
for name in submodules[1:]:
mod = getattr(mod, name)
out.append(search(mod))
Expand All @@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]:
mods = [mod for mod in chain(*out) if mod]

# store the modules search so it doesn't have to be performed again for this class
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)

# replace all subclass by its meta form
for mod in mods:
setattr(mod, subclass.__name__, _MetaClass)
setattr(mod, subclass.__name__, _MaterializerModule)


@contextmanager
Expand All @@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
_set_meta_device()
yield
_unset_meta_device()


def is_meta_device(module: nn.Module) -> bool:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
try:
param = next(module.parameters())
return param.device.type == "meta"
except StopIteration:
return False
7 changes: 6 additions & 1 deletion tests/utilities/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.meta import init_meta_context, materialize_module
from pytorch_lightning.utilities.meta import init_meta_context, is_meta_device, materialize_module
from tests.helpers.runif import RunIf


Expand All @@ -36,13 +36,18 @@ def test_init_meta_context():

with init_meta_context():
m = nn.Linear(in_features=1, out_features=1)
assert isinstance(m, nn.Linear)
assert m.weight.device.type == "meta"
assert is_meta_device(m)
mlp = MLP(4)
assert mlp.layer[0].weight.device.type == "meta"

mlp = materialize_module(mlp)
assert mlp.layer[0].weight.device.type == "cpu"

assert not is_meta_device(mlp)
assert not is_meta_device(nn.Module())

model = BoringModel(4)
assert model.layer[0].weight.device.type == "meta"
materialize_module(model)
Expand Down