Skip to content

Commit 6b3cf89

Browse files
authored
Implement sparsity as a AQT Layout (pytorch#498)
Summary: This PR adds in sparsity as an AQTLayout, previously it was implemented using the QuantizedLinearBase subclass that will be deprecated shortly. I also added renamed `sparsify` to `sparsify_` and added in a `semi_sparse_weight()` function to be in line with our other APIs. The main code changes are in `torchao/dtypes/affine_quantized_tensor.py`, for the semi-structured cusparselt representation, we can reuse a lot of the existing PlainLayout implementation, since the compressed representation is stored in a single tensor like `int_data`. Test Plan: ``` python test/sparsity/test_sparse_api ```
1 parent f8472f1 commit 6b3cf89

File tree

12 files changed

+168
-383
lines changed

12 files changed

+168
-383
lines changed

README.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,18 @@ And a quick crash course on inference quantization to help parse the above table
4949

5050
Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
5151
```python
52-
from torchao.sparsity import sparsify
53-
from torch.sparse import to_sparse_semi_structured
52+
from torchao.sparsity import sparsify, semi_sparse_weight()
5453

55-
m = sparsify(m, to_sparse_semi_structured)
54+
m = sparsify_(m, semi_sparse_weight())
5655
```
5756
Sparsity can also be composed with int8 dynamic quantization for further speedups:
5857

5958
```python
60-
from torchao.sparsity import sparsify
61-
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
59+
from torchao.sparsity import sparsify, int8_dynamic_activation_int8_semi_sparse_weight
6260

63-
m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
61+
m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight())
6462
```
65-
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
63+
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
6664
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.
6765

6866
The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:

scripts/sam/benchmark.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
88
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
99
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
1010
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
11-

scripts/sam/eval_combo.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import time
1010
import resource
1111

12+
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
13+
from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight
14+
from torchao.utils import unwrap_tensor_subclass
15+
1216
torch._dynamo.config.cache_size_limit = 50000
1317

1418
def unbind_jagged(device, data, sizes, offsets):
@@ -279,30 +283,17 @@ def run(
279283
block.attn.use_rel_pos = use_rel_pos
280284

281285
if compress == "int8_dynamic_quant":
282-
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
283-
from torchao.utils import unwrap_tensor_subclass
284286
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
285287
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
286288
elif compress == "sparse_mlp_only":
287289
def mlp_only(mod, name):
288290
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
289-
from torchao.sparsity import sparsify
290-
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
291291
apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
292-
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
292+
sparsify_(predictor.model.image_encoder, semi_sparse_weight(), filter_fn=mlp_only)
293293
elif compress == "sparse":
294-
from torchao.sparsity import sparsify
295-
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
296294
apply_fake_sparsity(predictor.model.image_encoder)
297-
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
295+
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
298296
elif compress == "int8_dynamic_quant_sparse":
299-
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
300-
SparseSemiStructuredTensor._FORCE_CUTLASS = False
301-
from torchao.sparsity import sparsify, apply_fake_sparsity
302-
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
303-
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
304-
from torchao.utils import unwrap_tensor_subclass
305-
306297
def attn_only(mod, name):
307298
return isinstance(mod, torch.nn.Linear) and 'attn' in name
308299
def mlp_lin1_only(mod, name):
@@ -316,20 +307,17 @@ def mlp_only(mod, name):
316307
apply_fake_sparsity(predictor.model.image_encoder,
317308
filter_fn=mlp_only)
318309

319-
quantize_(
320-
predictor.model.image_encoder,
321-
int8_dynamic_activation_int8_weight(),
322-
attn_only
323-
)
310+
quantize_(predictor.model.image_encoder,
311+
int8_dynamic_activation_int8_weight(),
312+
attn_only)
313+
quantize_(predictor.model.image_encoder,
314+
int8_dynamic_activation_int8_semi_sparse_weight(),
315+
mlp_lin1_only)
316+
sparsify_(predictor.model.image_encoder,
317+
semi_sparse_weight(),
318+
mlp_lin2_only)
324319
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
325320

326-
predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
327-
int8_dynamic_activation_int8_2x4_sparse_weight(),
328-
mlp_lin1_only, prune=False)
329-
330-
predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
331-
to_sparse_semi_structured,
332-
mlp_lin2_only, prune=False)
333321
else:
334322
assert compress is None, f"Unsupported compress mode {compress}"
335323

@@ -413,6 +401,6 @@ def mlp_only(mod, name):
413401
vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile,
414402
use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path]))
415403
f.write(vals+"\n")
416-
404+
417405
if __name__ == '__main__':
418406
fire.Fire(run)

scripts/sam/results.csv

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
2-
cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
3-
cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
4-
cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
5-
cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
6-
cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
2+
cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
3+
cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
4+
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
5+
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
6+
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None

test/sparsity/test_sparse_api.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
1+
import copy
12
import logging
23
import unittest
34

45
import torch
56
from torch import nn
6-
from torch.sparse import to_sparse_semi_structured
77

8-
from torchao.sparsity import apply_fake_sparsity, sparsify
9-
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
8+
from torchao.sparsity import (
9+
apply_fake_sparsity,
10+
sparsify_,
11+
int8_dynamic_activation_int8_semi_sparse_weight,
12+
semi_sparse_weight,
13+
)
1014
from torchao.quantization.quant_api import (
1115
_replace_with_custom_fn_if_matches_filter,
1216
_get_subclass_inserter,
1317
_is_linear,
18+
int8_dynamic_activation_int8_weight,
19+
quantize_,
1420
)
15-
from torchao.utils import TORCH_VERSION_AFTER_2_3
21+
from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass
1622
from torch.testing._internal.common_utils import TestCase
1723

1824

@@ -38,12 +44,11 @@ def test_sparse(self):
3844
apply_fake_sparsity(model)
3945
dense_result = model(input)
4046

41-
model = sparsify(model, to_sparse_semi_structured)
47+
sparsify_(model, semi_sparse_weight())
4248
sparse_result = model(input)
4349

4450
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
4551

46-
4752
class TestQuantSemiSparse(TestCase):
4853

4954
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
@@ -58,15 +63,15 @@ def test_quant_semi_sparse(self):
5863
.half()
5964
.cuda()
6065
)
61-
6266
apply_fake_sparsity(model)
63-
dense_result = model(input)
67+
model_copy = copy.deepcopy(model)
68+
quantize_(model_copy, int8_dynamic_activation_int8_weight())
69+
dense_result = model_copy(input)
6470

65-
sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
71+
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
6672
sparse_result = model(input)
6773

68-
assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)
69-
74+
assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
7075

7176
if __name__ == "__main__":
7277
unittest.main()

torchao/dtypes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
to_affine_quantized_static,
88
LayoutType,
99
PlainLayoutType,
10+
SemiSparseLayoutType,
1011
TensorCoreTiledLayoutType,
1112
)
1213

@@ -19,5 +20,6 @@
1920
"to_affine_quantized_static",
2021
"LayoutType",
2122
"PlainLayoutType",
23+
"SemiSparseLayoutType",
2224
"TensorCoreTiledLayoutType",
2325
]

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@
3232
class PlainLayoutType(LayoutType):
3333
pass
3434

35+
@dataclass(frozen=True)
36+
class SemiSparseLayoutType(LayoutType):
37+
38+
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
39+
# prune to 2:4 if not already
40+
temp = input.detach()
41+
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
42+
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
43+
return temp
44+
45+
3546
@dataclass(frozen=True)
3647
class TensorCoreTiledLayoutType(LayoutType):
3748
inner_k_tiles: int = 8
@@ -473,6 +484,47 @@ def from_plain(
473484
assert isinstance(layout_type, PlainLayoutType)
474485
return cls(int_data, scale, zero_point, layout_type)
475486

487+
@register_layout_cls(SemiSparseLayoutType)
488+
class SemiSparseAQTLayout(PlainAQTLayout):
489+
"""
490+
Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor
491+
"""
492+
@classmethod
493+
def __torch_dispatch__(cls, func, types, args, kwargs):
494+
kwargs = {} if kwargs is None else kwargs
495+
496+
if func is aten.detach.default:
497+
return return_and_correct_aliasing(
498+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
499+
)
500+
501+
raise NotImplementedError(
502+
f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported"
503+
)
504+
505+
def get_plain(self):
506+
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
507+
# the identity matrix to get the original dense matrix. This is slow though.
508+
cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0])
509+
int_data_expanded = torch._cslt_sparse_mm(self.int_data,
510+
torch.eye(cols,
511+
dtype=self.int_data.dtype,
512+
device=self.int_data.device).t())
513+
return int_data_expanded, self.scale, self.zero_point
514+
515+
@classmethod
516+
def from_plain(
517+
cls,
518+
int_data: torch.Tensor,
519+
scale: torch.Tensor,
520+
zero_point: torch.Tensor,
521+
layout_type: LayoutType,
522+
):
523+
assert isinstance(layout_type, SemiSparseLayoutType)
524+
int_data_compressed = torch._cslt_compress(int_data)
525+
return cls(int_data_compressed, scale, zero_point, layout_type)
526+
527+
476528
@register_layout_cls(TensorCoreTiledLayoutType)
477529
class TensorCoreTiledAQTLayout(AQTLayout):
478530
"""
@@ -669,6 +721,31 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
669721
if bias is not None:
670722
y += bias
671723
return y
724+
# handle int8 dynamic_quant + semi_structured_sparse
725+
elif(
726+
is_cuda and
727+
input_is_int8 and
728+
input_tensor.dtype == weight_qtensor.dtype and
729+
isinstance(input_tensor.layout_type, PlainLayoutType) and
730+
isinstance(weight_qtensor.layout_type, SemiSparseLayoutType)
731+
):
732+
x_vals_int8 = input_tensor.layout_tensor.int_data
733+
x_scales = input_tensor.layout_tensor.scale
734+
w_vals_int8 = weight_qtensor.layout_tensor.int_data
735+
w_scales = weight_qtensor.layout_tensor.scale
736+
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
737+
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
738+
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
739+
w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16
740+
).t()
741+
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
742+
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
743+
)
744+
output_dtype = input_tensor.dtype
745+
y = y.to(output_dtype)
746+
if bias is not None:
747+
y += bias
748+
return y
672749
else:
673750
input_tensor = input_tensor.dequantize()
674751

torchao/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"quantize_",
3333
"int8_dynamic_activation_int4_weight",
3434
"int8_dynamic_activation_int8_weight",
35+
"int8_dynamic_activation_int8_semi_sparse_weight",
3536
"int4_weight_only",
3637
"int8_weight_only",
3738
]

torchao/quantization/quant_api.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
come along with it and because that is how we access the intended quantized
1515
and mixed GEMM kernels
1616
"""
17-
17+
from functools import partial
1818
import torch
1919
import torchao
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222
from typing import Any, Callable, Union, Dict, Optional
2323

24+
from torchao.dtypes import PlainLayoutType
2425
from torchao.utils import (
2526
TORCH_VERSION_AFTER_2_4,
2627
unwrap_tensor_subclass,
@@ -57,6 +58,7 @@
5758
"quantize_",
5859
"int8_dynamic_activation_int4_weight",
5960
"int8_dynamic_activation_int8_weight",
61+
"int8_dynamic_activation_int8_semi_sparse_weight",
6062
"int4_weight_only",
6163
"int8_weight_only",
6264
]
@@ -410,7 +412,8 @@ def apply_int8wo_quant(weight):
410412

411413
return _get_linear_subclass_inserter(apply_int8wo_quant)
412414

413-
def int8_dynamic_activation_int8_weight():
415+
416+
def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
414417
"""
415418
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
416419
quantization to linear layers
@@ -432,16 +435,31 @@ def get_weight_block_size(x):
432435
zero_point_dtype = torch.int64
433436

434437
# input settings
438+
def get_per_token_block_size(x):
439+
block_size = list(x.shape)
440+
for i in range(len(block_size)-1):
441+
block_size[i] = 1
442+
return block_size
443+
435444
input_mapping_type = MappingType.SYMMETRIC
436445
input_target_dtype = torch.int8
437446
input_eps = 1e-5
438447
input_quant_min = -127
439448
input_quant_max = 127
440-
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
449+
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
441450

442451
block_size = get_weight_block_size(weight)
443-
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
452+
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
444453
weight = to_linear_act_quantized(weight, input_quant_func)
445454
return weight
446455

447456
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)
457+
458+
459+
def int8_dynamic_activation_int8_semi_sparse_weight():
460+
"""
461+
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
462+
quantization + 2:4 sparsity to linear layers.
463+
"""
464+
from torchao.dtypes import SemiSparseLayoutType
465+
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())

torchao/sparsity/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,18 @@
66

77
from .wanda import WandaSparsifier # noqa: F403
88
from .utils import PerChannelNormObserver # noqa: F403
9-
from .sparse_api import apply_fake_sparsity, sparsify
9+
from .sparse_api import (
10+
apply_fake_sparsity,
11+
sparsify_,
12+
semi_sparse_weight,
13+
int8_dynamic_activation_int8_semi_sparse_weight
14+
)
1015

1116
__all__ = [
1217
"WandaSparsifier",
1318
"PerChannelNormObserver",
1419
"apply_fake_sparsity",
15-
"sparsify"
20+
"sparsify_"
21+
"semi_sparse_weight",
22+
"int8_dynamic_activation_int8_semi_sparse_weight"
1623
]

0 commit comments

Comments
 (0)