Skip to content

Commit 7ecd79f

Browse files
committed
Update
[ghstack-poisoned]
1 parent e341c2e commit 7ecd79f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,15 +1090,16 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
10901090

10911091
@triton.jit
10921092
def _triton_calculate_scale(x, axis):
1093-
# We use a small epsilon to avoid division by zero
1094-
epsilon = 1e-10
1095-
1096-
# TODO(before land): reuse the constants below instead of hardcoding
1093+
# There is no good support for accessing globals from a jit'ed triton
1094+
# function, so we redefine them here. Since this is prototype code which
1095+
# we plan to remove after torch.compile catches up, this is fine.
10971096
target_max_pow2 = 8
10981097
e8m0_exponent_bias = 127
10991098
bf16_mbits = 7
11001099
bf16_exp_bias = 127
11011100
fp32_mbits = 23
1101+
# We use a small epsilon to avoid division by zero
1102+
epsilon = 1e-10
11021103

11031104
# Find the maximum absolute value for each row
11041105
max_abs = tl.max(x, axis=axis)

0 commit comments

Comments
 (0)