Skip to content

Commit fc6b3dc

Browse files
authored
Merge pull request pytorch#22 from pytorch-labs/jcaip/sparsity
[sparse] add sparsity, add wanda sparsifier to ao
2 parents a55e2d2 + 9c84256 commit fc6b3dc

File tree

4 files changed

+285
-0
lines changed

4 files changed

+285
-0
lines changed

test/sparsity/test_wanda.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import logging
2+
import unittest
3+
4+
import torch
5+
from torch import nn
6+
from torchao.sparsity import WandaSparsifier
7+
from torch.ao.pruning import FakeSparsity
8+
from torch.nn.utils.parametrize import is_parametrized
9+
from torch.testing._internal.common_pruning import SimpleLinear
10+
from torch.testing._internal.common_utils import TestCase
11+
12+
logging.basicConfig(
13+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
14+
)
15+
16+
17+
class TestWandaSparsifier(TestCase):
18+
"""
19+
Test Wanda Sparsifier
20+
"""
21+
22+
def test_prepare(self):
23+
model = SimpleLinear()
24+
sparsifier = WandaSparsifier()
25+
sparsifier.prepare(model, config=None)
26+
for g in sparsifier.groups:
27+
module = g["module"]
28+
# Check mask exists
29+
assert hasattr(module.parametrizations["weight"][0], "mask")
30+
# Check parametrization exists and is correct
31+
assert is_parametrized(module, "weight")
32+
assert type(module.parametrizations.weight[0]) == FakeSparsity
33+
# check activation observer is present
34+
assert hasattr(module, "activation_post_process")
35+
36+
def test_squash_mask(self):
37+
# check observers and parameterizations removed
38+
model = SimpleLinear()
39+
sparsifier = WandaSparsifier()
40+
sparsifier.prepare(model, config=None)
41+
sparsifier.squash_mask()
42+
for g in sparsifier.groups:
43+
module = g["module"]
44+
assert not is_parametrized(module, "weight")
45+
assert not hasattr(module, "mask")
46+
assert not hasattr(module, "activation_post_process")
47+
48+
def test_one_layer_mlp_2x4(self):
49+
model = nn.Sequential(nn.Linear(8, 1))
50+
weights = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
51+
model[0].weight.data.copy_(weights.data)
52+
X = torch.ones(1, 8)
53+
54+
sparsifier = WandaSparsifier(semi_structured_block_size=4)
55+
sparsifier.prepare(model, config=None)
56+
57+
model(X)
58+
59+
sparsifier.step()
60+
sparsifier.squash_mask()
61+
62+
sparsity = (model[0].weight == 0).float().mean()
63+
assert sparsity == 0.5
64+
65+
expected_fc = torch.tensor([[0, 0, 3, 4, 0, 0, 7, 8]], dtype=torch.float32)
66+
assert torch.allclose(model[0].weight.data, expected_fc, rtol=1e-05, atol=1e-07)
67+
68+
def test_one_layer_mlp_unstructured(self):
69+
model = nn.Sequential(nn.Linear(4, 1))
70+
weights = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32)
71+
model[0].weight.data.copy_(weights.data)
72+
X = torch.tensor([[100, 10, 1, 0.1]], dtype=torch.float32)
73+
74+
sparsifier = WandaSparsifier(sparsity_level=0.5)
75+
sparsifier.prepare(model, config=None)
76+
77+
model(X)
78+
79+
sparsifier.step()
80+
sparsifier.squash_mask()
81+
82+
sparsity = (model[0].weight == 0).float().mean()
83+
assert sparsity == 0.5
84+
85+
expected_fc = torch.tensor([[1, 2, 0, 0]], dtype=torch.float32)
86+
assert torch.allclose(model[0].weight.data, expected_fc, rtol=1e-05, atol=1e-07)
87+
88+
def test_two_layer_mlp_unstructured(self):
89+
model = nn.Sequential(
90+
nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)
91+
) # C_in by C_out
92+
X1 = torch.randn(100, 128) # B1 by C_in
93+
X2 = torch.randn(50, 128) # B2 by C_in
94+
95+
sparsifier = WandaSparsifier(sparsity_level=0.5)
96+
sparsifier.prepare(model, config=None)
97+
98+
model(X1)
99+
model(X2)
100+
sparsifier.step()
101+
102+
cnt = 0
103+
for m in model.modules():
104+
if isinstance(m, nn.Linear):
105+
cnt += 1
106+
sparsity_level = (m.weight == 0).float().mean()
107+
assert (
108+
sparsity_level == 0.5
109+
), f"sparsity for linear layer {cnt} should be 0.5"
110+
111+
sparsifier.squash_mask()
112+
113+
if __name__ == "__main__":
114+
unittest.main()

torchao/sparsity/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .wanda import WandaSparsifier # noqa: F403
8+
from .utils import PerChannelNormObserver # noqa: F403
9+
10+
__all__ = [
11+
"WandaSparsifier",
12+
"PerChannelNormObserver"
13+
]

torchao/sparsity/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
from torch.ao.quantization.observer import UniformQuantizationObserverBase
3+
4+
__all__ = ["PerChannelNormObserver"]
5+
6+
# Observers
7+
class PerChannelNormObserver(UniformQuantizationObserverBase):
8+
"""
9+
A custom observer that computes the L2 norm of each channel and stores it in a buffer.
10+
"""
11+
12+
def __init__(self, **kwargs) -> None:
13+
# init with fixed qparams for quantization flow
14+
super().__init__(
15+
dtype=torch.quint8,
16+
qscheme=torch.per_channel_affine,
17+
reduce_range=False,
18+
quant_min=None,
19+
quant_max=None,
20+
eps=torch.finfo(torch.float32).eps,
21+
**kwargs
22+
)
23+
# set averaging constant so quantization flow knows observer is memoryless.
24+
self.averaging_constant = 1.0
25+
self.register_buffer("norm", torch.tensor([]))
26+
27+
def forward(self, x_orig):
28+
if x_orig.numel() == 0:
29+
return x_orig
30+
x = x_orig.detach() # avoid keeping autograd tape
31+
32+
# channel_ax is always the last dimension
33+
new_axis_list = [i for i in range(x.dim())] # noqa: C416
34+
new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0]
35+
y = x.permute(new_axis_list)
36+
y = torch.flatten(y, start_dim=1)
37+
norm = torch.norm(y, dim=1) ** 2
38+
39+
if self.norm.numel() == 0:
40+
self.norm.resize_(norm.shape)
41+
self.norm.copy_(norm)
42+
else:
43+
self.norm += norm
44+
45+
return x_orig
46+
47+
def calculate_qparams(self):
48+
raise NotImplementedError("PerChannelNormObserver is designed to store activations only. ")

torchao/sparsity/wanda.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
2+
import warnings
3+
4+
from typing import Dict, List, Optional, Tuple
5+
6+
import torch
7+
from torch import nn
8+
from torch.ao.pruning import BaseSparsifier
9+
from torch.ao.quantization import default_placeholder_observer, QConfig
10+
from torch.ao.quantization.quantize import _remove_qconfig
11+
from .utils import PerChannelNormObserver
12+
13+
__all__ = ["WandaSparsifier"]
14+
15+
16+
class WandaSparsifier(BaseSparsifier):
17+
r"""Wanda sparsifier
18+
19+
Wanda (Pruning by Weights and activations), proposed in https://arxiv.org/abs/2306.11695
20+
is an activation aware pruning method. The sparsifier removes weights based on the product
21+
of the input activation norm and the weight magnitude.
22+
23+
This sparsifier is controlled by three variables:
24+
1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out;
25+
26+
Args:
27+
sparsity_level: The target level of sparsity;
28+
model: The model to be sparsified;
29+
"""
30+
31+
def __init__(
32+
self,
33+
sparsity_level: float = 0.5,
34+
semi_structured_block_size: Optional[int] = None,
35+
):
36+
defaults = {
37+
"sparsity_level": sparsity_level,
38+
"semi_structured_block_size": semi_structured_block_size,
39+
}
40+
if semi_structured_block_size is not None:
41+
m = semi_structured_block_size
42+
warnings.warn(
43+
f"WandaSparsifier got semi_structured_bock_size={m}, sparsity_level fixed to 50% ({m // 2}:{m}) sparsity"
44+
)
45+
super().__init__(defaults=defaults)
46+
47+
def prepare(self, model: nn.Module, config: List[Dict]) -> None:
48+
# activation: use PerChannelNormObserver
49+
# use no-op placeholder weight observer
50+
model.qconfig = QConfig(
51+
activation=PerChannelNormObserver, weight=default_placeholder_observer
52+
) # type: ignore[assignment]
53+
torch.ao.quantization.prepare(model, inplace=True)
54+
55+
# call superclass prepare
56+
super().prepare(model, config)
57+
58+
def update_mask( # type: ignore[override]
59+
self, module: nn.Module, tensor_name: str, sparsity_level: float, **kwargs
60+
) -> None:
61+
r"""Pruning function for WandaSparsifier
62+
63+
The activation statistics is retrieved first in the `act_per_input` variable.
64+
Then the Wanda pruning metric is computed. The weight matrix is then pruned
65+
by comparing this metric across the whole current layer.
66+
"""
67+
68+
# Step 1: get the tensor and the mask from the parametrizations
69+
mask = getattr(module.parametrizations, tensor_name)[0].mask
70+
tensor = getattr(module.parametrizations, tensor_name).original
71+
activation_norm_per_channel = module.activation_post_process.norm
72+
73+
# Step 2: Calculate Wx
74+
pruning_metric = torch.abs(tensor) * activation_norm_per_channel
75+
76+
# defaults for unstructured sparsity
77+
block_size = pruning_metric.numel()
78+
num_specified = int(block_size * sparsity_level)
79+
# if set to use semi-structured, ignore sparsity_level
80+
if kwargs.get("semi_structured_block_size", None) is not None:
81+
block_size = kwargs["semi_structured_block_size"]
82+
num_specified = block_size // 2
83+
84+
# get indicies to prune
85+
pruning_inds = pruning_metric.view(-1, block_size).argsort(dim=1)[
86+
:, :num_specified
87+
]
88+
# update mask
89+
mask.data.view(-1, block_size).scatter_(
90+
1, pruning_inds, torch.zeros_like(pruning_inds, dtype=mask.dtype)
91+
)
92+
93+
def squash_mask(
94+
self,
95+
params_to_keep: Optional[Tuple[str, ...]] = None,
96+
params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None,
97+
*args,
98+
**kwargs,
99+
):
100+
# remove quantization config
101+
for config in self.groups:
102+
module = config["module"]
103+
tensor_name = config["tensor_name"]
104+
_remove_qconfig(module)
105+
106+
# remove parameterizations
107+
super().squash_mask(
108+
params_to_keep=params_to_keep,
109+
params_to_keep_per_layer=params_to_keep_per_layer,
110+
)

0 commit comments

Comments
 (0)