|
8 | 8 | PerRow,
|
9 | 9 | quantize_,
|
10 | 10 | )
|
| 11 | +from torchao.quantization.quant_api import _float8_cutlass_quant |
11 | 12 |
|
12 | 13 | torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True
|
13 | 14 |
|
@@ -141,3 +142,68 @@ def srelu_linear(x):
|
141 | 142 | custom_output = reference_linear_copy(input_tensor)
|
142 | 143 |
|
143 | 144 | torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)
|
| 145 | + |
| 146 | + |
| 147 | +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") |
| 148 | +def test_sparse24_fp8_sm90_cutlass_gemm_eye( |
| 149 | + M=512, K=256, dtype=torch.float8_e4m3fn |
| 150 | +) -> None: |
| 151 | + torch.manual_seed(0) |
| 152 | + |
| 153 | + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() |
| 154 | + A_aqt = _float8_cutlass_quant(A_dense, dtype) |
| 155 | + A = A_aqt.tensor_impl.float8_data |
| 156 | + |
| 157 | + # NOTE: CUTLASS compression kernel expects the input to be *exactly* |
| 158 | + # 2:4 sparse already (eg it does not select the largest values) |
| 159 | + A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) |
| 160 | + assert torch.allclose( |
| 161 | + A_packed.float().sum(), A.float().sum() |
| 162 | + ) # Check all values are there |
| 163 | + |
| 164 | + # Check MM without scale |
| 165 | + eye = torch.eye(A.shape[1], device=A.device, dtype=A.dtype).T |
| 166 | + A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( |
| 167 | + A_packed, A_mdata, eye |
| 168 | + ) |
| 169 | + assert torch.allclose(A.float(), A_reconstructed.float()) |
| 170 | + |
| 171 | + # Check MM with scale |
| 172 | + b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32) |
| 173 | + a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32) |
| 174 | + A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( |
| 175 | + A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale |
| 176 | + ) |
| 177 | + assert torch.allclose( |
| 178 | + A.float() * b_scale * a_scale, A_reconstructed.float(), rtol=0.01 |
| 179 | + ) |
| 180 | + |
| 181 | + |
| 182 | +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") |
| 183 | +def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor( |
| 184 | + M=512, N=1024, K=256, dtype=torch.float8_e4m3fn |
| 185 | +) -> None: |
| 186 | + def _to_fp8_rowwise(x: torch.Tensor, dtype): |
| 187 | + max_v = torch.finfo(dtype).max |
| 188 | + x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() |
| 189 | + x = (x / x_scale).to(dtype) |
| 190 | + return x, x_scale |
| 191 | + |
| 192 | + torch.manual_seed(0) |
| 193 | + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() |
| 194 | + A, a_scale = _to_fp8_rowwise(A_dense, dtype) |
| 195 | + |
| 196 | + B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16) |
| 197 | + B, b_scale = _to_fp8_rowwise(B_dense, dtype) |
| 198 | + |
| 199 | + B = B.T |
| 200 | + b_scale = b_scale.T |
| 201 | + |
| 202 | + A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) |
| 203 | + out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( |
| 204 | + A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale |
| 205 | + ) |
| 206 | + out_ref = torch._scaled_mm( |
| 207 | + A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype |
| 208 | + ) |
| 209 | + assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) |
0 commit comments