4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Optional
8
+
7
9
import torch
8
10
9
11
BYTES_PER_EL_FLOAT8 = 1
@@ -55,29 +57,67 @@ def get_specs():
55
57
def get_tensor_memory_traffic_bytes (
56
58
dim0 ,
57
59
dim1 ,
60
+ float8_recipe_name : Optional [str ],
61
+ mx_recipe_name : Optional [str ],
58
62
fuse_with_prev = False ,
59
63
):
60
64
# assumes input bf16, output f8
61
65
numel = dim0 * dim1
62
66
63
- # x_bf16 = ...
64
- # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
65
- # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
66
- # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
67
+ if float8_recipe_name == "tensorwise" :
68
+ # x_bf16 = ...
69
+ # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
70
+ # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
71
+ # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
72
+
73
+ if fuse_with_prev :
74
+ kernel_1_rw = 0
75
+ else :
76
+ # kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
77
+ kernel_1_rw = BYTES_PER_EL_BF16 * numel
78
+
79
+ # kernel 3: read in bf16, write twice in float8 (row-major and col-major)
80
+ kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
81
+
82
+ return kernel_1_rw + kernel_3_rw
83
+
84
+ elif float8_recipe_name == "rowwise" :
85
+ # x_bf16 = ...
86
+ # kernel 1: x_bf16 -> x_float8_dim0
87
+ # kernel 2: x_bf16 -> x_float8_dim1
88
+
89
+ # assume that we can't fuse 1 and 2 because that would require loading
90
+ # the entire tensor to shared memory
91
+
92
+ if fuse_with_prev :
93
+ # assume we can fuse one of the reads with previous op
94
+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
95
+ else :
96
+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
97
+
98
+ kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
99
+
100
+ return kernel_1_rw + kernel_2_rw
67
101
68
- if fuse_with_prev :
69
- kernel_1_rw = 0
70
102
else :
71
- # kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
72
- kernel_1_rw = BYTES_PER_EL_BF16 * numel
103
+ assert mx_recipe_name in ("mxfp8_emulated" , "mxfp8_cutlass" ), "unsupported"
73
104
74
- # kernel 3: read in bf16, write twice in float8 (row-major and col-major)
75
- kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
105
+ # x_bf16 = ...
106
+ # kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1
76
107
77
- return kernel_1_rw + kernel_3_rw
108
+ if fuse_with_prev :
109
+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2
110
+ else :
111
+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2
112
+
113
+ return kernel_1_rw
78
114
79
115
80
116
def get_gemm_time_sympy (M , K , N , dtype ):
117
+ # currently this assumes gemm is compute bound
118
+ # TODO(future): maybe make more accurate for small shapes by taking max of
119
+ # time to read/write and time to do the dot product, this might also
120
+ # slightly differ for MX since scales are larger
81
121
specs = get_specs ()
82
122
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
83
123
if dtype is torch .bfloat16 :
@@ -89,9 +129,7 @@ def get_gemm_time_sympy(M, K, N, dtype):
89
129
90
130
91
131
def get_float8_mem_sympy (
92
- M ,
93
- K ,
94
- N ,
132
+ M , K , N , float8_recipe_name : Optional [str ], mx_recipe_name : Optional [str ]
95
133
):
96
134
specs = get_specs ()
97
135
@@ -112,11 +150,15 @@ def get_float8_mem_sympy(
112
150
fwd_fp8_input_mem = get_tensor_memory_traffic_bytes (
113
151
M ,
114
152
K ,
153
+ float8_recipe_name ,
154
+ mx_recipe_name ,
115
155
fuse_with_prev = True ,
116
156
)
117
157
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes (
118
158
K ,
119
159
N ,
160
+ float8_recipe_name ,
161
+ mx_recipe_name ,
120
162
fuse_with_prev = False ,
121
163
)
122
164
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem
@@ -127,6 +169,8 @@ def get_float8_mem_sympy(
127
169
gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes (
128
170
M ,
129
171
N ,
172
+ float8_recipe_name ,
173
+ mx_recipe_name ,
130
174
fuse_with_prev = True ,
131
175
)
132
176
# already casted, assuming that we save weight from fw to bw
@@ -158,12 +202,20 @@ def get_float8_mem_sympy(
158
202
# kernel overhead in the units of seconds, and the per-gemm-input memory
159
203
# estimations are in the units of bytes.
160
204
num_extra_kernels = 0
161
- # second stage of max-abs reduction for input
162
- num_extra_kernels += 1
163
- # second stage of max-abs reduction for weight
164
- num_extra_kernels += 1
165
- # second stage of max-abs reduction for grad_output
166
- num_extra_kernels += 1
205
+ if float8_recipe_name == "tensorwise" :
206
+ # second stage of max-abs reduction for input
207
+ num_extra_kernels += 1
208
+ # second stage of max-abs reduction for weight
209
+ num_extra_kernels += 1
210
+ # second stage of max-abs reduction for grad_output
211
+ num_extra_kernels += 1
212
+ elif float8_recipe_name == "rowwise" :
213
+ # for simplicity, assume all rowwise kernels are large and bandwidth bound
214
+ pass
215
+ else :
216
+ assert mx_recipe_name in ("mxfp8_emulated" , "mxfp8_cutlass" ), "unsupported"
217
+ # for simplicity, assume all mxfp8 kernels are large and bandwidth bound
218
+ pass
167
219
168
220
extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
169
221
0 commit comments