Skip to content

float8 training with rowwise scaling #889

Closed
@vkuzo

Description

@vkuzo

This is a brain dump of what is missing from torchao.float8 to support training with rowwise scaling, to help if someone wants to jump in to build this.

already done

  • torch._scaled_mm supports rowwise scaling
  • inductor supports rowwise scaled gemms, in max-autotune mode (I haven't personally tested this yet)

needed

  1. we need Float8Tensor to work with rowwise scales. We had an unlanded PR on float8_experimental doing that here ([wip] add axiswise granularity to Float8Tensor pytorch-labs/float8_experimental#352), just never got the time to land it. You can reuse that PR or do something similar. Note that [Float8Quant] Add rowwise scaling option to float8 dyanmic quant #819 landed recently adding float8 rowwise scaling to inference, so being consistent with that where applicable would be nice.
  2. we need Float8Linear to be configurable with rowwise scales for each argument, and for the scaling to respect the config, validated by tests + benchmarks, would require changes to torchao.float8.config.py and torchao.float8.float8_linear.py.
  3. after (1) and (2), we could make each gemm configurable to enable leaving some of them in high precision
  4. performance fixes throughout torchao.float8 and inductor, if needed based on how well inductor generates the scaling code

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions