Description
🐛 Describe the bug
We should make the following kernel be fast in compile + inductor. This is important to be able to generate the dim1 cast to MX formats.
def scale_dim1_reference(x_hp: torch.Tensor, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
# normalize across dim1
x_hp_d1 = x_hp.t().contiguous()
x_hp_d1_block = x_hp_d1.reshape(-1, block_size)
x_hp_d1_block_abs = x_hp_d1_block.abs()
amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1)
x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1
x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape)
return x_hp_d1_normalized.t(), amax_dim1
Currently, I am only hitting 0.6 to 0.7 TB/s on NVIDIA H100. If the reduction and write is across dim0 instead of dim1, I see 2.0-2.2 TB/s. From discussions with @eellison , this is due to uncoalesced reads and we can fix this.
Repro script: https://gist.github.com/vkuzo/9eff0d27691be483e45bb10edf66d82c
Repro results on NVIDIA H100:
(pytorch) [vasiliy@devgpu006.vll6 ~/local/pytorch_scripts/mx_cast_poc (20250325_dim1_cast)]$ python 20250325_dim1_cast.py --M 4096 --K 4096
M 4096 K 4096 BLOCK_SIZE 32
GPU: NVIDIA H100
torch version: 2.8.0a0+gitdd94e94
triton version: 3.2.0
time_reference_compile_us 107.69072608695663
mem_bw_gbps 632.8998092645895
(pytorch) [vasiliy@devgpu006.vll6 ~/local/pytorch_scripts/mx_cast_poc (20250325_dim1_cast)]$ python 20250325_dim1_cast.py --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA H100
torch version: 2.8.0a0+gitdd94e94
triton version: 3.2.0
time_reference_compile_us 1612.7510689655173
mem_bw_gbps 676.1855942836252
TORCH_LOGS=output_code results: https://gist.github.com/vkuzo/4420c5b508ddd560e5d4620758b5936a
Versions
main branch
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov