Skip to content

torch.compile cast to mxfp8 should only require one kernel #1769

Closed
@vkuzo

Description

@vkuzo

What this cast is doing

  • reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
  • for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
  • return the casted elements and the scale

We really should do this all in one kernel, but today we see two kernels

How to reproduce (requires latest main branch)

TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only

Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions