Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/compressed_tensors/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Dict

import torch
Expand All @@ -38,12 +39,61 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
# attach config to model for compression/serialization
setattr(model, TRANSFORM_CONFIG_NAME, config)

# populate `_tied_weights_keys` for proper loading by transformers
_update_transforms_tied_weights(model)

# ensure that tied weight transforms can be serialized without aliases
# In the future, this could be done by transformers or model compressor
# which would make this more robust to changing dispatches after transforms
_tie_offloaded_tensors(model)


def _update_transforms_tied_weights(model: torch.nn.Module):
"""
This function updates the `_tied_weights_keys` and `all_tied_weights_keys`
attributes of the given model with transform weights.

This function is needed because transformers only knows which weights are shared
via the `_tied_weights_keys` attributes. These attributes are used to tie
weights after the model has loaded.

CompressedTensors does not enforce a particular weight is the source weight :.
We rely on correctness of the following mapping in PreTrainedModel.tie_weights():
```
B -> A
C -> A
D -> A

Where any of A,B,C,D might be the loaded source weight
```
This property is tested by `test_modeling_utils::BaseModelWithMultipleTiedWeights`
"""
from compressed_tensors.transform import TransformBase

# 1. find which transform weights are shared
# create mapping: tensor_hash -> key
weight_to_keys: dict[int, str] = defaultdict(list)
for name, module in model.named_modules():
if isinstance(module, TransformBase):
for param_name, param_hash in module.tied_weights_hash.items():
param_fqn = f"{name}.{param_name}" if name else param_name
weight_to_keys[param_hash].append(param_fqn)

# 2. assign each group of shared weights to the same value
# create tied weights: key -> tied_keys[0]
transform_tied_weights_keys = {}
for keys in weight_to_keys.values():
keys = list(keys)
for key in keys[1:]: # skip A -> A
transform_tied_weights_keys[key] = keys[0]

# 3. update tied weights attributes
if not getattr(model, "_tied_weights_keys", None) is None:
model._tied_weights_keys = {}
model._tied_weights_keys.update(transform_tied_weights_keys)
model.all_tied_weights_keys = model._tied_weights_keys


def _tie_offloaded_tensors(model: torch.nn.Module):
"""
When accelerate replaces tensors with meta tensors during offloading, the meta
Expand Down
10 changes: 9 additions & 1 deletion src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ class TransformBase(InternalModule, ABC):

args: TransformArgs
weight: Parameter
_dynamic_tied_weights_keys: List[str] = ["weight"]
tied_weights_hash: dict[str, int]

def __init__(self):
super().__init__()
self.tied_weights_hash = {}

@abstractmethod
def forward(self, value: Tensor) -> Tensor:
Expand All @@ -211,3 +215,7 @@ def right_inverse(self, value: Tensor) -> Tensor:

def __repr__(self):
return f"{self.__class__.__name__}(inverse={self.args.inverse})"

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
self.tied_weights_hash[name] = id(param)
super().register_parameter(name, param)
4 changes: 1 addition & 3 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import Optional

import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
Expand Down Expand Up @@ -83,8 +83,6 @@ def _create_permutation(self, weight: Parameter) -> Parameter:


class HadamardTransform(TransformBase):
_dynamic_tied_weights_keys: List[str] = ["weight", "perm"]

def __init__(
self,
weight: Parameter,
Expand Down
Loading