This repo contains some exploration works for quantized training. Inspirations:
- Q-GaLore: [paper] [code]
- AQT: [related paper] [code]
- SwitchBack: [paper] [code]
- Jetfire: [paper] [code]
Eventually, some of these will be upstreamed to torchao.
# Include submodules to clone cutlass
git clone --recurse-submodules https://github.com/gau-nernst/quantized-training
# Install PyTorch from https://pytorch.org/. Recommended to use nightly version.
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch-nightly -c nvidia
# Install other deps. Might not be updated.
pip install -r requirements.txt
# Apply cutlass patch to fix compile error with scaled int4 matmul.
git apply kernels/cutlass.patch --directory kernels/cutlass
LLM pre-training
# using HF streaming dataset, tokenize on-the-fly
python llm_pretrain.py --train_ds '{"type":"hf_text","dataset":"allenai/c4","subset":"en","split":"train","tokenizer":"llama2"}' --seed 2024
# using pre-tokenized local dataset. see below. "dataset_dir" should contain .bin files
python llm_pretrain.py --train_ds '{"type":"token","dataset_dir":"tinystories_train"}' --seed 2024
To obtain pre-tokenized datasets, either download from gaunernst/tokenized-datasets or run
python tokenize_data.py --dataset tinystories --split train
LLM fine-tuning on MetaMathQA
TODO: update command
python llm_finetune.py --model HuggingFaceTB/SmolLM-1.7B --freeze_embedding_layer --batch_size 4 --n_steps 100_000 --ckpt_interval 10_000 --seed 2024 --compile
ViT supervised training
TODO ImageNet
ViT fine-tuning on RESISC45
TODO: update command
python timm_finetune.py --model timm/vit_giant_patch14_dinov2.lvd142m --n_epochs 2 --batch_size 64 --model_kwargs '{"img_size":224}' --seed 2024 --compile
4070Ti SUPER. Speedup over PyTorch BF16 matmul. See benchmark_mm.py
(might need better configs for FP16. Use default Cutlass INT4 GEMM)
Row-major x Column-major (A @ B.T
)
1024 | 2048 | 4096 | |
---|---|---|---|
4070Ti SUPER | |||
CuBLAS INT8 | 1.95 | 2.01 | 2.90 |
Triton INT8 | 2.72 | 2.87 | 3.14 |
Cutlass INT4 | 2.56 | 3.81 | 5.89 |
Triton FP8 | 1.78 | 1.65 | 1.64 |
Triton FP16 w/ FP16 accumulate | 1.86 | 1.76 | 1.29 |
A100 | |||
CuBLAS INT8 | 1.28 | 1.87 | 1.53 |
Triton INT8 | 1.21 | 1.96 | 1.72 |
Cutlass INT4 | 1.15 | 2.33 | 2.96 |
Triton FP16 w/ FP16 accumulate | 0.92 | 1.26 | 0.98 |
H100 NVL | |||
CuBLAS INT8 | 1.57 | 1.01 | 1.37 |
Triton INT8 | 1.56 | 1.41 | 1.75 |
Cutlass INT4 | 0.08 | 0.08 | 0.07 |
Triton FP8 | 0.67 | 0.28 | 0.28 |
Triton FP16 w/ FP16 accumulate | 1.19 | 0.67 | 0.69 |
Row-major x Row-major (A @ B
)
1024 | 2048 | 4096 | |
---|---|---|---|
CuBLAS INT8 | 1.03 | 0.94 | 0.92 |
Triton INT8 | 1.62 | 1.98 | 2.18 |
Triton FP8 | 1.70 | 1.63 | 1.71 |
Triton FP16 w/ FP16 accumulate | 1.64 | 1.77 | 1.38 |
Column-major x Row-major (A.T @ B
)
1024 | 2048 | 4096 | |
---|---|---|---|
CuBLAS INT8 | 0.87 | 0.93 | 0.88 |
Triton INT8 | 1.31 | 1.43 | 1.54 |
Triton FP8 | 1.48 | 1.61 | 1.70 |
Triton FP16 w/ FP16 accumulate | 1.42 | 1.79 | 1.35 |
4070Ti SUPER. Llama2-1B, bs=16, seq_len=2048. INT8 means dynamically perform row-wise quantization + scaled INT8 matmul. Exclude LM head.
Forward | Backward grad input | Backward grad weight | Stochastic rounding | tok/s | Speedup |
---|---|---|---|---|---|
BF16 | BF16 | BF16 | - | 9,223 | 1.00 |
INT8 | BF16 | BF16 | ❌ | 11,751 | 1.27 |
INT8 | BF16 | BF16 | ✅ | 10,944 | 1.19 |
INT8 | INT8 | BF16 | ❌ | 13,678 | 1.48 |
INT8 | INT8 | BF16 | ✅ | 12,028 | 1.30 |
INT8 | INT8 | INT8 | ❌ | 15,517 | 1.68 |
INT8 | INT8 | INT8 | ✅ | OOM |
When stochastic rounding is used and backward is applied INT8 matmul, there is a significant increase in memory. To be investigated.