@@ -48,39 +48,91 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
48
48
return time_sec , tops_sec , pct_top_peak
49
49
50
50
51
+ def get_name_to_shapes_iter (
52
+ shape_gen_name : str ,
53
+ M : Optional [int ],
54
+ K : Optional [int ],
55
+ N : Optional [int ],
56
+ ):
57
+ if shape_gen_name == 'llama' :
58
+ assert M == K == N == None , \
59
+ f'M, K, N arguments not supported for shape_gen_name { shape_gen_name } '
60
+ bsz , seq_len = 4 , 4096
61
+ M = bsz * seq_len
62
+ # LLaMa 2 70B single-node weight shapes
63
+ # assumes fused attn.wqkv and ffn.w13
64
+ # source: https://fburl.com/gsheet/g8onr7rh
65
+ name_to_shapes_70b = {
66
+ "attn.wqkv" : (M , 8192 , 1280 ),
67
+ "attn.w0" : (M , 1024 , 8192 ),
68
+ "ffn.w13" : (M , 8192 , 7168 ),
69
+ "ffn.w2" : (M , 3584 , 8192 ),
70
+ }
71
+ return name_to_shapes_70b .items ()
72
+
73
+ elif shape_gen_name == 'square' :
74
+ assert M == K == N == None , \
75
+ f'M, K, N arguments not supported for shape_gen_name { shape_gen_name } '
76
+ name_to_shapes = {}
77
+ min_power_of_2 = 5 # 32
78
+ max_power_of_2 = 16 # 65,536
79
+ for idx , power_of_2 in enumerate (range (min_power_of_2 , max_power_of_2 + 1 )):
80
+ val = 2 ** power_of_2
81
+ name_to_shapes [idx ] = val , val , val
82
+ return name_to_shapes .items ()
83
+
84
+ elif shape_gen_name == 'sweep' :
85
+ assert M == K == N == None , \
86
+ f'M, K, N arguments not supported for shape_gen_name { shape_gen_name } '
87
+ name_to_shapes = {}
88
+ min_p2 = 5 # 32
89
+ max_p2 = 16 # 65,536
90
+ counter = 0
91
+ for M_p2 in range (min_p2 , max_p2 + 1 ):
92
+ M = 2 ** M_p2
93
+ for K_p2 in range (min_p2 , max_p2 + 1 ):
94
+ K = 2 ** K_p2
95
+ for N_p2 in range (min_p2 , max_p2 + 1 ):
96
+ N = 2 ** N_p2
97
+ name_to_shapes [counter ] = M , K , N
98
+ counter += 1
99
+ return name_to_shapes .items ()
100
+
101
+ elif shape_gen_name == 'custom' :
102
+ assert M is not None and K is not None and N is not None , \
103
+ 'M, K, N must be specified for custom shape_gen'
104
+ name_to_shapes = {
105
+ 1 : (M , K , N ),
106
+ }
107
+ return name_to_shapes .items ()
108
+
109
+ raise AssertionError (f'unknown shape_gen_name { shape_gen_name } ' )
110
+
111
+
51
112
@torch .inference_mode ()
52
- def run (n_limit : Optional [int ] = None ):
113
+ def run (
114
+ n_limit : Optional [int ] = None ,
115
+ shape_gen_name : str = 'llama' ,
116
+ out_filename : Optional [str ] = None ,
117
+ M : Optional [int ] = None ,
118
+ K : Optional [int ] = None ,
119
+ N : Optional [int ] = None ,
120
+ ):
53
121
device = "cuda"
54
122
55
- # LLaMa 2 70B single-node weight shapes
56
- # assumes fused attn.wqkv and ffn.w13
57
- # source: https://fburl.com/gsheet/g8onr7rh
58
- name_to_shapes_70b = {
59
- "attn.wqkv" : (8192 , 1280 ),
60
- "attn.w0" : (1024 , 8192 ),
61
- "ffn.w13" : (8192 , 7168 ),
62
- "ffn.w2" : (3584 , 8192 ),
63
- }
64
-
65
- headers = ("name" , "shape" , "dtype" , "ref_time_s" , "fp8_time_s" , "fp8_speedup" )
123
+ headers = ("fast_accum" , "name" , "M" , "K" , "N" , "ref_time_s" , "fp8_time_s" , "fp8_speedup" )
66
124
results = []
67
125
68
- name_to_shapes = name_to_shapes_70b
69
- dtypes = torch .bfloat16 , torch .float16
126
+ dtype = torch .bfloat16
127
+ name_to_shapes = get_name_to_shapes_iter (shape_gen_name , M , K , N )
128
+ fast_accum_vals = [True , False ]
70
129
71
- for idx , (dtype , (name , (K , N ))) in enumerate (
72
- itertools .product (dtypes , name_to_shapes .items ())
73
- ):
130
+ for idx , (fast_accum , (name , (M , K , N ))) in enumerate (itertools .product (fast_accum_vals , name_to_shapes )):
74
131
if n_limit is not None and idx >= n_limit :
75
132
break
76
133
77
- # source: Xiao Sun, these are realistic for LLaMa 70B training
78
- bsz , seq_len = 4 , 4096
79
-
80
- M = bsz * seq_len
81
- print ("M, K, N:" , M , K , N )
82
134
tops = 2 * M * N * K
83
- print (f"tops: { tops :.2E} " )
135
+ print ("M, K, N:" , M , K , N , f"tops: { tops :.2E} " )
84
136
85
137
# raw torch.mm
86
138
A = torch .randn (M , K , device = device , dtype = dtype )
@@ -99,12 +151,12 @@ def run(n_limit: Optional[int] = None):
99
151
d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , dtype
100
152
A = torch .zeros (M , K , device = device , dtype = d1 )
101
153
B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
154
+ scale_a = torch .tensor ([1.0 ], device = device )
155
+ scale_b = torch .tensor ([1.0 ], device = device )
102
156
103
157
def do_matmul (A , B ):
104
- scale_a = torch .tensor ([1.0 ], device = device )
105
- scale_b = torch .tensor ([1.0 ], device = device )
106
158
return torch ._scaled_mm (
107
- A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = False
159
+ A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
108
160
)
109
161
110
162
fp8_time_sec , fp8_tops_sec , fp8_pct_top_peak = do_benchmarks (
@@ -114,22 +166,26 @@ def do_matmul(A, B):
114
166
f"fp8 time_sec { fp8_time_sec :.2E} , tops/sec { fp8_tops_sec :.2E} , pct_peak { fp8_pct_top_peak :.3f} "
115
167
)
116
168
117
- del A , B
169
+ del A , B , scale_a , scale_b
118
170
119
171
results .append (
120
172
[
173
+ fast_accum ,
121
174
name ,
122
- (M , K , N ),
123
- dtype ,
175
+ M ,
176
+ K ,
177
+ N ,
124
178
ref_time_sec ,
125
179
fp8_time_sec ,
126
180
ref_time_sec / fp8_time_sec ,
127
181
]
128
182
)
129
183
130
- data_pd = pd .DataFrame (results , columns = headers )
131
- print (data_pd )
184
+ data_df = pd .DataFrame (results , columns = headers )
185
+ print (data_df )
132
186
187
+ if out_filename is not None :
188
+ data_df .to_csv (out_filename )
133
189
134
190
def main () -> None :
135
191
fire .Fire (run )
0 commit comments