|
6 | 6 | from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
|
7 | 7 | from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
|
8 | 8 | from vllm.model_executor.layers.parameters import LazyCompressedParameter
|
9 |
| -from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat) |
| 9 | +from magic_wand.semi_structured import (pad_tensor_to_multiple, |
| 10 | + extract_valid_rows) |
| 11 | +from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat, |
| 12 | + SparseSemiStructuredStorageFormat) |
10 | 13 | from magic_wand.ops import be_ds_gemm
|
11 | 14 |
|
12 | 15 |
|
@@ -54,11 +57,18 @@ def apply_weights(
|
54 | 57 | if w.has_uncompressed_data:
|
55 | 58 | assert not w.has_compressed_data
|
56 | 59 | output = F.linear(x, w.uncompressed_data, bias)
|
57 |
| - # The current 2:4 implementation was running dense so ignore it |
58 |
| - # for now and instead just explicitly decompress as usual |
59 |
| - # elif self.storage_format_cls == SparseSemiStructuredStorageFormat: |
60 |
| - # assert bias is None |
61 |
| - # raise NotImplementedError |
| 60 | + elif self.storage_format_cls == SparseSemiStructuredStorageFormat: |
| 61 | + assert bias is None |
| 62 | + w_encap = w.compressed_data.encapsulated_torch_sparse_tensor |
| 63 | + out_shape = (x.shape[:-1] + (w_encap.shape[0], )) |
| 64 | + reshaped_x, valid_rows_range = pad_tensor_to_multiple( |
| 65 | + x.reshape(-1, x.shape[-1]), 8) |
| 66 | + output = F.linear( |
| 67 | + reshaped_x, w_encap, |
| 68 | + torch.nn.Parameter(torch.zeros((w_encap.shape[0], ))).to( |
| 69 | + reshaped_x.dtype).to(reshaped_x.device)).contiguous() |
| 70 | + output = extract_valid_rows(output, valid_rows_range) |
| 71 | + return output.reshape(out_shape) |
62 | 72 | elif self.storage_format_cls == SparseBEGemmStorageFormat:
|
63 | 73 | assert bias is None
|
64 | 74 | assert w.compress_transposed
|
|
0 commit comments