44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66import itertools
7+ from enum import IntEnum
78from typing import Optional
89
910import fire
2627h100_peak_flops_fp16_tc = 989e12
2728h100_peak_tops_float8_tc = 1979e12
2829
29- dtype_to_peak_tops = {
30+ # HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/
31+ # note: divided numbers from ^ by 2 to undo the effects of sparsity
32+ # TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8,
33+ # something seems funky
34+ b200_peak_flops_float32 = 600e12
35+ b200_peak_flops_fp16_tc = 18e15
36+ b200_peak_tops_float8_tc = 36e15
37+ b200_peak_tops_float4_tc = 72e15
38+
39+ dtype_to_peak_tops_h100 = {
3040 torch .float32 : h100_peak_flops_float32 ,
3141 torch .float16 : h100_peak_flops_fp16_tc ,
3242 torch .bfloat16 : h100_peak_flops_fp16_tc ,
3343 torch .float8_e4m3fn : h100_peak_tops_float8_tc ,
3444 torch .float8_e5m2 : h100_peak_tops_float8_tc ,
3545}
3646
47+ dtype_to_peak_tops_b200 = {
48+ torch .float32 : b200_peak_flops_float32 ,
49+ torch .float16 : b200_peak_flops_fp16_tc ,
50+ torch .bfloat16 : b200_peak_flops_fp16_tc ,
51+ torch .float8_e4m3fn : b200_peak_tops_float8_tc ,
52+ torch .float8_e5m2 : b200_peak_tops_float8_tc ,
53+ # TODO float4
54+ }
55+
56+ # TODO(this PR): switch automatically by detected hardware type
57+ # TODO(this PR): fp4 is currently using fp8's peak tops below, fix it
58+ dtype_to_peak_tops = dtype_to_peak_tops_b200
59+
60+
61+ # not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991
62+ class DataType (IntEnum ):
63+ DEFAULT = 0
64+ E8M0 = 1
65+ FP4 = 2
66+ UFP8 = 3
67+
3768
3869def benchmark_fn_in_sec (f , * args , ** kwargs ):
3970 # Manual warmup
@@ -75,6 +106,7 @@ def run(
75106 N : Optional [int ] = None ,
76107 use_gpu_kernel_time : bool = False ,
77108 scaling_granularity : str = "tensorwise" ,
109+ blockwise_dtype : Optional [str ] = None ,
78110):
79111 device = "cuda"
80112
@@ -85,15 +117,17 @@ def run(
85117 "K" ,
86118 "N" ,
87119 "ref_time_s" ,
88- "fp8_time_s " ,
89- "fp8_speedup " ,
120+ "lowp_time_s " ,
121+ "lowp_speedup " ,
90122 )
91123 results = []
92124
93125 dtype = torch .bfloat16
94126 name_to_shapes = get_name_to_shapes_iter (shape_gen_name , M , K , N )
95127 fast_accum_vals = [True , False ]
96- scaling_granularity = ScalingGranularity (scaling_granularity )
128+ # Note: blockwise not in enum because blockwise is in prototype
129+ if scaling_granularity != "blockwise" :
130+ scaling_granularity = ScalingGranularity (scaling_granularity )
97131
98132 for idx , (fast_accum , (name , (M , K , N ))) in enumerate (
99133 itertools .product (fast_accum_vals , name_to_shapes )
@@ -119,28 +153,97 @@ def run(
119153 # raw float8 matmul (upper bound for what we can achive in eager mode)
120154 # TODO(future): add e5m2
121155 d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , dtype
122- A = torch .zeros (M , K , device = device , dtype = d1 )
123- B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
156+ A = torch .randn (M , K , device = device ). to ( d1 )
157+ B = torch .randn (K , N , device = device ). to ( d2 ).t ().contiguous ().t ()
124158 if scaling_granularity == ScalingGranularity .TENSORWISE :
125159 scale_a = torch .tensor ([1.0 ], device = device )
126160 scale_b = torch .tensor ([1.0 ], device = device )
127- else :
128- assert scaling_granularity == ScalingGranularity .AXISWISE , "unsupported"
161+ elif scaling_granularity == ScalingGranularity .AXISWISE :
129162 scale_a = torch .ones (M , 1 , device = device )
130163 scale_b = torch .ones (1 , N , device = device )
164+ elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3" :
165+ # TODO(this PR): also block size 16
166+ BLOCK_SIZE = 32
167+ A = torch .randint (128 , (M , K ), device = device , dtype = torch .uint8 ).view (
168+ torch .float8_e4m3fn
169+ )
170+ B = (
171+ torch .randint (128 , (N , K ), device = device , dtype = torch .uint8 )
172+ .view (torch .float8_e4m3fn )
173+ .t ()
174+ )
175+ scale_a = torch .randint (
176+ 128 , (M , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
177+ )
178+ scale_b = torch .randint (
179+ 128 , (N , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
180+ ).t ()
181+ elif scaling_granularity == "blockwise" and blockwise_dtype == "float4" :
182+ # TODO(this PR): also block size 16
183+ BLOCK_SIZE = 16
184+ A = torch .randint (128 , (M , K // 2 ), device = device , dtype = torch .uint8 ).view (
185+ torch .float8_e4m3fn
186+ )
187+ B = (
188+ torch .randint (128 , (N , K // 2 ), device = device , dtype = torch .uint8 )
189+ .view (torch .float8_e4m3fn )
190+ .t ()
191+ )
192+ scale_a = torch .randint (
193+ 128 , (M , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
194+ )
195+ scale_b = torch .randint (
196+ 128 , (N , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
197+ ).t ()
198+ else :
199+ raise AssertionError (f"unsupported granularity { scaling_granularity } " )
131200
132201 def do_matmul (A , B ):
133202 nonlocal scale_a
134203 nonlocal scale_b
135- return torch ._scaled_mm (
136- A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
137- )
204+
205+ if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3" :
206+ return torch ._scaled_mm (
207+ A ,
208+ B ,
209+ scale_a ,
210+ scale_b ,
211+ bias = None ,
212+ scale_result = None ,
213+ out_dtype = d3 ,
214+ use_fast_accum = fast_accum ,
215+ a_dtype = None , # inferred from A
216+ b_dtype = None , # inferred from B
217+ scale_dtype = DataType .E8M0 ,
218+ )
219+ elif scaling_granularity == "blockwise" and blockwise_dtype == "float4" :
220+ return torch ._scaled_mm (
221+ A ,
222+ B ,
223+ scale_a ,
224+ scale_b ,
225+ bias = None ,
226+ scale_result = None ,
227+ out_dtype = d3 ,
228+ use_fast_accum = fast_accum ,
229+ a_dtype = DataType .FP4 ,
230+ b_dtype = DataType .FP4 ,
231+ scale_dtype = DataType .E8M0 ,
232+ )
233+
234+ else :
235+ return torch ._scaled_mm (
236+ A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
237+ )
238+
239+ # test
240+ # res = do_matmul(A, B)
138241
139242 fp8_time_sec , fp8_tops_sec , fp8_pct_top_peak = do_benchmarks (
140243 tops , dtype_to_peak_tops [d1 ], use_gpu_kernel_time , do_matmul , A , B
141244 )
142245 print (
143- f"fp8 time_sec { fp8_time_sec :.2E} , tops/sec { fp8_tops_sec :.2E} , pct_peak { fp8_pct_top_peak :.3f} "
246+ f"lowp time_sec { fp8_time_sec :.2E} , tops/sec { fp8_tops_sec :.2E} , pct_peak { fp8_pct_top_peak :.3f} "
144247 )
145248
146249 del A , B , scale_a , scale_b
0 commit comments