Closed
Description
What this cast is doing
- reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
- for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
- return the casted elements and the scale
We really should do this all in one kernel, but today we see two kernels
How to reproduce (requires latest main branch)
TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only
Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels
Metadata
Metadata
Assignees
Labels
No labels