What this cast is doing
- reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
- for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
- return the casted elements and the scale
We really should do this all in one kernel, but today we see two kernels
How to reproduce (requires latest main branch)
TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only
Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels
What this cast is doing
We really should do this all in one kernel, but today we see two kernels
How to reproduce (requires latest main branch)
Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels