Skip to content

Add blocksparse_int_addmm. Eliminate unnecessary contiguous calls which leads to performance increase. #891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,33 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torch.testing._internal.common_utils import TestCase

from torch.ao.pruning import WeightNormSparsifier


logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

def apply_fake_block_sparsity(model, **kwargs):
"""
This function simulates 2:4 sparsity on all linear layers in a model.
It uses the torch.ao.pruning flow.
"""
filter_fn = kwargs.pop("filter_fn", _is_linear)
# torch.ao.pruning flow
sparse_config = []
for name, mod in model.named_modules():
if filter_fn(mod, name):
sparse_config.append({"tensor_fqn": f"{name}.weight"})

sparsifier = WeightNormSparsifier(
sparsity_level=0.5, sparse_block_shape=(64, 64)
)
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()


class TestSemiStructuredSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
Expand Down Expand Up @@ -73,5 +95,70 @@ def test_quant_semi_sparse(self):

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)



class TestBlockSparseWeight(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((1024, 1024)).half().cuda()
model = (
nn.Sequential(
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
)
.half()
.cuda()
)

from torchao.sparsity.utils import create_block_sparse_tensor
M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
sparsify_(model, block_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

class TestQuantBlockSparseWeight(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((128, 128)).to(torch.bfloat16).cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.to(torch.bfloat16)
.cuda()
)

from torchao.sparsity.utils import create_block_sparse_tensor
M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) * torch.rand(M, N, dtype=torch.bfloat16).cuda()
print(model[0].weight)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
print(model[1].weight)

model_copy = copy.deepcopy(model)

quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType(), ))
sparse_result = model(input)

print(reference)
print(sparse_result)
assert torch.allclose(reference, sparse_result, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
unittest.main()
19 changes: 10 additions & 9 deletions torchao/_models/sam/benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# baseline
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True
# int8 dynamic quant (all)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant
# 2:4 sparsity (all)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only
# 2:4 sparsity (mlp only)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True
## int8 dynamic quant (all)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant
## 2:4 sparsity (all)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only
## 2:4 sparsity (mlp only)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
## int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_block_sparse
6 changes: 5 additions & 1 deletion torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@ def mlp_only(mod, name):
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

elif compress == "int8_dynamic_quant_block_sparse":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType()), mlp_only)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
6 changes: 6 additions & 0 deletions torchao/_models/sam/results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
cuda,vit_h,32,15172,18,22.787559123509425,43.88359431477336,0.5809962729163862,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15153,18,24.872293344547476,40.20537978333312,0.5821541984818872,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15640,19,24.64409232721636,40.5776762528853,0.5674436009126148,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.710537332827382,40.46856555691013,0.530554119734646,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.5429434697436,37.67479673608557,0.566992236284673,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
Loading