Skip to content

Commit 16f8c5d

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix_h100
2 parents 82acc5b + 446f07d commit 16f8c5d

32 files changed

+626
-441
lines changed

.github/workflows/regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
6565
gpu-arch-type: "cuda"
6666
gpu-arch-version: "12.6"
67-
dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/"
67+
dev-requirements-overrides: "s/^pytest.*$/pytest==7.4.0/"
6868
- name: CUDA 2.6
6969
runs-on: linux.g5.12xlarge.nvidia.gpu
7070
torch-spec: 'torch==2.6.0'
@@ -83,7 +83,7 @@ jobs:
8383
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu'
8484
gpu-arch-type: "cpu"
8585
gpu-arch-version: ""
86-
dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/"
86+
dev-requirements-overrides: "s/^pytest.*$/pytest==7.4.0/"
8787
- name: CPU 2.6
8888
runs-on: linux.4xlarge
8989
torch-spec: 'torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu'

README.md

Lines changed: 98 additions & 97 deletions
Large diffs are not rendered by default.

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Test utilities
2-
pytest
2+
pytest==8.3.4
33
unittest-xml-reporting
44
parameterized
55
packaging

e2e_fp8_sparse.csv

Lines changed: 0 additions & 8 deletions
This file was deleted.

rowwise_scaled_linear_sparse_cutlass_time_results.csv

Lines changed: 0 additions & 4 deletions
This file was deleted.

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def get_extensions():
433433
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
434434
),
435435
os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"),
436+
os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"),
436437
]
437438
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
438439
cutlass_90a_sources.append(

test/dtypes/test_affine_quantized.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,24 @@ def test_slice_and_copy_int4wo(self, device, dtype):
424424
# making sure param.data is updated
425425
assert param.data.dequantize()[0][0] != 0
426426

427+
@common_utils.parametrize("device", ["cuda"])
428+
@common_utils.parametrize("dtype", [torch.bfloat16])
429+
@skip_if_no_cuda()
430+
@skip_if_rocm("ROCm enablement in progress")
431+
def test_mm_int4wo(self, device, dtype):
432+
weight = torch.randn(512, 1024).to(device).to(dtype)
433+
weight = weight.t()
434+
435+
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
436+
l.weight = torch.nn.Parameter(weight)
437+
quantize_(l, Int4WeightOnlyConfig())
438+
# weight shape: 1024 x 512
439+
weight = l.weight
440+
441+
input = torch.randn(1, 512, device=device, dtype=dtype)
442+
# make sure it runs
443+
torch.nn.functional.linear(input, weight)
444+
427445

428446
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
429447
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

test/dtypes/test_affine_quantized_float.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from torchao.float8.float8_utils import compute_error
2929
from torchao.quantization import (
30+
Float8DynamicActivationFloat8WeightConfig,
3031
float8_dynamic_activation_float8_weight,
3132
float8_weight_only,
3233
quantize_,
@@ -308,6 +309,26 @@ def test_fp8_weight_dimension_warning(self):
308309
f"Expected warning message containing: {expected}",
309310
)
310311

312+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
313+
@unittest.skipIf(
314+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
315+
)
316+
def test_mm_float8dq(self):
317+
device = "cuda"
318+
dtype = torch.bfloat16
319+
weight = torch.randn(512, 1024).to(device).to(dtype)
320+
weight = weight.t()
321+
322+
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
323+
l.weight = torch.nn.Parameter(weight)
324+
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
325+
# weight shape: 1024 x 512
326+
weight = l.weight
327+
328+
input = torch.randn(1, 512, device=device, dtype=dtype)
329+
# make sure it runs
330+
torch.nn.functional.linear(input, weight)
331+
311332

312333
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
313334

test/quantization/test_config_serialization.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@
6363
GemliteUIntXWeightOnlyConfig(
6464
group_size=128, # Optional, has default of 64
6565
bit_width=8, # Optional, has default of 4
66-
packing_bitwidth=8, # Optional, has default of 32
67-
contiguous=True, # Optional, has default of None
6866
),
6967
FPXWeightOnlyConfig(ebits=4, mbits=8),
7068
# Sparsity configs

test/sparsity/test_activation24.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
PerRow,
99
quantize_,
1010
)
11+
from torchao.quantization.quant_api import _float8_cutlass_quant
1112

1213
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True
1314

@@ -141,3 +142,68 @@ def srelu_linear(x):
141142
custom_output = reference_linear_copy(input_tensor)
142143

143144
torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)
145+
146+
147+
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
148+
def test_sparse24_fp8_sm90_cutlass_gemm_eye(
149+
M=512, K=256, dtype=torch.float8_e4m3fn
150+
) -> None:
151+
torch.manual_seed(0)
152+
153+
A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
154+
A_aqt = _float8_cutlass_quant(A_dense, dtype)
155+
A = A_aqt.tensor_impl.float8_data
156+
157+
# NOTE: CUTLASS compression kernel expects the input to be *exactly*
158+
# 2:4 sparse already (eg it does not select the largest values)
159+
A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A)
160+
assert torch.allclose(
161+
A_packed.float().sum(), A.float().sum()
162+
) # Check all values are there
163+
164+
# Check MM without scale
165+
eye = torch.eye(A.shape[1], device=A.device, dtype=A.dtype).T
166+
A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
167+
A_packed, A_mdata, eye
168+
)
169+
assert torch.allclose(A.float(), A_reconstructed.float())
170+
171+
# Check MM with scale
172+
b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32)
173+
a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32)
174+
A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm(
175+
A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale
176+
)
177+
assert torch.allclose(
178+
A.float() * b_scale * a_scale, A_reconstructed.float(), rtol=0.01
179+
)
180+
181+
182+
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
183+
def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor(
184+
M=512, N=1024, K=256, dtype=torch.float8_e4m3fn
185+
) -> None:
186+
def _to_fp8_rowwise(x: torch.Tensor, dtype):
187+
max_v = torch.finfo(dtype).max
188+
x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float()
189+
x = (x / x_scale).to(dtype)
190+
return x, x_scale
191+
192+
torch.manual_seed(0)
193+
A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
194+
A, a_scale = _to_fp8_rowwise(A_dense, dtype)
195+
196+
B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16)
197+
B, b_scale = _to_fp8_rowwise(B_dense, dtype)
198+
199+
B = B.T
200+
b_scale = b_scale.T
201+
202+
A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A)
203+
out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
204+
A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale
205+
)
206+
out_ref = torch._scaled_mm(
207+
A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype
208+
)
209+
assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01)

0 commit comments

Comments
 (0)