File tree 1 file changed +5
-4
lines changed
torchao/prototype/mx_formats 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -1090,15 +1090,16 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
1090
1090
1091
1091
@triton .jit
1092
1092
def _triton_calculate_scale (x , axis ):
1093
- # We use a small epsilon to avoid division by zero
1094
- epsilon = 1e-10
1095
-
1096
- # TODO(before land): reuse the constants below instead of hardcoding
1093
+ # There is no good support for accessing globals from a jit'ed triton
1094
+ # function, so we redefine them here. Since this is prototype code which
1095
+ # we plan to remove after torch.compile catches up, this is fine.
1097
1096
target_max_pow2 = 8
1098
1097
e8m0_exponent_bias = 127
1099
1098
bf16_mbits = 7
1100
1099
bf16_exp_bias = 127
1101
1100
fp32_mbits = 23
1101
+ # We use a small epsilon to avoid division by zero
1102
+ epsilon = 1e-10
1102
1103
1103
1104
# Find the maximum absolute value for each row
1104
1105
max_abs = tl .max (x , axis = axis )
You can’t perform that action at this time.
0 commit comments