@@ -55,14 +55,12 @@ def get_specs():
55
55
def get_tensor_memory_traffic_bytes (
56
56
dim0 ,
57
57
dim1 ,
58
- scaling_type : str ,
59
58
fuse_with_prev = False ,
60
59
model_torch_compile_limitations = False ,
61
60
):
62
61
# assumes input bf16, output f8
63
62
numel = dim0 * dim1
64
63
65
- assert scaling_type == "dynamic" , "unsupported"
66
64
# x_bf16 = ...
67
65
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
68
66
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
@@ -104,14 +102,7 @@ def get_float8_mem_sympy(
104
102
K ,
105
103
N ,
106
104
model_torch_compile_limitations : bool = False ,
107
- scaling_type_input : str = "dynamic" ,
108
- scaling_type_weight : str = "dynamic" ,
109
- scaling_type_grad_output : str = "dynamic" ,
110
105
):
111
- assert scaling_type_input in ("dynamic" ,), "unsupported"
112
- assert scaling_type_weight in ("dynamic" ,), "unsupported"
113
- assert scaling_type_grad_output in ("dynamic" ,), "unsupported"
114
-
115
106
specs = get_specs ()
116
107
117
108
# there are three gemms in the fwd/bwd of a linear:
@@ -131,14 +122,12 @@ def get_float8_mem_sympy(
131
122
fwd_fp8_input_mem = get_tensor_memory_traffic_bytes (
132
123
M ,
133
124
K ,
134
- scaling_type_input ,
135
125
fuse_with_prev = True ,
136
126
model_torch_compile_limitations = model_torch_compile_limitations ,
137
127
)
138
128
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes (
139
129
K ,
140
130
N ,
141
- scaling_type_weight ,
142
131
fuse_with_prev = False ,
143
132
model_torch_compile_limitations = model_torch_compile_limitations ,
144
133
)
@@ -150,7 +139,6 @@ def get_float8_mem_sympy(
150
139
gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes (
151
140
M ,
152
141
N ,
153
- scaling_type_grad_output ,
154
142
fuse_with_prev = True ,
155
143
model_torch_compile_limitations = model_torch_compile_limitations ,
156
144
)
@@ -183,15 +171,12 @@ def get_float8_mem_sympy(
183
171
# kernel overhead in the units of seconds, and the per-gemm-input memory
184
172
# estimations are in the units of bytes.
185
173
num_extra_kernels = 0
186
- if scaling_type_input == "dynamic" :
187
- # second stage of max-abs reduction
188
- num_extra_kernels += 1
189
- if scaling_type_weight == "dynamic" :
190
- # second stage of max-abs reduction
191
- num_extra_kernels += 1
192
- if scaling_type_grad_output == "dynamic" :
193
- # second stage of max-abs reduction
194
- num_extra_kernels += 1
174
+ # second stage of max-abs reduction for input
175
+ num_extra_kernels += 1
176
+ # second stage of max-abs reduction for weight
177
+ num_extra_kernels += 1
178
+ # second stage of max-abs reduction for grad_output
179
+ num_extra_kernels += 1
195
180
196
181
extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
197
182
0 commit comments