Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
import torch.nn.utils.parametrize as P
import tqdm
from compressed_tensors.registry.registry import RegistryMixin, T
from compressed_tensors.transform import (
TransformArgs,
Expand Down Expand Up @@ -84,15 +85,21 @@ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBas
"""
raise NotImplementedError()

def apply_to_model(self, model: Module):
def apply_to_model(self, model: Module, use_tqdm=True):
"""
Create transforms and apply them to the model

:param model: module to apply transforms to
"""
for arg in self.scheme.apply:
for _, module in match_named_modules(model, arg.targets, arg.ignore):
self._apply_to_module(module, arg)
modules_args = [
(module, arg)
for arg in self.scheme.apply
for _, module in match_named_modules(model, arg.targets, arg.ignore)
]

desc = f"Applying {self.name} transforms"
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
self._apply_to_module(module, arg)

self._update_tied_weights()

Expand Down
27 changes: 15 additions & 12 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,28 @@ def create_transform(self, module: Module, args: TransformArgs):
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = self.scheme.precision
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

factory_kwargs = {"construct_device": exec_device}
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
device = get_offloaded_device(module)
precision = self.scheme.precision if args.is_online() else torch.float64

factory_kwargs = {
"device": device,
"construct_device": exec_device,
"precision": precision,
}
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
# TODO: permutations should be keyed by fused modules, not weight
perm = self.perms[weight] if self.scheme.randomize else None
return HadamardTransform(weight, perm, self.scheme, args, type(module))

def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
precision: dtype,
) -> Parameter:
# construct on execution device, cache on offload device
data = deterministic_hadamard_matrix(size, dtype, construct_device)
data = deterministic_hadamard_matrix(size, precision, construct_device)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)

Expand All @@ -94,8 +98,7 @@ def __init__(
self.scheme = scheme
self.args = args
self.module_type = module_type
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
self._precision = scheme.precision if args.is_online() else torch.float64
self._scale = torch.tensor(weight.size(0), dtype=torch.float64).sqrt()

def forward(self, value: Tensor) -> Tensor:
weight = self.weight
Expand All @@ -108,8 +111,8 @@ def forward(self, value: Tensor) -> Tensor:

return (
apply_transform_weight(
weight.to(self._precision),
value.to(self._precision),
weight.to(device=value.device),
value.to(dtype=weight.dtype),
self.args.location,
self.module_type,
)
Expand Down
25 changes: 14 additions & 11 deletions src/compressed_tensors/transform/factory/matrix_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
apply_transform_weight,
get_transform_size,
)
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from compressed_tensors.utils.offload import get_offloaded_device
from torch import Tensor, device, dtype
from torch.nn import Module, Parameter

Expand Down Expand Up @@ -52,19 +52,23 @@ def create_transform(self, module: Module, args: TransformArgs):
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = self.scheme.precision
device = get_offloaded_device(module)
precision = self.scheme.precision if args.is_online() else torch.float64

weight = self.weights[size, dtype, device]
factory_kwargs = {"device": device, "precision": precision}
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
if args.inverse:
weight = self.inverses[weight]

return RandomMatrixTransform(weight, self.scheme, args, type(module))

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
# TODO: verify that weight is invertible (has non-zero determinant)
def _create_weight(self, size: int, device: device, precision: dtype) -> Parameter:
# TODO: construct such that weight is invertible (has non-zero determinant)
data = torch.rand(
(size, size), generator=self.generator, dtype=dtype, device=device
(size, size),
generator=self.generator,
dtype=precision,
device=device,
)
return Parameter(data, requires_grad=self.scheme.requires_grad)

Expand All @@ -87,21 +91,20 @@ def __init__(
self.scheme = scheme
self.args = args
self.module_type = module_type
self._precision = scheme.precision if args.is_online() else torch.float64

def forward(self, value: Tensor) -> Parameter:
return apply_transform_weight(
self.weight.to(self._precision),
value.to(self._precision),
self.weight.to(device=value.device),
value.to(dtype=self.weight.dtype),
self.args.location,
self.module_type,
).to(value.dtype)

def right_inverse(self, value: Tensor) -> Tensor:
inverse = high_precision_invert(self.weight)
return apply_transform_weight(
inverse.to(self._precision),
value.to(self._precision),
inverse.to(device=value.device),
value.to(dtype=inverse.dtype),
self.args.location,
self.module_type,
).to(value.dtype)
Expand Down
5 changes: 2 additions & 3 deletions src/compressed_tensors/transform/factory/random_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ class RandomHadamardFactory(HadamardFactory):
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
precision: dtype,
) -> Parameter:
# construct on execution device, cache on offload device
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
data = random_hadamard_matrix(size, precision, construct_device, self.generator)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)
5 changes: 4 additions & 1 deletion src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
first_key = list(module._hf_hook.weights_map.keys())[0]
prefix_dataset = module._hf_hook.weights_map.dataset
return prefix_dataset[first_key].device
return next(module.parameters()).device
else:
# if the module is not offloaded, then any addded weights
# should be placed the module's execution device
return get_execution_device(module)


@check_accelerate(fallback=None)
Expand Down