-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add INT8 mixed-precision training #748
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/748
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6082d30 with merge base 1b317f9 (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
nice! I am on vacation for most of this week but will try to take a look. |
(Not marking this ready for review yet because I'm still updating some benchmarks number. I'm seeing some discrepancy in peak memory between ao's built-in Llama and HF's Llama. Otherwise the main code should be ready for review.) |
from .int8 import quantize_int8_rowwise | ||
|
||
if has_triton(): | ||
from .int8_mm import int8_mm_dequant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you tried using torch._int_mm
with torch.compile in max-autotune
mode to generate the matmul?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't. I can try.
A question. Can torch._int_mm()
generate INT8 matmul fused with dequant? If yes, what is the proper way to write matmul that torch.compile will codegen it correctly? I see some versions of this in dynamo template but not very sure under which condition it will be codegen-ed.
Some problems I often face:
max-autotune
can take very long, especially for training.- triton codegen can be finicky i.e. it may or may not do what we expect. Feels like sometimes it's clearer to manually write out the triton kernel (though we may miss out some other fusion opportunities)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense. To scale to all preceding/subsequent ops, IMO compiler is the only scalable option. If you are ok with just covering the matmul and not fusing preceding/subsequent ops, manual kernel works great. I was more just wondering if you have have tried using the compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was playing around a bit with torch._int_mm()
with max-autotune to codegen the fused matmul. Some observations:
- Seems like currently inductor can only fuse 1 epilogue multiplication. I think the relevant code is here? https://github.com/pytorch/pytorch/blob/90e12cf63d7ead464e9f99007b50b494e82298ca/torch/_inductor/kernel/mm.py#L757
- This might not be an issue, as the other multiplication can be fused into the subsequent op. However, in a small end2end benchmark, my explicit triton kernel is still much faster (27%).
Therefore, I think for now using the explicit triton kernel is still better.
What I'm curious is that is there a mechanism to register a custom op with torch compiler that provides 2 implementations (e.g. 1 hand-written fused Triton impl, and 1 pure PyTorch impl which can potentially be fused across op boundaries) so that the compiler can benchmark and choose the better one in a specific scenario.
If you do the same thing |
def fsdp_post_all_gather( | ||
self, | ||
all_gather_outputs: Tuple[Tensor, ...], | ||
metadata: Any, | ||
param_dtype: torch.dtype, | ||
*, | ||
out: Optional[Tensor] = None, | ||
): | ||
(data,) = all_gather_outputs | ||
(config,) = metadata | ||
if out is not None: | ||
assert isinstance(out, Int8MixedPrecisionTrainingLinearWeight) | ||
assert out.config == config | ||
return | ||
return Int8MixedPrecisionTrainingLinearWeight(data.to(param_dtype), config), all_gather_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo The only thing not working with FSDP now (in both eager and compile mode) is MixedPrecisionPolicy(param_dtype=torch.bfloat16)
with FP32 model. I think this is where it goes wrong. With this code, the loss diverges from baseline.
If I change to the below, loss becomes NaN.
data = data.to(param_dtype)
return Int8MixedPrecisionTrainingLinearWeight(data, config), (data,)
If I don't do anything about param_dtype
, I will get AssertionError: FSDP reduce-scatter expects uniform gradient dtype but got {torch.bfloat16, torch.float32}
.
Do you know what should be the correct way to do this? Also, what should fsdp_post_all_gather()
return? I suppose this is a private API so there is no documentation.
I think one way to circumvent this is to override dtype
of tensor subclass and keeps track of this metadata to propagate across ops (like NF4), and in matmul/linear, return the correct dtype. However, it seems a bit cumbersome for what seems to be simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not directly answering your question, but here is how I built / encouraged folks to build things for float8 training:
- focus on promising accuracy + performance first, take the simplest possible path through the PT stack to get good numbers
- improve UX, as needed, after (1) is done
At a high level, that meant using module swaps and sidestepping a lot of missing functionality in torch.compile + tensor subclasses, to move faster on (1). If you want, you are welcome to do the same here. That would mean something like:
- have a Int8TraningLinear, weight is a regular tensor
- tensor subclasses are transient and only used to work around autograd limitations and compose with distributed
- if you follow what float8 is doing, things will work with compile/distributed/autograd as we've fixed most/all of the issues
Note that this isn't aligned with torchao inference, and that alignment will have to happen at some future time, I'm just decoupling that from unblocking early research / getting promising performance + accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what should fsdp_post_all_gather() return?
This is an API designed to allow FSDP to communicate in lower precision. Since you are currently always wrapping weights with subclasses, something that would sound right to me is:
- first, just make FSDP comms work in higher precision, this means you have to ser/deser your weight wrapper in these APIs
- then, in a separate PR, optinally enable doing the comms in int8, this means you quantize in the fsdp_pre_all_gather API and reconstruct the already quantized tensor in the fsdp_post_all_gather API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one way to circumvent this is to override dtype of tensor subclass and keeps track of this metadata to propagate across ops (like NF4), and in matmul/linear, return the correct dtype. However, it seems a bit cumbersome for what seems to be simple.
hmm, not sure exactly what this means, but for float8, dtype is set to what we want autograd to see. I understand this might be confusing to a subset of people, but I think being practical and just having things work e2e with a slight regression to what is intuitive has worked well in that case. There isn't a single thing we can use as a "dtype" for arbitrary tensor subclasses anyways which composes with all the PyTorch systems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I think "module swap + regular weight tensor + call custom autograd.Function in forward" is simpler to implement too. May need to add @torch._dynamo.allow_in_graph
to the custom autograd.Function for compile to work. I tested this combo and it works for single GPU. Will try if it works for FSDP w/ mixed-precision param_dtype.
The main reason I started this out as tensor subclass is because of familiarity when I wrote other stuff in torchao. Also, for quantized weight before all-gather, tensor subclass for linear weight is needed (similar to WeightWithDynamicFloat8CastTensor
) (though I agree that this shouldn't be a big reason at this stage). Anyway, single GPU performance is already pretty good 😄, only FSDP is left. Though I guess FSDP performance is more important for big corps 😄.
This is an API designed to allow FSDP to communicate in lower precision
Yea I realize if the tensor subclass doesn't implement this method, the following ops will be called
torch.ops.fsdp.all_gather_copy_in.default,
torch.ops.fsdp.split_with_sizes_copy.default,
torch.ops.c10d._allgather_base_.default,
It's not clear which op here should return the wrapped tensor (so that later F.linear()
dispatch is correct). So maybe it's also possible to not use fsdp_pre/post_all_gather()
if I don't need low-precision comms at the moment.
first, just make FSDP comms work in higher precision, this means you have to ser/deser your weight wrapper in these APIs
Yes, this is my intention. But right now it doesn't work with MixedPrecisionPolicy(param_dtype=torch.bfloat16)
as I said previously. There will be complaints about dtype (thus I thought about keeping tracking of a separate dtype metadata like NF4). If there is no MixedPrecisionPolicy, the current state of this PR works with FSDP, but very slow even with torch.compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that sounds like a bug, lmk if you have logs from a small toy model to demonstrate what is different, I'd be interested to check that out
What kind of logs do you need? Like profile trace? I don't keep this code anymore, so need some efforts to reproduce it. If you really want to check it out, I can produce it for you.
Also a side observation. Without fsdp_pre/post_all_gather()
, I think we can't propagate tensor subclass across FSDP all-gather. I tried wrapping the outputs of the 3 FSDP ops mentioned earlier (fsdp.all_gather_copy_in.default
, ...) but still could not propagate the tensor subclass. I think this is why NF4 also needs to use fsdp_pre/post_all_gather()
, even though it doesn't exactly do dist comm quantization (just all-gather the pre-quantized weight, and wrap it with tensor subclass again).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of logs do you need? Like profile trace? I don't keep this code anymore, so need some efforts to reproduce it. If you really want to check it out, I can produce it for you.
feel free to ignore since looks like it will be a lot of work. I can check it out from my machine after this lands, although it will be on different hardware.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, just curious, do you have end2end speedup numbers for FP8? Specifically for Llama3-8B. I see there is one here https://github.com/pytorch/torchtitan/blob/main/docs/performance.md, but it's with tensor parallel and for Llama3-405B.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the most recent polished numbers we have: https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359 . It's on LLaMa 3 70B, 1.5x speedup on 128 H100s with FSDP but no TP/SP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The setup there seems like doesn't use many special tricks yet (still just dynamic scaling) + FP8 all-gather (which probably is not too significant for an 8x A100 setup like I used). So there are still room to optimize INT8 training further 😄 (assuming INT8 and FP8 speedups over BF16 are similar) (or maybe using Llama3-70B will have larger speedup simply thanks to larger matmul -> more % time spent in matmul)
Added torchtitan FSDP results to PR description. Looking very good 😄 @andrewor14 I have made all the changes I want to make. You can take a look through everything again to see if I haven't addressed any of your earlier feedback + any other changes you want. M1 build wheels CI failing is unrelated. |
Looks good to me! I'm merging this and cherry-picking this into the 0.5 branch. Thanks for all your hard work figuring out the FSDP issue. If there are remaining issues let's address them in follow-up PRs. |
By the way @gau-nernst I suspect the reason why you weren't able to get good numerics with torchtune is because fine-tuning with C4 there is simply broken. I ran into this myself recently: pytorch/torchtune#1526. Probably won't be able to test that particular workflow until they fix this problem. |
* initial commit * expose some UX. update test * add test. update bench * update test. add doc * fix ngpu * fix FSDP * fix * fix fsdp test * fix * grammar * simplify fsdp test * update benchmark script * update * make claim more conservative * register fused adam * update benchmark script * add more ops * update default * use TorchAOBaseTensor * fix fsdp param_dtype * fix param_dtype * dtype check to prevent unnecessary errors * move checks * add note * fix * simplify script * add module-based UX * fix * use FP8 impl of __torch_dispatch__ * rename _dynamice interface * update test * fix compile on 2.4 * log torch version * make log interval customizable * make naming for explicit * update readme * some change * fix big bug * add docstring. update _get_linear_inserter * add TorchAOBaseTensor back * fix FSDP * update FSDP test. add autocast support * reduce iter * update int8_mm fallback * put leading dims logic to _dynamic_int8_mm
* initial commit * expose some UX. update test * add test. update bench * update test. add doc * fix ngpu * fix FSDP * fix * fix fsdp test * fix * grammar * simplify fsdp test * update benchmark script * update * make claim more conservative * register fused adam * update benchmark script * add more ops * update default * use TorchAOBaseTensor * fix fsdp param_dtype * fix param_dtype * dtype check to prevent unnecessary errors * move checks * add note * fix * simplify script * add module-based UX * fix * use FP8 impl of __torch_dispatch__ * rename _dynamice interface * update test * fix compile on 2.4 * log torch version * make log interval customizable * make naming for explicit * update readme * some change * fix big bug * add docstring. update _get_linear_inserter * add TorchAOBaseTensor back * fix FSDP * update FSDP test. add autocast support * reduce iter * update int8_mm fallback * put leading dims logic to _dynamic_int8_mm
# return new unwrapped object | ||
return out | ||
|
||
def fsdp_pre_all_gather(self, mesh): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, I was thinking of changing the signature to:
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
module: nn.Module, # owning module (e.g. `lin` if this is `lin.weight`)
mp_policy: MixedPrecisionPolicy, # policy of the owning FSDP module
):
as (1) having access to the module
can allow using the module as a scratch space for any needed state and (2) having access to the mp_policy
helps with situations like the one here.
My main concern with doing this is how to deal with the backward compatibility issue. Technically, FSDP2 is still in prototype release, so we are allowed to break BC. However, libraries like ao and torchtune may need to do if/else based on PyTorch version. Maybe it is better to do it sooner than later though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we go: pytorch/pytorch#136129
will need more time to test it -- fortunately the subclasses using the FSDP extension are all in ao now it seems 😆
Excerpts from the new README
Terminologies for INT8 training are generally not standardized yet. To be precise, we use these terms with the following meaning:
There are 3 main benefits of using low-precision dtype for training (the extent depends on the actual strategies):
INT8 mixed-precision
On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision.
Basic usage
During training, there are 3 matmuls involved in each
nn.Linear
layer:output = input @ weight.T
grad_input = grad_output @ weight
grad_weight = grad_output.T @ input
You can configure which matmul to be applied with INT8 mixed-precision (shown above). If convergence is an issue, we recommend leaving
grad_weight
in original matmul precision, and alsograd_input
if the issue still persists.Note:
output
andgrad_input
, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling.FSDP support
Out of the box, this INT8 mixed-precision training is not compatible with FSDP2
MixedPrecisionPolicy(param_dtype=param_dtype)
, whereparam_dtype
!= model dtype. As a workaround, you will need to manually specify the FSDP2'sparam_dtype
inInt8MixedPrecisionTrainingConfig
Preliminary results
(1) INT8 matmul microbenchmarks using
benchmarks/quantized_training/benchmark_int8mm.py
. Measure speedup over BF16 matmul. UsingA @ B.T
i.e.F.linear()
. Using vast.ai instances.A100 (sm80)
4090 (sm89, consumer)
L40S (sm89)
Seems like consumer cards have an unusual speedups for INT8 🤔
(2) End2end speed benchmark using
benchmarks/quantized_training/pretrain_llama2.py
There is a small loss spike for INT8 mixed-precision training. This is different from (3) (see below). Possibly because this script uses the built-in torchao's Llama, which does not do proper weight initialization, while (3) uses Llama from HF.
4090 also sees a much larger end2end speedup compared to A100. This agrees with microbenchmarks in (1).
(3) Convergence benchmark for Llama2-1B pre-training on C4 realnewslike subset, using https://github.com/gau-nernst/quantized-training. bs=32, seq_len=2048 -> 65k tok/batch. Train for 20k steps (1.3B tokens). 1x 4090 -> 70% speedup.
The surprising thing is that INT8 mixed-precision outperforms BF16 in terms of both speed and accuracy. There are loss spikes in BF16 run that don't manifest in INT8 mixed-precision. I don't have an explanation for this. Maybe need to validate this more with torchtune and torchtitan.
(4) Short benchmark with torchtune. Llama3.1-8B-instruct on 1x A100.
The issue with torchtune is that it pads input tokens to longest sequence, which results in a variety of seq len. This will trigger triton autotune at every step. The current workaround is to pad the input tokens manually. If there is popular demand for dynamic seq_len, which is probably common for fine-tuning workload I think, we might need to write a custom triton autotune that will only trigger re-tuning when input dims change significantly.
Comparing BF16 w/ padding and INT8 mixed-precision w/ padding, there is ~20% speedup. However, due to the impact of padding, there is no significant speedup over baseline.
(5) Short benchmark with torchtitan. LLama3-8B pre-training on 8x A100. Default config
llama3_8b.toml
withprofiling.enable_profiling=false
,optimizer.fused=true
, andtraining.compile=True
-> ~20% speedupThe following code is added before applying FSDP
The reduction in peak memory active is strange. But the loss curves match, so seems like there is nothing wrong? Will need to do memory profiling to understand what's going on.
Side note: For all end2end benchmarks, I don't apply INT8 mixed-precision to the LM head. This is because there will be an increase in peak memory. Haven't got the time to debug it, so I just leave it off. Other people's works also suggest that the first and last layers are the most sensitive.
Discussion
(1) In this blogpost, under Quantization configuration, Google AQT states that
It's unclear which backward matmul is left in BF16. I'm guessing it's the one calculating
grad_weight
, since it directly affect weight updates, and SwitchBack also leaves this matmul in high precision, citing precision issues when batch size is large.(2) Dynamic BF16->INT8 cast with stochastic rounding (SR). This can help to improve the accuracy of matmul. However, preliminary benchmarks show that it slows down the gains in speed quite substantially. There are also issues with increased peak memory in backward pass. Given that without SR, the results are already very good, I don't think it's necessary to introduce SR here.
(3) Many of optimization of FP8 training can be applied here too. Namely, delayed scaling (my current INT8 mixed-precision here is basically dynamic scaling), quantize before FSDP all-gather to reduce communication bandwidth. I leave this for future PRs.