Closed
Description
This is a brain dump of what is missing from torchao.float8
to support training with rowwise scaling, to help if someone wants to jump in to build this.
already done
torch._scaled_mm
supports rowwise scaling- inductor supports rowwise scaled gemms, in
max-autotune
mode (I haven't personally tested this yet)
needed
- we need
Float8Tensor
to work with rowwise scales. We had an unlanded PR onfloat8_experimental
doing that here ([wip] add axiswise granularity to Float8Tensor pytorch-labs/float8_experimental#352), just never got the time to land it. You can reuse that PR or do something similar. Note that [Float8Quant] Add rowwise scaling option to float8 dyanmic quant #819 landed recently adding float8 rowwise scaling to inference, so being consistent with that where applicable would be nice. - we need
Float8Linear
to be configurable with rowwise scales for each argument, and for the scaling to respect the config, validated by tests + benchmarks, would require changes totorchao.float8.config.py
andtorchao.float8.float8_linear.py
. - after (1) and (2), we could make each gemm configurable to enable leaving some of them in high precision
- performance fixes throughout
torchao.float8
and inductor, if needed based on how well inductor generates the scaling code
Metadata
Metadata
Assignees
Labels
No labels