Open
Description
What this cast is doing
- reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
- for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
- do ^ across both dim0 and dim1
What we currently see from inductor is two kernels, one for dim0 and one for dim1:
TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only_dim0_dim1
Output: https://gist.github.com/vkuzo/7a9f104872790e58b316c7ba477fcbf5
A mx-compliant 32x32 block of a bfloat16 tensor occupies 2kib of memory, so it should easily fit into shared memory of an SM on a modern GPU. We should explore doing this cast across dim0 and dim1 in a tiled fashion, so we can load each tile to shared memory only once.