This repository is a prototype of a float8 training UX written in native PyTorch. For now the goal is to move quickly and validate our design. Production readiness, backwards compatibility, etc right now is an explicit non-goal at this point. Once we are farther along, we will discuss how to make this public.
# install requirements
pip install -r requirements.txt
- [done] float8 dtypes in core
- [in progress] torch._scaled_mm in core
- [not started] saturated casts to float8 in core
- [done] Float8Linear with emulation and just-in-time scaling
- [in progress] swap to real fp8 compute
- [in progress] swap to delayed scaling
Note that performance is a non-goal for this milestone
- [starting] PT2.0 compatibility of this repository
- [not started] inductor support for fp8 matmul
- [not started] inductor support for fusing amax calculations into surrounding ops
- [not started] e2e benchmarking
- [in progress] validate FSDP with fp16 weight all-gather still works
- [in progress] design for FSDP with fp8 weight all-gather
- [not started] implementation for FSDP with fp8 weight all-gather
Float8Linear
owns casting inputs/weights/outputs/grads to float8 and keeps track of the relevant buffers- user is responsible for applying
Float8Linear
to the right parts of their model with module swaps - eager mode performance is a non-goal. PT2.0 graph capture -> inductor graph lowering to fp8 enabled fused kernels is the blessed path for competitive performance.
No change from single GPU code
- user code is responsible for making the model fp8 aware and adding the right buffers
- user code is responsible to passing FSDP a data structure with all the information necessary to cast weights to fp8
- FSDP is responsible for performing the fp8 cast and providing the unsharded fp8 weight to each worker
- user code is responsible for syncing amax metadata across workers
More details TBD
For now, we plan to start with:
- moving the input/output casts out from
Float8Linear
and in to module hooks - asking the user to apply the hooks to the right places in their model, to compose with the activation distributed primitives
More details TBD.
float8_playground/float8_linear.py
-Float8Linear
(user facing entry point), and custom fw/bwfloat8_playground/float8_tensor.py
-Float8Tensor
, which contains syntactic sugar for passing float8 data + scale around and converting to/from fp8float8_playground/float8_python_apy.py
- interface between Python functions which know aboutFloat8Tensor
and aten functions which know about raw data + scale
# run single-GPU unit tests
python tests/test.py
# run a single-GPU integration test on SAM
python tests/test_sam.py
# run a two-GPU integration test on FSDP
./tests/test_fsdp.sh
# run all of these tests
./tests/run_everything.sh