Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Loudly reject compression when the tensor isn't sparse enough #55

Merged
merged 2 commits into from
Feb 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions vllm/model_executor/layers/parameters/lazy_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from typing import Type

from vllm.logger import init_logger

logger = init_logger(__name__)

is_magic_wand_available = importlib.util.find_spec("magic_wand") is not None

# These are types from magic_wand, but we only want to import if required
Expand Down Expand Up @@ -91,13 +95,37 @@ def wrap(e):
return rs

def compress(self) -> None:
density = torch.count_nonzero(
self.uncompressed_data).item() / numpy.prod(self.shape)

# only compress if we have sufficient sparsity (>=45%), currently
# this applies globally across all formats including 2:4
if (1 - density) < 0.45:
return
from magic_wand import SparseSemiStructuredStorageFormat

if self.storage_format_cls == SparseSemiStructuredStorageFormat:
# Semi-structured sparsity assumes a 2:4 pattern, where each 4 elements
# have at minimum 2 zeros. We need to validate this pattern exists, so
# we check the whole tensor before committing to compression.

# Count zeros in each group of 4
reshaped_tensor = self.uncompressed_data.view(-1, 4)
zeros = reshaped_tensor == 0
zeros_per_group = zeros.sum(dim=1)

# Check if each group has exactly 2 zeros
has_semi_structured_sparsity = torch.all(zeros_per_group == 2)

if not has_semi_structured_sparsity:
logger.warning(
f"Called compress() on tensor of shape {self.shape} but does not "
"have 2:4 sparsity, skipping compression")
return

else:
sparsity = 1 - (torch.count_nonzero(self.uncompressed_data).item()
/ numpy.prod(self.shape))

# Only compress if we have sufficient sparsity (>=45%)
if sparsity < 0.45:
logger.warning(
f"Called compress() on tensor of shape {self.shape} but only has "
f"{sparsity:.2}% sparsity, skipping compression")
return

if self.uncompressed_data is None:
raise ValueError(
Expand Down
Loading