Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 6143ab9

Browse files
afeldman-nmLucasWilkinson
authored andcommitted
Abf149/fix semi structured sparse (#16)
SUMMARY: - Fix bug whereby 2:4 is not being invoked - Eschew SparseTensor based implementation TESTING: - examples/offline_inference_semi_structured_sparse.py --------- Co-authored-by: Lucas Wilkinson <wilkinson.lucas@gmail.com>
1 parent fbfd1aa commit 6143ab9

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
77
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
88
from vllm.model_executor.layers.parameters import LazyCompressedParameter
9-
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat)
9+
from magic_wand.semi_structured import (pad_tensor_to_multiple,
10+
extract_valid_rows)
11+
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat,
12+
SparseSemiStructuredStorageFormat)
1013
from magic_wand.ops import be_ds_gemm
1114

1215

@@ -54,11 +57,18 @@ def apply_weights(
5457
if w.has_uncompressed_data:
5558
assert not w.has_compressed_data
5659
output = F.linear(x, w.uncompressed_data, bias)
57-
# The current 2:4 implementation was running dense so ignore it
58-
# for now and instead just explicitly decompress as usual
59-
# elif self.storage_format_cls == SparseSemiStructuredStorageFormat:
60-
# assert bias is None
61-
# raise NotImplementedError
60+
elif self.storage_format_cls == SparseSemiStructuredStorageFormat:
61+
assert bias is None
62+
w_encap = w.compressed_data.encapsulated_torch_sparse_tensor
63+
out_shape = (x.shape[:-1] + (w_encap.shape[0], ))
64+
reshaped_x, valid_rows_range = pad_tensor_to_multiple(
65+
x.reshape(-1, x.shape[-1]), 8)
66+
output = F.linear(
67+
reshaped_x, w_encap,
68+
torch.nn.Parameter(torch.zeros((w_encap.shape[0], ))).to(
69+
reshaped_x.dtype).to(reshaped_x.device)).contiguous()
70+
output = extract_valid_rows(output, valid_rows_range)
71+
return output.reshape(out_shape)
6272
elif self.storage_format_cls == SparseBEGemmStorageFormat:
6373
assert bias is None
6474
assert w.compress_transposed

0 commit comments

Comments
 (0)