Skip to content

Commit 6b46fb9

Browse files
vkuzoliangel-02
authored andcommitted
mx roofline: adjust mxfp8 formulas (#1953)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 8b39a0b commit 6b46fb9

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

torchao/testing/float8/roofline_utils.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -183,27 +183,16 @@ def get_tensor_memory_traffic_ovhd_s(
183183
"mxfp8_cutlass",
184184
"mxfp8_cublas",
185185
), "unsupported"
186-
187-
if tensor_role == "weight":
188-
# x_bf16 = ...
189-
# kernel 1: x_bf16 -> x_mxfp8_dim0
190-
# kernel 2: x_bf16 -> x_mxfp8_dim1
191-
if fuse_with_prev:
192-
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
193-
else:
194-
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
195-
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
196-
res_bytes = [kernel_1_rw, kernel_2_rw]
186+
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
187+
# x_bf16 = ...
188+
# kernel 1: x_bf16 -> x_mxfp8_dim0
189+
# kernel 2: x_bf16 -> x_mxfp8_dim1
190+
if fuse_with_prev:
191+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
197192
else:
198-
# x_bf16 = ...
199-
# kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1
200-
if fuse_with_prev:
201-
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2
202-
else:
203-
kernel_1_rw = (
204-
BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2
205-
)
206-
res_bytes = [kernel_1_rw]
193+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
194+
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
195+
res_bytes = [kernel_1_rw, kernel_2_rw]
207196

208197
# convert from bytes to seconds
209198
res_s = [

0 commit comments

Comments
 (0)