From 289c2e6e90592d2f31c6db9c6a22811198543b94 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 14 Feb 2024 17:53:05 -0500 Subject: [PATCH] Sparse fused gemm integration (#12) Summary: Initial integration for the sparse-fused gemm. To achieve this, we need to ensure that we compress the weight matrix only once and never decompress it, as decompression is currently unsupported. Before this change, using `SparseParameter(SparseTensor)` meant that in `MergedColumnParallelLinear` and `QKVParallelLinear` every time a new shard was loaded by the `weight_loader` (e.g., the "q" portion of `QKVParallelLinear`), we would decompress the tensor in-order to use narrow to update the appropriate section of the weight tensor. With this change, `SparseParameter(SparseTensor)` is replaced with `LazyCompressedParameter`, which allows us to operate on `uncompressed_data` until we explicitly compress it. At that point, the `uncompressed_data` is compressed into `compressed_data` and freed. Currently, the detection of when to call compress is somewhat hacky. For `QKVParallelLinear`, we compress only after inserting "q", "k", and "v" shard ids, and for `MergedColumnParallelLinear`, we compress once we've inserted the same number of shards as outputs (determined by `len(output_sizes)`), which implicitly assumes one shard per output. Moving away from `SparseParameter(SparseTensor)` means that `SparseTensor` no longer handles dispatching to the custom ops; instead, this is handled by `SparseW16A16LinearMethod`. I believe this is a positive change overall. `SparseTensor` was an unnecessary extra layer of abstraction/indirection originally designed for the SLoRA work, not vLLM. This did result in the 2:4 sparse implementation breaking. However, it turns out it was already broken (i.e., it was decompressing and running dense within `SparseTensor`), so we "disable" it for now ("disable" meaning decompress and run dense instead). We should revisit all of this infrastructure post-MVP. --------- Co-authored-by: Andrew Feldman --- vllm/model_executor/layers/linear.py | 54 ++++++------- .../layers/parameters/__init__.py | 13 +--- .../layers/parameters/lazy_compressed.py | 78 +++++++++++++++++++ .../layers/parameters/sparsity.py | 66 ---------------- .../layers/sparsity/sparse_w16a16.py | 4 +- .../sparsity/sparse_w16a16_linear_method.py | 47 +++++++---- vllm/model_executor/weight_utils.py | 9 +-- 7 files changed, 148 insertions(+), 123 deletions(-) create mode 100644 vllm/model_executor/layers/parameters/lazy_compressed.py delete mode 100644 vllm/model_executor/layers/parameters/sparsity.py diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5f4e47a05403f..49e05922443d2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,7 +13,7 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger -from vllm.model_executor.layers.parameters import SparseParameter, get_param_data +from vllm.model_executor.layers.parameters import LazyCompressedParameter logger = init_logger(__name__) @@ -192,7 +192,7 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) - param_data = get_param_data(param) + param_data = param.data if output_dim is not None: shard_size = param_data.shape[output_dim] @@ -202,9 +202,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If SparseParameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + if isinstance(param, LazyCompressedParameter): + param.compress() def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -253,6 +252,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): self.output_sizes = output_sizes + self.loaded_shards = set() tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, @@ -262,14 +262,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): - param_data = get_param_data(param) + param_data = param.data output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: - if isinstance(param, SparseParameter): - raise NotImplementedError( - "Passing loaded_shard_id=None not yet supported for SparseParameter" - ) - # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape @@ -316,12 +311,17 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + + self.loaded_shards.add(loaded_shard_id) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If Parameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + # This is super hacky for now but we basically want to only compress once all + # of the shards are loaded, right now we just check if the number of shards + # loaded matches the number of outputs expected, assuming one shard per output + all_shards_loaded = (len(self.loaded_shards) == len(self.output_sizes)) + if all_shards_loaded and isinstance(param, LazyCompressedParameter): + param.compress() class QKVParallelLinear(ColumnParallelLinear): @@ -365,6 +365,7 @@ def __init__( if total_num_kv_heads is None: total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads + self.loaded_shards = set() # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() self.num_heads = divide(self.total_num_heads, tp_size) @@ -385,14 +386,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): - param_data = get_param_data(param) + param_data = param.data output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: - if isinstance(param, SparseParameter): - raise NotImplementedError( - "Passing loaded_shard_id=None not yet supported for SparseParameter" - ) - # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape @@ -456,9 +452,14 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If SparseParameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + self.loaded_shards.add(loaded_shard_id) + + # This is super hacky for now but we basically want to only compress once + # all of the shards are loaded, for the QKV matrix this means + # loading shards "q", "k" and "v" + all_shards_loaded = (self.loaded_shards == set(["q", "k", "v"])) + if all_shards_loaded and isinstance(param, LazyCompressedParameter): + param.compress() class RowParallelLinear(torch.nn.Module): @@ -540,7 +541,7 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) - param_data = get_param_data(param) + param_data = param.data if input_dim is not None: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size @@ -549,9 +550,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If SparseParameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + if isinstance(param, LazyCompressedParameter): + param.compress() def forward(self, input_): # Set up backprop all-reduce. diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py index 2d41190087a0d..c05cdf56e27a4 100644 --- a/vllm/model_executor/layers/parameters/__init__.py +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -1,10 +1,5 @@ -import torch -from vllm.model_executor.layers.parameters.sparsity import SparseParameter +from vllm.model_executor.layers.parameters.lazy_compressed import LazyCompressedParameter - -def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: - """Gets parameter data in dense format.""" - if isinstance(param, SparseParameter): - return param.get_dense_data() - else: - return param.data +__all__ = [ + "LazyCompressedParameter", +] diff --git a/vllm/model_executor/layers/parameters/lazy_compressed.py b/vllm/model_executor/layers/parameters/lazy_compressed.py new file mode 100644 index 0000000000000..96e892a03d1fb --- /dev/null +++ b/vllm/model_executor/layers/parameters/lazy_compressed.py @@ -0,0 +1,78 @@ +import numpy +import torch +from torch.utils._pytree import tree_map + +from typing import Type +from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat) + + +class LazyCompressedParameter(torch.Tensor): + + @staticmethod + def __new__(cls, + uncompressed_data: torch.Tensor, + storage_format_cls: Type[ + CompressedStorageFormat] = SparseBitmaskStorageFormat, + compress_transposed: bool = False): + self = torch.Tensor._make_wrapper_subclass( + cls, + size=uncompressed_data.shape, + dtype=uncompressed_data.dtype, + requires_grad=False) + self.storage_format_cls = storage_format_cls + self.compressed_data = None + self.uncompressed_data = uncompressed_data + self.compress_transposed = compress_transposed + self._is_param = True + + return self + + @property + def has_compressed_data(self) -> bool: + return (self.compressed_data is not None) + + @property + def has_uncompressed_data(self) -> bool: + return (self.uncompressed_data is not None) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + ret_storage_format_cls = None + + def unwrap(e): + nonlocal ret_storage_format_cls + if isinstance(e, LazyCompressedParameter): + assert ret_storage_format_cls is None or ret_storage_format_cls == e.storage_format_cls + ret_storage_format_cls = e.storage_format_cls + return e.uncompressed_data if isinstance( + e, LazyCompressedParameter) else e + + rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + + def wrap(e): + if isinstance(e, + torch.Tensor) and ret_storage_format_cls is not None: + return LazyCompressedParameter( + e, storage_format_cls=ret_storage_format_cls) + return e + + rs = tree_map(wrap, rs) + return rs + + def compress(self) -> None: + density = torch.count_nonzero( + self.uncompressed_data).item() / numpy.prod(self.shape) + + # only compress if we have sufficient sparsity (>=45%), currently + # this applies globally across all formats including 2:4 + if (1 - density) < 0.45: + return + + if self.uncompressed_data is None: + raise ValueError( + "Called compress() but uncompressed_data does not exist.") + self.compressed_data = self.storage_format_cls.compress( + self.uncompressed_data.t( + ) if self.compress_transposed else self.uncompressed_data) + del self.uncompressed_data # free memory + self.uncompressed_data = None diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py deleted file mode 100644 index 017fb6b825965..0000000000000 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch - -from typing import Type -from magic_wand import (SparseTensor, CompressedStorageFormat, - SparseBitmaskStorageFormat) - - -class SparseParameter(SparseTensor): - - @staticmethod - def __new__(cls, - shape: torch.Size, - dtype: torch.dtype, - storage_format_cls: Type[ - CompressedStorageFormat] = SparseBitmaskStorageFormat): - assert torch.__version__ > (1, - 10), "SparseTensor requires PyTorch 1.11+" - - self = torch.Tensor._make_wrapper_subclass(cls, - size=shape, - dtype=dtype, - requires_grad=False) - self.storage_format_cls = storage_format_cls - self.compressed_data = None - self.dense_data = None - self._is_param = True - - return self - - def has_compressed_data(self) -> bool: - return (self.compressed_data is not None) - - def get_dense_data(self) -> torch.Tensor: - if self.dense_data is not None: - raise ValueError( - "Called get_data_dense() but dense_data already exists.") - self.dense_data = self._unpack() - return self.dense_data - - def _unpack(self) -> torch.Tensor: - if self.has_compressed_data(): - return self.compressed_data.decompress() - else: - return torch.empty(size=self.shape, - dtype=self.dtype, - device="cuda") - - @classmethod - def _copy(cls, arg0, arg1): - assert arg0.shape == arg1.shape - - if arg0.has_compressed_data(): - arg0.compressed_data.copy_(arg1) - else: - arg0.compressed_data = arg0.storage_format_cls.compress(arg1) - - return arg0 - - def copy_(self, src, non_blocking=False): - return SparseParameter._copy(self, src) - - def pack(self) -> None: - if self.dense_data is None: - raise ValueError("Called pack() but dense_data does not exist.") - self.copy_(self.dense_data) - self.dense_data = None diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py index 69905eab0c0af..d3a93d9b1d945 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -5,7 +5,7 @@ from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from .sparse_w16a16_linear_method import SparseW16A16LinearMethod -from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat) +from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat) class SparseW16A16Config(SparsityConfig): @@ -23,7 +23,7 @@ def __repr__(self) -> str: @classmethod def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]: - return SparseBitmaskStorageFormat + return SparseBEGemmStorageFormat @classmethod def get_name(cls) -> str: diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index e2fecda663b60..65713a1bf15b3 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -5,9 +5,9 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig -from vllm.model_executor.layers.parameters import SparseParameter -from magic_wand import (CompressedStorageFormat, - SparseSemiStructuredStorageFormat) +from vllm.model_executor.layers.parameters import LazyCompressedParameter +from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat) +from magic_wand.ops import be_ds_gemm class SparseW16A16LinearMethod(LinearMethodBase): @@ -27,10 +27,15 @@ def create_weights(self, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - weight = SparseParameter(shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - storage_format_cls=self.storage_format_cls) + supports_linear = (self.storage_format_cls != + SparseBEGemmStorageFormat) + weight = LazyCompressedParameter( + torch.empty((output_size_per_partition, input_size_per_partition), + dtype=params_dtype), + storage_format_cls=self.storage_format_cls, + # if we don't support F.linear or something analogous, + # transpose when we compress so we can use a basic matmul + compress_transposed=not supports_linear) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -42,14 +47,28 @@ def apply_weights( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - sparse_weight = weights["weight"] + w: LazyCompressedParameter = weights["weight"] - if self.storage_format_cls == SparseSemiStructuredStorageFormat: - output = F.linear(x, sparse_weight, bias) - return output + # if we never compressed (likely due to insufficient sparsity), + # i.e. have uncompressed_data run normally + if w.has_uncompressed_data: + assert not w.has_compressed_data + output = F.linear(x, w.uncompressed_data, bias) + # The current 2:4 implementation was running dense so ignore it + # for now and instead just explicitly decompress as usual + # elif self.storage_format_cls == SparseSemiStructuredStorageFormat: + # assert bias is None + # raise NotImplementedError + elif self.storage_format_cls == SparseBEGemmStorageFormat: + assert bias is None + assert w.compress_transposed + out_shape = (x.shape[:-1] + (w.shape[0], )) + reshaped_x = x.reshape(-1, x.shape[-1]) + y = be_ds_gemm(reshaped_x, w.compressed_data) + return y.reshape(out_shape) else: - # Standard matrix multiply # Uncompress to dense - output = F.linear(x, sparse_weight.to_dense(), bias) - return output + assert not w.compress_transposed + output = F.linear(x, w.compressed_data.decompress(), bias) + return output diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index f29b70ac26051..23c352c664d4b 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -19,8 +19,7 @@ QuantizationConfig) from vllm.model_executor.layers.sparsity import (get_sparsity_config, SparsityConfig) -from vllm.model_executor.layers.parameters import (get_param_data, - SparseParameter) +from vllm.model_executor.layers.parameters import LazyCompressedParameter logger = init_logger(__name__) @@ -299,9 +298,9 @@ def default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() - get_param_data(param).copy_(loaded_weight) - if isinstance(param, SparseParameter): - param.pack() + param.data.copy_(loaded_weight) + if isinstance(param, LazyCompressedParameter): + param.compress() def initialize_dummy_weights(