16
16
get_name_to_shapes_iter ,
17
17
)
18
18
19
- from torchao .float8 .config import ScalingGranularity
20
19
from torchao .testing .float8 .roofline_utils import get_specs
21
20
22
21
@@ -53,15 +52,17 @@ def do_benchmarks(
53
52
@torch .inference_mode ()
54
53
def run (
55
54
n_limit : Optional [int ] = None ,
56
- shape_gen_name : str = "llama " ,
55
+ shape_gen_name : str = "pow2_extended " ,
57
56
out_filename : Optional [str ] = None ,
58
57
M : Optional [int ] = None ,
59
58
K : Optional [int ] = None ,
60
59
N : Optional [int ] = None ,
61
- use_gpu_kernel_time : bool = False ,
62
- scaling_granularity : str = "tensorwise" ,
60
+ use_gpu_kernel_time : bool = True ,
61
+ recipe : str = "tensorwise" ,
63
62
):
64
63
device = "cuda"
64
+ # TODO(future PR): this is ugly
65
+ assert recipe in ("tensorwise" , "rowwise" , "mxfp8_cublas" ), "unsupported"
65
66
66
67
specs = get_specs ()
67
68
bf16_peak_tops = specs ["bf16_peak_tops" ]
@@ -84,7 +85,6 @@ def run(
84
85
dtype = torch .bfloat16
85
86
name_to_shapes = get_name_to_shapes_iter (shape_gen_name , M , K , N )
86
87
fast_accum_vals = [True , False ]
87
- scaling_granularity = ScalingGranularity (scaling_granularity )
88
88
89
89
for idx , (fast_accum , (name , (M , K , N ))) in enumerate (
90
90
itertools .product (fast_accum_vals , name_to_shapes )
@@ -112,13 +112,17 @@ def run(
112
112
d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , dtype
113
113
A = torch .zeros (M , K , device = device , dtype = d1 )
114
114
B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
115
- if scaling_granularity == ScalingGranularity . TENSORWISE :
115
+ if recipe == "tensorwise" :
116
116
scale_a = torch .tensor ([1.0 ], device = device )
117
117
scale_b = torch .tensor ([1.0 ], device = device )
118
- else :
119
- assert scaling_granularity == ScalingGranularity .AXISWISE , "unsupported"
118
+ elif recipe == "rowwise" :
120
119
scale_a = torch .ones (M , 1 , device = device )
121
120
scale_b = torch .ones (1 , N , device = device )
121
+ elif recipe == "mxfp8_cublas" :
122
+ scale_a = torch .ones (M , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
123
+ scale_b = torch .ones (N , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
124
+ else :
125
+ assert False , f"unknown recipe { recipe } "
122
126
123
127
def do_matmul (A , B ):
124
128
nonlocal scale_a
0 commit comments