diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dd19189332cf1e..71b8ac979ab7b8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,7 +29,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from zipfile import is_zipfile import torch @@ -570,6 +570,65 @@ def set_initialized_submodules(model, state_dict_keys): return not_initialized_submodules +def _end_ptr(tensor: torch.Tensor) -> int: + # extract the end of the pointer if the tensor is a slice of a bigger tensor + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + disjoint_tensors = [] + shared_tensors = [] + for tensors in filtered_tensors: + if len(tensors) == 1: + disjoint_tensors.append(tensors.pop()) + else: + shared_tensors.append(tensors) + return shared_tensors, disjoint_tensors + + +def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: + shared_tensors = [] + identical = [] + for shared in tensors: + if len(shared) < 2: + continue + + areas = collections.defaultdict(set) + for name in shared: + tensor = state_dict[name] + area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) + areas[area].add(name) + if len(areas) == 1: + identical.append(shared) + else: + shared_tensors.append(shared) + return shared_tensors, identical + + def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -2382,6 +2441,8 @@ def save_pretrained( # These are all the pointers of shared tensors. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} warn_names = set() + error_names = set() + to_delete_names = set() for names in shared_ptrs.values(): # Removing the keys which are declared as known duplicates on # load. This allows to make sure the name which is kept is consistent. @@ -2392,25 +2453,42 @@ def save_pretrained( if matches_pattern and name in state_dict: found += 1 if found < len(names): - del state_dict[name] - - # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. - # If the link between tensors was done at runtime then `from_pretrained` will not get - # the key back leading to random tensor. A proper warning will be shown - # during reload (if applicable), but since the file is not necessarily compatible with - # the config, better show a proper warning. - found = 0 - for name in names: - if name in state_dict: - found += 1 - if found > 1: - del state_dict[name] - warn_names.add(name) + to_delete_names.add(name) + # We are entering a place where the weights and the transformers configuration do NOT match. + shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) + # Those are actually tensor sharing but disjoint from each other, we can safely clone them + # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. + for name in disjoint_names: + state_dict[name] = state_dict[name].clone() + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + shared_names, identical_names = _find_identical(shared_names, state_dict) + # delete tensors that have identical storage + for inames in identical_names: + known = inames.intersection(to_delete_names) + for name in known: + del state_dict[name] + unknown = sorted(inames.difference(to_delete_names)) + for name in unknown[1:]: + del state_dict[name] + warn_names.add(name) + + error_names.update(shared_names) + if len(warn_names) > 0: logger.warning_once( f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", ) + if len(error_names) > 0: + raise RuntimeError( + f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", + ) + # Shard the model if it is too big. if not _hf_peft_config_loaded: weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index cef56822dc3e95..f7878cb68d803d 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -257,6 +257,26 @@ def test_model_from_pretrained_subfolder(self): self.assertTrue(check_models_equal(model, model_loaded)) + def test_model_manually_shared_disjointed_tensors_optimum(self): + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + model = BertModel(config) + + # Let's fuse qkv + attn = model.encoder.layer[0].attention.self + q = attn.query.weight + k = attn.key.weight + v = attn.value.weight + # Force some shared storage + qkv = torch.stack([q, k, v], dim=0) + attn.query.weight = torch.nn.Parameter(qkv[0]) + attn.key.weight = torch.nn.Parameter(qkv[1]) + attn.value.weight = torch.nn.Parameter(qkv[2]) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + model_loaded = BertModel.from_pretrained(tmp_dir) + + self.assertTrue(check_models_equal(model, model_loaded)) + def test_model_from_pretrained_subfolder_sharded(self): config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel(config)