1313# limitations under the License.
1414
1515from abc import ABC , abstractmethod
16- from collections import defaultdict
17- from typing import List , Optional , Set , Tuple
16+ from typing import List , Optional , Set
1817
1918import torch
2019import torch .nn .utils .parametrize as P
@@ -101,8 +100,6 @@ def apply_to_model(self, model: Module, use_tqdm=True):
101100 for module , arg in tqdm .tqdm (modules_args , desc = desc , disable = (not use_tqdm )):
102101 self ._apply_to_module (module , arg )
103102
104- self ._update_tied_weights ()
105-
106103 def _apply_to_module (self , module : Module , args : TransformArgs ):
107104 """
108105 Create transforms and apply them to the module
@@ -165,31 +162,6 @@ def output_hook(_, _input, output):
165162 else :
166163 raise NotImplementedError ()
167164
168- def _update_tied_weights (self ):
169- """
170- Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171- which is used by transformers to detect and remove shared pointers
172- during saving
173- """
174- # map from data_ptrs to keys
175- ptr_to_keys : dict [int , List [Tuple [TransformBase , str ]]] = defaultdict (list )
176- for transform in self .transforms :
177- for name , param in transform .named_parameters (recurse = False ):
178- # NOTE: previously asserted that parent._hf_hook.place_submodules=False
179- if has_offloaded_params (transform ):
180- param = transform ._hf_hook .weights_map [name ]
181- ptr_to_keys [param .data_ptr ()].append ((transform , name ))
182-
183- # populate `_dynamic_tied_weights_keys` if there is more than one key
184- # and ensure that they share tensors
185- for shared_keys in ptr_to_keys .values ():
186- if len (shared_keys ) > 1 :
187- tensor = getattr (shared_keys [0 ][0 ], shared_keys [0 ][1 ])
188-
189- for transform , name in shared_keys :
190- transform ._dynamic_tied_weights_keys .add (name )
191- setattr (transform , name , tensor )
192-
193165
194166class TransformBase (InternalModule , ABC ):
195167 """
@@ -198,11 +170,7 @@ class TransformBase(InternalModule, ABC):
198170
199171 args : TransformArgs
200172 weight : Parameter
201- _dynamic_tied_weights_keys : Set [str ]
202-
203- def __init__ (self ):
204- super ().__init__ ()
205- self ._dynamic_tied_weights_keys = set ()
173+ _dynamic_tied_weights_keys : List [str ] = ["weight" ]
206174
207175 @abstractmethod
208176 def forward (self , value : Tensor ) -> Tensor :
0 commit comments