Skip to content

migrate sparsify_ to configs #1856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
73ea701
Update
vkuzo Mar 7, 2025
4f2c69d
Update
vkuzo Mar 7, 2025
75fb698
Update
vkuzo Mar 7, 2025
5796efc
Update
vkuzo Mar 7, 2025
3b016e7
Update
vkuzo Mar 7, 2025
11bce08
Update
vkuzo Mar 7, 2025
3a91a06
Update
vkuzo Mar 7, 2025
125546b
Update
vkuzo Mar 7, 2025
130a52e
Update
vkuzo Mar 7, 2025
70e0e53
Update
vkuzo Mar 7, 2025
049d750
Update
vkuzo Mar 7, 2025
1ea0e2f
Update
vkuzo Mar 7, 2025
cf3ad33
Update
vkuzo Mar 7, 2025
19ac99d
Update
vkuzo Mar 7, 2025
5deed22
Update
vkuzo Mar 7, 2025
2a1f7b2
Update
vkuzo Mar 7, 2025
5fa0e27
Update
vkuzo Mar 7, 2025
160cc29
Update
vkuzo Mar 7, 2025
0ecb02d
Update
vkuzo Mar 7, 2025
7cb810c
Update
vkuzo Mar 7, 2025
eb567cd
Update
vkuzo Mar 7, 2025
cd97b30
Update
vkuzo Mar 7, 2025
de38b6e
Update
vkuzo Mar 7, 2025
bfba1d9
Update
vkuzo Mar 7, 2025
6f3d127
Update
vkuzo Mar 8, 2025
cda5d18
Update
vkuzo Mar 8, 2025
96d74a3
Update
vkuzo Mar 8, 2025
c83c029
Update
vkuzo Mar 8, 2025
6f1c92d
Update
vkuzo Mar 8, 2025
95be23e
Update
vkuzo Mar 8, 2025
fdb292e
Update
vkuzo Mar 8, 2025
706ff1f
Update
vkuzo Mar 8, 2025
ac2314e
Update
vkuzo Mar 8, 2025
8002c39
Update
vkuzo Mar 8, 2025
a4dfaa1
Update
vkuzo Mar 8, 2025
0506d32
Update
vkuzo Mar 8, 2025
237a72a
Update
vkuzo Mar 8, 2025
7183e83
Update
vkuzo Mar 8, 2025
10d0dff
Update
vkuzo Mar 8, 2025
2ec7827
Update
vkuzo Mar 8, 2025
3f10bc5
Update
vkuzo Mar 8, 2025
fa8c0f1
Update
vkuzo Mar 8, 2025
50a7f9f
Update
vkuzo Mar 8, 2025
8c8d7e4
Update
vkuzo Mar 8, 2025
34ce5f4
Update
vkuzo Mar 10, 2025
c038451
Update
vkuzo Mar 10, 2025
4f195b6
Update
vkuzo Mar 10, 2025
db3e3d3
Update
vkuzo Mar 12, 2025
5f742de
Update
vkuzo Mar 12, 2025
bdf1aea
Update
vkuzo Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions test/sparsity/test_supermask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from torch import nn
from torch.testing._internal import common_utils

from torchao.sparsity import sparsify_

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
Expand All @@ -30,13 +28,10 @@ def test_supermask(self, sparsity_level, blocksize):
from torchao.sparsity import SupermaskLinear

M, N = model[0].weight.shape
sparsify_(
model,
lambda x: SupermaskLinear.from_linear(
x, sparsity_level=sparsity_level, blocksize=blocksize
),
model[0] = SupermaskLinear.from_linear(
model[0], sparsity_level=sparsity_level, blocksize=blocksize
)
sparsify_(model, SupermaskLinear.to_linear)
model[0] = SupermaskLinear.to_linear(model[0])
weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize)

# Test correct sparsity level
Expand Down
8 changes: 4 additions & 4 deletions torchao/sparsity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,22 @@ quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
### 2:4 sparsity

```py
from torchao.sparsity.sparse_api import sparsify_, semi_sparse_weight
from torchao.sparsity.sparse_api import sparsify_, SemiSparseWeightConfig
from torchao.dtypes import SemiSparseLayout

model = model.cuda()
sparsify_(model, semi_sparse_weight())
sparsify_(model, SemiSparseWeightConfig())
```

### Block sparsity
We offer prototype support for accelerating block sparsity with our triton kernels for bfloat16/float16 workloads.

```py
from torchao.sparsity.sparse_api import sparsify_
from torchao.sparsity import block_sparse_weight
from torchao.sparsity import BlockSparseWeightConfig

model = model.cuda()
sparsify_(model, block_sparse_weight())
sparsify_(model, BlockSparseWeightConfig())
```

# Goal
Expand Down
65 changes: 52 additions & 13 deletions torchao/sparsity/sparse_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from functools import partial
import types
from dataclasses import dataclass
from typing import Callable, Optional

import torch
from torch.sparse import to_sparse_semi_structured

from torchao.core.config import AOBaseConfig
from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import (
WeightNormSparsifier,
)
from torchao.quantization.quant_api import (
_get_linear_subclass_inserter,
_is_linear,
_linear_extra_repr,
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.sparsity.blocksparse import BlockSparseTensor


Expand All @@ -35,22 +41,53 @@ def apply_fake_sparsity(model, **kwargs):
sparsifier.squash_mask()


def block_sparse_weight(blocksize=64):
return _get_linear_subclass_inserter(
partial(BlockSparseTensor.from_dense, blocksize=blocksize)
)
@dataclass
class BlockSparseWeightConfig(AOBaseConfig):
blocksize: int = 64


# for bc
block_sparse_weight = BlockSparseWeightConfig


@register_quantize_module_handler(BlockSparseWeightConfig)
def _block_sparse_weight_transform(
module: torch.nn.Module,
config: BlockSparseWeightConfig,
):
blocksize = config.blocksize
new_weight = BlockSparseTensor.from_dense(module.weight, blocksize)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


def semi_sparse_weight():
class SemiSparseWeightConfig(AOBaseConfig):
"""
Convert the weight of linear moduels to semi-structured (2:4) sparsity
Configuration for converting the weight of linear modules to semi-structured (2:4) sparsity
"""
return _get_linear_subclass_inserter(to_sparse_semi_structured)

pass


# for bc
semi_sparse_weight = SemiSparseWeightConfig


@register_quantize_module_handler(SemiSparseWeightConfig)
def _semi_sparse_weight_transform(
module: torch.nn.Module,
config: SemiSparseWeightConfig,
) -> torch.nn.Module:
new_weight = to_sparse_semi_structured(module.weight)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


def sparsify_(
model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
config: AOBaseConfig,
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`.
Expand All @@ -63,8 +100,8 @@ def sparsify_(

Args:
model (torch.nn.Module): input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module
config (AOBaseConfig): a workflow configuration object
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to apply the specified workflow to this module.

**Example:**
::
Expand All @@ -85,8 +122,10 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
from torchao.dtypes import SemiSparseLayout
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
"""
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
_replace_with_custom_fn_if_matches_filter(
model,
apply_tensor_subclass,
handler,
_is_linear if filter_fn is None else filter_fn,
extra_args=(config,),
)
Loading