Skip to content

mx cast to mxfp8 across dim0 and dim1 should be performant #1788

Open
@vkuzo

Description

@vkuzo

What this cast is doing

  • reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
  • for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
  • do ^ across both dim0 and dim1

What we currently see from inductor is two kernels, one for dim0 and one for dim1:

TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only_dim0_dim1

Output: https://gist.github.com/vkuzo/7a9f104872790e58b316c7ba477fcbf5

A mx-compliant 32x32 block of a bfloat16 tensor occupies 2kib of memory, so it should easily fit into shared memory of an SM on a modern GPU. We should explore doing this cast across dim0 and dim1 in a tiled fashion, so we can load each tile to shared memory only once.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions