Skip to content

Commit 2dea315

Browse files
Fix failing FP6 benchmark (#931)
1 parent 7dff17a commit 2dea315

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import torch
22
import pandas as pd
33
import torch.nn.functional as F
4-
from torchao.dtypes import to_affine_quantized_floatx
4+
from torchao.dtypes import to_affine_quantized_fpx
55
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
66
from torchao.utils import benchmark_torch_function_in_microseconds
77
from tqdm import tqdm
88

99

1010
def benchmark(m: int, k: int, n: int):
1111
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
12-
fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2))
12+
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2))
1313
fp16_weight = fp6_weight.dequantize(torch.half)
1414

1515
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")

0 commit comments

Comments
 (0)