From ac22cc358739cb408f1fbea625571d490a838a7e Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 4 Mar 2021 12:42:00 -0800 Subject: [PATCH] Parametrization Functionality (#33344) Summary: Provides the implementation for feature request issue https://github.com/pytorch/pytorch/issues/28937. Adds the `Parametrization` functionality and implements `Pruning` on top of it. It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example. It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion https://github.com/pytorch/pytorch/issues/7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions. As described in https://github.com/pytorch/pytorch/issues/7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in https://github.com/pytorch/pytorch/issues/28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...) TODO (when implementation is validated): - More thorough test - Documentation Resolves https://github.com/pytorch/pytorch/issues/28937 albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/33344 Reviewed By: zhangguanheng66 Differential Revision: D26816708 Pulled By: albanD fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe --- docs/source/nn.rst | 10 + test/test_nn.py | 349 ++++++++++++++++++++++++++++++++ torch/nn/utils/parametrize.py | 364 ++++++++++++++++++++++++++++++++++ 3 files changed, 723 insertions(+) create mode 100644 torch/nn/utils/parametrize.py diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 21dc99dc430a5..f2fd12284f0f2 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -346,11 +346,21 @@ From the ``torch.nn.utils`` module parameters_to_vector vector_to_parameters +.. autosummary:: + :toctree: generated + :nosignatures: + + parametrize.register_parametrization + parametrize.remove_parametrizations + parametrize.cached + parametrize.is_parametrized + .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst + parametrize.ParametrizationList prune.BasePruningMethod .. autosummary:: diff --git a/test/test_nn.py b/test/test_nn.py index 065b7cd8a238a..4e07fd9fabff9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -27,6 +27,7 @@ import torch.nn.init as init import torch.nn.utils.rnn as rnn_utils from torch.nn.utils import clip_grad_norm_, clip_grad_value_ +import torch.nn.utils.parametrize as parametrize import torch.nn.utils.prune as prune from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.nn import Parameter @@ -1939,6 +1940,354 @@ def test_vector_to_parameters(self): sample = next(model.parameters())[0, 0, 0] self.assertTrue(torch.equal(sample.data, vec.data[:5])) + # torch/nn/utils/parametrize + def test_register_and_remove_parametrization(self): + r"""Test that it is possible to add a few parametrizations + on a parameter or a buffer and that removing them restores the initial state + It also tests that backpropagating through them works as expected + """ + # Define a couple matrix parametrizations + class Skew(nn.Module): + def forward(self, X): + X = X.tril(-1) + return X - X.T + + class Orthogonal(nn.Module): + def forward(self, X): + # Cayley map + # If X is skew-symmetric it returns an orthogonal matrix + Id = torch.eye(X.size(0), device=X.device) + return torch.solve(Id - X, Id + X).solution + + # Define a couple vector parametrizations + class FirstZero(nn.Module): + def forward(self, x): + return torch.cat([x.new_zeros(1), x[1:]]) + + class LastZero(nn.Module): + def forward(self, x): + return torch.cat([x[:-1], x.new_zeros(1)]) + + model = nn.Linear(8, 8) + initial_weight_id = id(model.weight) + initial_bias_id = id(model.bias) + initial_model = deepcopy(model) + + # Test one parametrization + parametrize.register_parametrization(model, "weight", Skew()) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertNotIn("weight", model._parameters) + # Result should be skew-symmetric + A = model.weight + self.assertTrue(torch.allclose(A, -A.T)) + # Remove and check consistency + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.weight, initial_model.weight) + self.assertEqual(id(model.weight), initial_weight_id) + self.assertEqual(model.__class__, nn.Linear) + + # Test two parametrizations at the same time and removing them + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + # Result should be orthogonal + X = model.weight + Id = torch.eye(X.size(0), device=X.device) + self.assertTrue(torch.allclose(X.T @ X, Id)) + # Structure tests + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertIn("weight", model.parametrizations) + self.assertNotIn("weight", model._parameters) + # Remove + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertEqual(model.weight, initial_model.weight) + self.assertEqual(id(model.weight), initial_weight_id) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.__class__, nn.Linear) + + # Add everything + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + parametrize.register_parametrization(model, "bias", FirstZero()) + parametrize.register_parametrization(model, "bias", LastZero()) + + # Basic tests + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertTrue(parametrize.is_parametrized(model, "bias")) + self.assertEqual(model.bias[0].item(), 0.) + self.assertEqual(model.bias[-1].item(), 0.) + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happpened + # Should not throw + (model.weight.T @ model.bias).sum().backward() + with torch.no_grad(): + for p in model.parameters(): + p.add_(- p.grad, alpha=0.01) + + # Remove first parametrization. + # Check that the model is still parametrized and so is the second parameter + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized + self.assertFalse(parametrize.is_parametrized(model, "weight")) # Parametrization removed + self.assertTrue(parametrize.is_parametrized(model, "bias")) # Still parametrized + self.assertEqual(model.bias[0].item(), 0.) # Still parametrized + self.assertEqual(model.bias[-1].item(), 0.) # Still parametrized + self.assertNotEqual(model.weight, initial_model.weight) # Has been updated + self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened + # Should not throw + (model.weight.T @ model.bias).sum().backward() + with torch.no_grad(): + for p in model.parameters(): + p.add_(- p.grad, alpha=0.01) + + # Remove the second parametrization. + # Check that the module is not parametrized + parametrize.remove_parametrizations(model, "bias", leave_parametrized=False) + self.assertFalse(parametrize.is_parametrized(model)) # Still parametrized + self.assertNotEqual(model.bias, initial_model.bias) # Has been updated + self.assertNotEqual(model.bias[0].item(), 0.) # Still parametrized + self.assertNotEqual(model.bias[-1].item(), 0.) # Still parametrized + self.assertEqual(id(model.bias), initial_bias_id) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.__class__, nn.Linear) + self.assertEqual(len(list(model.parameters())), 2) + # Should not throw + (model.weight.T @ model.bias).sum().backward() + with torch.no_grad(): + for p in model.parameters(): + p.add_(- p.grad, alpha=0.01) + + def test_register_and_remove_buffer_parametrization(self): + r"""Test that it is possible to add and remove parametrizations on buffers""" + # Define a couple vector parametrizations + class FirstZero(nn.Module): + def forward(self, x): + return torch.cat([x.new_zeros(1), x[1:]]) + + class LastZero(nn.Module): + def forward(self, x): + return torch.cat([x[:-1], x.new_zeros(1)]) + + model = nn.Linear(8, 8) + + # Instantiate parametrizations on buffers. It should work as expected + delattr(model, "bias") + model.register_buffer("bias", torch.ones(8)) + parametrize.register_parametrization(model, "bias", FirstZero()) + parametrize.register_parametrization(model, "bias", LastZero()) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "bias")) + self.assertEqual(model.bias[0].item(), 0.) + self.assertEqual(model.bias[-1].item(), 0.) + self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) + self.assertEqual(len(list(model.parameters())), 1) + + # Remove parametrizations on buffers. It should work as expected + parametrize.remove_parametrizations(model, "bias", leave_parametrized=True) + self.assertFalse(parametrize.is_parametrized(model)) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertEqual(model.bias[0].item(), 0.) + self.assertEqual(model.bias[-1].item(), 0.) + self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) + self.assertEqual(len(list(model.parameters())), 1) + + def test_serialization_parametrization(self): + r"""Test that it is possible to serialize a parametrized model via state_dict""" + # A stateful parametrization + class Orthogonal(nn.Module): + def __init__(self, n): + super().__init__() + self.register_buffer("id", torch.eye(n)) + self.register_buffer("B", torch.empty(n, n)) + init.orthogonal_(self.B) + + def forward(self, X): + A = X.triu(1) + A = A - A.T + return self.B @ torch.solve(self.id - A, self.id + A).solution + + def get_model(): + model = torch.nn.Sequential( + torch.nn.Linear(5, 5), + torch.nn.ReLU(), + torch.nn.Linear(5, 1), + ) + + parametrize.register_parametrization(model[0], "weight", Orthogonal(5)) + return model + + model = get_model() + + prev_weight = model[0].weight + prev_B = model[0].parametrizations.weight[0].B + + new_model = get_model() + with TemporaryFileName() as fname: + torch.save(model.state_dict(), fname) + new_model.load_state_dict(torch.load(fname)) + + # Integrity tests + self.assertTrue(parametrize.is_parametrized(new_model[0], "weight")) + self.assertEqual(prev_weight, new_model[0].weight) + self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B) + + # Trying to save the whole parametrized model raises + with self.assertRaisesRegex(RuntimeError, "state_dict"): + with TemporaryFileName() as fname: + torch.save(model, fname) + + def test_initialization_parametrization(self): + r"""Test that it is possible to initialize a parametrization when it + implements a `right_inverse` method + """ + class Skew(nn.Module): + def forward(self, X): + A = X.triu(1) + return A - A.T + + def is_skew(self, A): + return torch.allclose(A, -A.T, atol=1e-6) + + def right_inverse(self, X): + if not self.is_skew(X): + raise ValueError("The matrix is not skew-symmetric.") + return X.triu(1) + + # Implements a Cayley map where right_inverse is not quite the inverse of forward + class Orthogonal(nn.Module): + def __init__(self, n): + super().__init__() + self.register_buffer("B", torch.eye(n)) + + def forward(self, A): + Id = torch.eye(X.size(0)) + return self.B @ torch.solve(Id - A, Id + A).solution + + def is_orthogonal(self, X): + Id = torch.eye(X.size(0)) + return torch.allclose(X.T @ X, Id, atol=1e-4) + + def right_inverse(self, X): + if not self.is_orthogonal(X): + raise ValueError("The input is not orthogonal.") + # cayley(0) == Id, so B @ cayley(0) == B + self.B = X + return torch.zeros_like(X) + + N = 5 + model = nn.Linear(N, N) + # Register the skew-symmetric onstraint. The result is now skew-symmetric + parametrize.register_parametrization(model, "weight", Skew()) + X = torch.rand(N, N) + # X is not skew-symmetric, so it throws an error + with self.assertRaises(ValueError): + model.weight = X + # Make X skew-symmetric + X = X - X.T + model.weight = X + self.assertEqual(model.parametrizations.weight.original, X.triu(1)) + self.assertEqual(model.weight, X) + + # Having several parametrizations registered should work in the same way + parametrize.register_parametrization(model, "weight", Orthogonal(N)) + # Register now the Cayley map. The result is now orthogonal + X = torch.rand(N, N) + # X is not orthogonal, so it throws an error + with self.assertRaises(ValueError): + model.weight = X + init.orthogonal_(X) + model.weight = X + self.assertEqual(model.weight, X) + self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) + + def test_errors_parametrization(self): + # A parametrization shall not change the size of the parameter + class ChangeSize(nn.Module): + def forward(self, x): + return x[:-1] + + # A simple parametrization that does not implement a right_inverse + class Double(nn.Module): + def forward(self, x): + return 2 * x + + module = nn.Linear(3, 4) + # This should not throw when registering + parametrize.register_parametrization(module, "weight", ChangeSize()) + # It throws in the forward + with self.assertRaisesRegex(RuntimeError, "may not change the size"): + module(torch.rand(2)) + # Undo + parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) + self.assertFalse(parametrize.is_parametrized(module)) + + # Removing a parametrization from an unparametrized tensor throws + with self.assertRaisesRegex(ValueError, "does not have a parametrization"): + parametrize.remove_parametrizations(module, "bias") + # Nothing odd happens + self.assertFalse(parametrize.is_parametrized(module)) + + # Register a parametrization on a non-existing parameter breaks + with self.assertRaisesRegex(ValueError, "does not have a parameter"): + parametrize.register_parametrization(module, "foo", ChangeSize()) + self.assertFalse(parametrize.is_parametrized(module)) + + # Try to assign to a parametrization that does not implement `right_inverse` + parametrize.register_parametrization(module, "weight", Double()) + with self.assertRaisesRegex(RuntimeError, "right_inverse"): + module.weight = torch.rand(4, 3) + # Undo + parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) + self.assertFalse(parametrize.is_parametrized(module)) + + def test_caching_parametrization(self): + r"""Test the caching system of a parametrization""" + # Define a couple matrix parametrizations + class Skew(nn.Module): + def forward(self, X): + X = X.tril(-1) + return X - X.T + + class Orthogonal(nn.Module): + def forward(self, X): + Id = torch.eye(X.size(0), device=X.device) + return torch.solve(Id - X, Id + X).solution + + model = nn.Linear(5, 5) + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + + # Test that the caching system works + with parametrize.cached(): + X = model.weight + Y = model.weight + self.assertEqual(id(X), id(Y)) + + def test_dtype_parametrization(self): + r"""Test a case that is not allowed when removing a parametrization""" + class ChangeType(nn.Module): + def forward(self, X): + return X.double() + + module = nn.Linear(4, 4).float() + input_ = torch.rand(4).double() + # It is allowed to register a parametrization that changes the dtype + parametrize.register_parametrization(module, "weight", ChangeType()) + module(input_) + # We can remove it leaving the original tensor + parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) + # But leaving it parametrized breaks + parametrize.register_parametrization(module, "weight", ChangeType()) + with self.assertRaisesRegex(ValueError, "changes the dtype"): + parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) + # torch/nn/utils/prune.py @unittest.skipIf(not TEST_NUMPY, "numpy not found") def test_validate_pruning_amount_init(self): diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py new file mode 100644 index 0000000000000..b90404d4b6945 --- /dev/null +++ b/torch/nn/utils/parametrize.py @@ -0,0 +1,364 @@ +import torch +from torch.nn.modules.container import ModuleList, ModuleDict, Module +from torch.nn.parameter import Parameter +from torch import Tensor +from typing import Union, Optional, Iterable, Dict, Tuple +from contextlib import contextmanager + + +_cache_enabled = 0 +_cache: Dict[Tuple[int, str], Optional[Tensor]] = {} + + +@contextmanager +def cached(): + r"""Context manager that enables the caching system within parametrizations + registered with :func:`register_parametrization`. + The value of the parametrized objects is computed and cached the first time + they are required when this context manager is active. The cached values are + discarded when leaving the context manager. + This is useful when using a parametrized parameter more than once in the forward pass. + An example of this is when parametrizing the recurrent kernel of an RNN or when + sharing weights. + The simplest way to activate the cache is by wrapping the forward pass of the neural network + + .. code-block:: python + + import torch.nn.utils.parametrize as P + ... + with P.cached(): + output = model(inputs) + + in training and evaluation. One may also wrap the parts of the modules that use + several times the parametrized tensors. For example, the loop of an RNN with a + parametrized recurrent kernel: + + .. code-block:: python + + with P.cached(): + for x in xs: + out_rnn = self.rnn_cell(x, out_rnn) + """ + global _cache + global _cache_enabled + _cache_enabled += 1 + try: + yield + finally: + _cache_enabled -= 1 + if not _cache_enabled: + _cache = {} + + +class ParametrizationList(ModuleList): + r"""A sequential container that holds and manages the ``original`` parameter of + a parametrized :class:`~nn.Parameter` or buffer. It is the type of + ``module.parametrizations[tensor_name]`` when ``tensor_name`` has been parametrized + with :func:`register_parametrization` + + .. note :: + This class is used internally by :func:`register_parametrization`. It is documented + here for completeness. It should not be instantiated by the user. + + Args: + modules (iterable): an iterable of modules representing the parametrizations + original (Parameter or Tensor): parameter or buffer that is parametrized + """ + original: Tensor + + def __init__( + self, modules: Iterable[Module], original: Union[Tensor, Parameter] + ) -> None: + super().__init__(modules) + if isinstance(original, Parameter): + self.register_parameter("original", original) + else: + self.register_buffer("original", original) + + def set_original_(self, value: Tensor) -> None: + r"""This method is called when assigning to a parametrized tensor. + It calls the methods ``right_inverse`` (see :func:`register_parametrization`) + of the parametrizations in the inverse order that they have been registered. + Then, it assigns the result to ``self.original``. + + Args: + value (Tensor): Value to which initialize the module + + Raises: + RuntimeError: if any of the parametrizations do not implement a ```right_inverse`` method + """ + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore + if not hasattr(module, "right_inverse"): + raise RuntimeError( + "The parametrization '{}' does not implement a 'right_inverse' method. " + "Assigning to a parametrized tensor is only possible when all the parametrizations " + "implement a 'right_inverse' method.".format( + module.__class__.__name__ + ) + ) + + with torch.no_grad(): + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore + value = module.right_inverse(value) + self.original.copy_(value) + + def forward(self) -> Tensor: + x = self.original + for module in self: + x = module(x) + if x.size() != self.original.size(): + raise RuntimeError( + "The parametrization may not change the size of the parametrized tensor. " + "Size of original tensor: {} " + "Size of parametrized tensor: {}".format(self.original.size(), x.size()) + ) + return x + + +def _inject_new_class(module: Module) -> None: + r"""Sets up the parametrization mechanism used by parametrizations. + This works by substituting the class of the module by a class + that extends it to be able to inject a property + + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ + + def getstate(self): + raise RuntimeError( + "Serialization of parametrized modules is only " + "supported through state_dict(). See:\n" + "https://pytorch.org/tutorials/beginner/saving_loading_models.html" + "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" + ) + + param_cls = type( + "Parametrized{}".format(cls.__name__), + (cls,), + { + "__getstate__": getstate, + }, + ) + + module.__class__ = param_cls + + +def _inject_property(module: Module, tensor_name: str) -> None: + r"""Injects a property into module[tensor_name]. + It assumes that the class in the module has already been modified from its + original one using _inject_new_class and that the tensor under `tensor_name` + has already been moved out + + Args: + module (nn.Module): module into which to inject the property + tensor_name (str): name of the name of the property to create + """ + # We check the precondition. + # This should never fire if register_parametrization is correctly implemented + assert not hasattr(module, tensor_name) + + def get_parametrized(self) -> Tensor: + global _cache + + parametrization = self.parametrizations[tensor_name] + if _cache_enabled: + key = (id(module), tensor_name) + tensor = _cache.get(key) + if tensor is None: + tensor = parametrization() + _cache[key] = tensor + return tensor + else: + # If caching is not active, this function just evaluates the parametrization + return parametrization() + + def set_original(self, value: Tensor) -> None: + self.parametrizations[tensor_name].set_original_(value) + + setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) + + +def register_parametrization( + module: Module, tensor_name: str, parametrization: Module +) -> Module: + r"""Adds a parametrization to a tensor in a module. + When accessing ``module[tensor_name]``, the module will return the + parametrized version ``parametrization(module[tensor_name])``. The backward + pass will differentiate through the ``parametrization`` and if the original + tensor is a :class:``~Parameter``, it will be updated accordingly by the optimizer. + The first time that a module registers a parametrization, this function will add an attribute + ``parametrizations`` to the module of type :class:`~ParametrizationList`. + The list of parametrizations on a tensor will be accessible under + ``module.parametrizations[tensor_name]``. + The original tensor will be accessible under + ``module.parametrizations[tensor_name].original``. + Parametrizations may be composed by registering several parametrizations + on the same attribute. + Parametrized parameters and buffers have a built-in caching system that can be activated + using :func:`cached`. + A ``parametrization`` may optionally implement a method with signature + + .. code-block:: python + + def right_inverse(self, X: Tensor) -> Tensor + + If this method is implemented, it will be possible to assign to the parametrized tensor. + This may be used to initialize the tensor: + + >>> import torch + >>> import torch.nn.utils.parametrize as P + >>> + >>> class Symmetric(torch.nn.Module): + >>> def forward(self, X): + >>> return X.triu() + X.triu(1).T # Return a symmetric matrix + >>> + >>> def right_inverse(self, A): + >>> return A.triu() + >>> + >>> m = torch.nn.Linear(5, 5) + >>> P.register_parametrization(m, "weight", Symmetric()) + >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric + True + >>> A = torch.rand(5, 5) + >>> A = A + A.T # A is now symmetric + >>> m.weight = A # Initialize the weight to be the symmetric matrix A + >>> print(torch.allclose(m.weight, A)) + True + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this, as we did with ``Symmetric`` in the example above. + + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (str): name of the parameter, buffer on which to register + the parametrization + parametrization (nn.Module): the parametrization to register + + Returns: + Module: module + + Raises: + ValueError: if the module does not have a parameter or a buffer named ``tensor_name`` + """ + if is_parametrized(module, tensor_name): + # Just add the new parametrization to the parametrization list + module.parametrizations[tensor_name].append(parametrization) # type: ignore + elif tensor_name in module._buffers or tensor_name in module._parameters: + # Set the parametrization mechanism + # Fetch the original buffer or parameter + original = getattr(module, tensor_name) + # Delete the previous parameter or buffer + delattr(module, tensor_name) + # If this is the first parametrization of a buffer or parameter of the module, + # we prepare the module to inject the property + if not is_parametrized(module): + # Change the class + _inject_new_class(module) + # Inject the a ``ModuleDict`` into the instance under module.parametrizations + module.parametrizations = ModuleDict() + # Add a property into the class + _inject_property(module, tensor_name) + # Add a ParametrizationList + module.parametrizations[tensor_name] = ParametrizationList( # type: ignore + [parametrization], original + ) + else: + raise ValueError( + "Module '{}' does not have a parameter, a buffer, nor a " + "parametrized element with name '{}'".format(module, tensor_name) + ) + return module + + +def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: + r"""Returns ``True`` if module has an active parametrization. + If the argument ``name`` is specified, it returns ``True`` if + ``module[name]`` is parametrized. + + Args: + module (nn.Module): module to query + name (str, optional): attribute in the module to query + Default: ``None`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + + +def remove_parametrizations( + module: Module, tensor_name: str, leave_parametrized: bool = True +) -> Module: + r"""Removes the parametrizations on a tensor in a module. + + - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to + its current output. In this case, the parametrization shall not change the ``dtype`` + of the tensor. + - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to + the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + + Args: + module (nn.Module): module from which remove the parametrization + tensor_name (str): name of the parametrization to be removed + leave_parametrized (bool, optional): leave the attribute ``tensor_name`` parametrized. + Default: ``True`` + + Returns: + Module: module + + Raises: + ValueError: if ``module[tensor_name]`` is not parametrized + ValueError: if ``leave_parametrized=True`` and the parametrization changes the ``dtype`` of the tensor + """ + + if not is_parametrized(module, tensor_name): + raise ValueError( + "Module {} does not have a parametrization on {}".format( + module, tensor_name + ) + ) + + # Fetch the original tensor + original = module.parametrizations[tensor_name].original # type: ignore + if leave_parametrized: + t = getattr(module, tensor_name) + # If they have the same dtype, we reuse the original tensor. + # We do this so that the parameter does not to change the id() + # This way the user does not need to update the optimizer + if t.dtype == original.dtype: + original.set_(t) + else: + raise ValueError( + "The parametrization changes the dtype of the tensor from {} to {}. " + "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) " + "in this case.".format(original.dtype, t.dtype) + ) + # Delete the property that manages the parametrization + delattr(module.__class__, tensor_name) + # Delete the ParametrizationList + del module.parametrizations[tensor_name] # type: ignore + + # Restore the parameter / buffer into the main class + if isinstance(original, Parameter): + module.register_parameter(tensor_name, original) + else: + module.register_buffer(tensor_name, original) + + # Roll back the parametrized class if no other buffer or parameter + # is currently parametrized in this class + if not is_parametrized(module): + delattr(module, "parametrizations") + # Restore class + orig_cls = module.__class__.__bases__[0] + module.__class__ = orig_cls + return module