Skip to content

Commit d63e657

Browse files
committed
Update
[ghstack-poisoned]
2 parents fac3263 + d3306b2 commit d63e657

File tree

13 files changed

+645
-80
lines changed

13 files changed

+645
-80
lines changed

examples/sam2_amg_server/generate_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def latencies_statistics(data):
6060
mean = np.mean(data_array)
6161
# Calculate the median
6262
median = np.median(data_array)
63+
# Calculate the 90th percentile
64+
p90 = np.percentile(data_array, 90)
6365
# Calculate the 95th percentile
6466
p95 = np.percentile(data_array, 95)
6567
# Calculate the 99th percentile
@@ -74,6 +76,7 @@ def latencies_statistics(data):
7476
{
7577
"mean": mean,
7678
"median": median,
79+
"p90": p90,
7780
"p95": p95,
7881
"p99": p99,
7982
"p999": p999,

examples/sam2_amg_server/result.csv

Lines changed: 70 additions & 70 deletions
Large diffs are not rendered by default.

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,7 @@ def get_extensions():
215215
extra_link_args = []
216216
extra_compile_args = {
217217
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
218-
"nvcc": [
219-
"-O3" if not debug_mode else "-O0",
220-
"-t=0",
221-
],
218+
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
222219
}
223220

224221
if not IS_WINDOWS:
@@ -257,12 +254,16 @@ def get_extensions():
257254
use_cutlass = True
258255
cutlass_dir = os.path.join(third_party_path, "cutlass")
259256
cutlass_include_dir = os.path.join(cutlass_dir, "include")
257+
cutlass_tools_include_dir = os.path.join(
258+
cutlass_dir, "tools", "util", "include"
259+
)
260260
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
261261
if use_cutlass:
262262
extra_compile_args["nvcc"].extend(
263263
[
264264
"-DTORCHAO_USE_CUTLASS",
265265
"-I" + cutlass_include_dir,
266+
"-I" + cutlass_tools_include_dir,
266267
"-I" + cutlass_extensions_include_dir,
267268
]
268269
)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
import torch
3+
4+
from torchao.float8.float8_utils import compute_error
5+
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
6+
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
7+
from torchao.prototype.mx_formats.utils import to_blocked
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
9+
10+
if not TORCH_VERSION_AT_LEAST_2_4:
11+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
12+
13+
14+
def run_matrix_test(M: int, K: int, N: int, format) -> float:
15+
dtype = torch.bfloat16
16+
device = torch.device("cuda")
17+
18+
a = torch.rand((M, K), dtype=dtype, device=device)
19+
b = torch.rand((N, K), dtype=dtype, device=device)
20+
21+
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
22+
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
23+
24+
a_mx = MXTensor.to_mx(a, fmt, 32)
25+
b_mx = MXTensor.to_mx(b, fmt, 32)
26+
27+
a_data = a_mx._data
28+
b_data = b_mx._data
29+
assert b_data.is_contiguous()
30+
b_data = b_data.transpose(-1, -2)
31+
32+
a_scale = a_mx._scale_e8m0.view(M, K // 32)
33+
b_scale = b_mx._scale_e8m0.view(N, K // 32)
34+
35+
a_scale_block = to_blocked(a_scale)
36+
b_scale_block = to_blocked(b_scale)
37+
38+
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
39+
-1, -2
40+
)
41+
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)
42+
43+
return compute_error(out_hp, out).item()
44+
45+
46+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
47+
@pytest.mark.skipif(
48+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
49+
)
50+
@pytest.mark.parametrize(
51+
"size",
52+
[
53+
(128, 128, 128),
54+
(256, 256, 256),
55+
(384, 384, 384), # Small
56+
(512, 512, 512),
57+
(768, 768, 768), # Medium
58+
(1024, 1024, 1024),
59+
(8192, 8192, 8192), # Large
60+
(128, 256, 384),
61+
(256, 384, 512), # Non-square
62+
(129, 256, 384),
63+
(133, 512, 528), # Non-aligned
64+
],
65+
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
66+
)
67+
@pytest.mark.parametrize("format", ["fp8", "fp4"])
68+
def test_matrix_multiplication(size, format):
69+
M, K, N = size
70+
sqnr = run_matrix_test(M, K, N, format)
71+
threshold = 80.0
72+
assert (
73+
sqnr >= threshold
74+
), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}"

third_party/cutlass

Submodule cutlass updated 361 files

0 commit comments

Comments
 (0)