Skip to content

request for faster inductor kernels for blockwise reduction across dim1 -> write #149982

Open
@vkuzo

Description

@vkuzo

🐛 Describe the bug

We should make the following kernel be fast in compile + inductor. This is important to be able to generate the dim1 cast to MX formats.

def scale_dim1_reference(x_hp: torch.Tensor, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
    # normalize across dim1
    x_hp_d1 = x_hp.t().contiguous()
    x_hp_d1_block = x_hp_d1.reshape(-1, block_size)
    x_hp_d1_block_abs = x_hp_d1_block.abs()
    amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1)
    x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1
    x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape)
    return x_hp_d1_normalized.t(), amax_dim1

Currently, I am only hitting 0.6 to 0.7 TB/s on NVIDIA H100. If the reduction and write is across dim0 instead of dim1, I see 2.0-2.2 TB/s. From discussions with @eellison , this is due to uncoalesced reads and we can fix this.

Repro script: https://gist.github.com/vkuzo/9eff0d27691be483e45bb10edf66d82c
Repro results on NVIDIA H100:

(pytorch) [vasiliy@devgpu006.vll6 ~/local/pytorch_scripts/mx_cast_poc (20250325_dim1_cast)]$ python 20250325_dim1_cast.py --M 4096 --K 4096
M 4096 K 4096 BLOCK_SIZE 32
GPU: NVIDIA H100
torch version: 2.8.0a0+gitdd94e94
triton version: 3.2.0
time_reference_compile_us 107.69072608695663
mem_bw_gbps 632.8998092645895
(pytorch) [vasiliy@devgpu006.vll6 ~/local/pytorch_scripts/mx_cast_poc (20250325_dim1_cast)]$ python 20250325_dim1_cast.py --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA H100
torch version: 2.8.0a0+gitdd94e94
triton version: 3.2.0
time_reference_compile_us 1612.7510689655173
mem_bw_gbps 676.1855942836252

TORCH_LOGS=output_code results: https://gist.github.com/vkuzo/4420c5b508ddd560e5d4620758b5936a

Versions

main branch

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions