Skip to content

Commit e6706ca

Browse files
authored
roofline estimation: delete scaling type (#1781)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent c788ee7 commit e6706ca

File tree

2 files changed

+6
-30
lines changed

2 files changed

+6
-30
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,6 @@ def run(
176176
outfile: str,
177177
gemm_time_strategy: str = "benchmarks",
178178
model_torch_compile_limitations: bool = False,
179-
scaling_type_input: str = "dynamic",
180-
scaling_type_weight: str = "dynamic",
181-
scaling_type_grad_output: str = "dynamic",
182179
shape_gen_name: str = "square",
183180
gemm_cache_filename: Optional[str] = None,
184181
n_limit: Optional[int] = None,
@@ -208,18 +205,12 @@ def run(
208205
K,
209206
N,
210207
model_torch_compile_limitations=True,
211-
scaling_type_input="dynamic",
212-
scaling_type_weight="dynamic",
213-
scaling_type_grad_output="dynamic",
214208
)
215209
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
216210
M,
217211
K,
218212
N,
219213
model_torch_compile_limitations=False,
220-
scaling_type_input="dynamic",
221-
scaling_type_weight="dynamic",
222-
scaling_type_grad_output="dynamic",
223214
)
224215

225216
if gemm_time_strategy == "roofline":

torchao/testing/float8/roofline_utils.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,12 @@ def get_specs():
5555
def get_tensor_memory_traffic_bytes(
5656
dim0,
5757
dim1,
58-
scaling_type: str,
5958
fuse_with_prev=False,
6059
model_torch_compile_limitations=False,
6160
):
6261
# assumes input bf16, output f8
6362
numel = dim0 * dim1
6463

65-
assert scaling_type == "dynamic", "unsupported"
6664
# x_bf16 = ...
6765
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
6866
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
@@ -104,14 +102,7 @@ def get_float8_mem_sympy(
104102
K,
105103
N,
106104
model_torch_compile_limitations: bool = False,
107-
scaling_type_input: str = "dynamic",
108-
scaling_type_weight: str = "dynamic",
109-
scaling_type_grad_output: str = "dynamic",
110105
):
111-
assert scaling_type_input in ("dynamic",), "unsupported"
112-
assert scaling_type_weight in ("dynamic",), "unsupported"
113-
assert scaling_type_grad_output in ("dynamic",), "unsupported"
114-
115106
specs = get_specs()
116107

117108
# there are three gemms in the fwd/bwd of a linear:
@@ -131,14 +122,12 @@ def get_float8_mem_sympy(
131122
fwd_fp8_input_mem = get_tensor_memory_traffic_bytes(
132123
M,
133124
K,
134-
scaling_type_input,
135125
fuse_with_prev=True,
136126
model_torch_compile_limitations=model_torch_compile_limitations,
137127
)
138128
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
139129
K,
140130
N,
141-
scaling_type_weight,
142131
fuse_with_prev=False,
143132
model_torch_compile_limitations=model_torch_compile_limitations,
144133
)
@@ -150,7 +139,6 @@ def get_float8_mem_sympy(
150139
gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes(
151140
M,
152141
N,
153-
scaling_type_grad_output,
154142
fuse_with_prev=True,
155143
model_torch_compile_limitations=model_torch_compile_limitations,
156144
)
@@ -183,15 +171,12 @@ def get_float8_mem_sympy(
183171
# kernel overhead in the units of seconds, and the per-gemm-input memory
184172
# estimations are in the units of bytes.
185173
num_extra_kernels = 0
186-
if scaling_type_input == "dynamic":
187-
# second stage of max-abs reduction
188-
num_extra_kernels += 1
189-
if scaling_type_weight == "dynamic":
190-
# second stage of max-abs reduction
191-
num_extra_kernels += 1
192-
if scaling_type_grad_output == "dynamic":
193-
# second stage of max-abs reduction
194-
num_extra_kernels += 1
174+
# second stage of max-abs reduction for input
175+
num_extra_kernels += 1
176+
# second stage of max-abs reduction for weight
177+
num_extra_kernels += 1
178+
# second stage of max-abs reduction for grad_output
179+
num_extra_kernels += 1
195180

196181
extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
197182

0 commit comments

Comments
 (0)