Skip to content

Commit 4c6188f

Browse files
authored
[sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity (#2242)
* [sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity Summary: We have this gemm already in torchao, but for weight sparsity. For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode. Test Plan: Reviewers: Subscribers: Tasks: Tags: * remove cutlass compression * ruff fix * one more ruff fix * don't build for CUDA 11.8 * fix formatting * ifdef to avoid issues
1 parent f0f976c commit 4c6188f

File tree

4 files changed

+443
-0
lines changed

4 files changed

+443
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def get_extensions():
433433
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
434434
),
435435
os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"),
436+
os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"),
436437
]
437438
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
438439
cutlass_90a_sources.append(

test/sparsity/test_activation24.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
PerRow,
99
quantize_,
1010
)
11+
from torchao.quantization.quant_api import _float8_cutlass_quant
1112

1213
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True
1314

@@ -141,3 +142,68 @@ def srelu_linear(x):
141142
custom_output = reference_linear_copy(input_tensor)
142143

143144
torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)
145+
146+
147+
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
148+
def test_sparse24_fp8_sm90_cutlass_gemm_eye(
149+
M=512, K=256, dtype=torch.float8_e4m3fn
150+
) -> None:
151+
torch.manual_seed(0)
152+
153+
A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
154+
A_aqt = _float8_cutlass_quant(A_dense, dtype)
155+
A = A_aqt.tensor_impl.float8_data
156+
157+
# NOTE: CUTLASS compression kernel expects the input to be *exactly*
158+
# 2:4 sparse already (eg it does not select the largest values)
159+
A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A)
160+
assert torch.allclose(
161+
A_packed.float().sum(), A.float().sum()
162+
) # Check all values are there
163+
164+
# Check MM without scale
165+
eye = torch.eye(A.shape[1], device=A.device, dtype=A.dtype).T
166+
A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
167+
A_packed, A_mdata, eye
168+
)
169+
assert torch.allclose(A.float(), A_reconstructed.float())
170+
171+
# Check MM with scale
172+
b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32)
173+
a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32)
174+
A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm(
175+
A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale
176+
)
177+
assert torch.allclose(
178+
A.float() * b_scale * a_scale, A_reconstructed.float(), rtol=0.01
179+
)
180+
181+
182+
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
183+
def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor(
184+
M=512, N=1024, K=256, dtype=torch.float8_e4m3fn
185+
) -> None:
186+
def _to_fp8_rowwise(x: torch.Tensor, dtype):
187+
max_v = torch.finfo(dtype).max
188+
x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float()
189+
x = (x / x_scale).to(dtype)
190+
return x, x_scale
191+
192+
torch.manual_seed(0)
193+
A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
194+
A, a_scale = _to_fp8_rowwise(A_dense, dtype)
195+
196+
B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16)
197+
B, b_scale = _to_fp8_rowwise(B_dense, dtype)
198+
199+
B = B.T
200+
b_scale = b_scale.T
201+
202+
A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A)
203+
out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
204+
A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale
205+
)
206+
out_ref = torch._scaled_mm(
207+
A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype
208+
)
209+
assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01)

0 commit comments

Comments
 (0)