Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add INT8 mixed-precision training #748

Merged
merged 54 commits into from
Sep 9, 2024
Merged

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Aug 26, 2024

Excerpts from the new README

Terminologies for INT8 training are generally not standardized yet. To be precise, we use these terms with the following meaning:

  • Quantized training: model weights are quantized. This is a strict requirement. Does not matter what is the compute precision. Examples of this: Q-GaLore, JetFire.
  • INT8 mixed-precision training: model weights are in original precision, while compute dtype for some or all ops is in INT8. We call it like this because it is similar to FP16/BF16 mixed-precision training. One difference is that in FP16/BF16 mixed-precision training, matmul will return FP16/BF16 outputs, while for INT8 mixed-precision training, the returned dtype is usually not INT8. Examples include Google AQT and SwitchBack.

There are 3 main benefits of using low-precision dtype for training (the extent depends on the actual strategies):

  • Memory: reduce memory footprint by model weights, activations, gradients, and distributed communication bandwidth.
  • Speed: speedup compute-bound ops with low-precision hardware instructions (e.g. INT8 Tensor Cores) and speedup memory-bound ops with quantized inputs/outputs.
  • What you train is what you serve.

INT8 mixed-precision

On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision.

Basic usage

from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig
from torchao.quantization import quantize_

model = ...

# apply INT8 matmul to all 3 matmuls
quantize_(model, int8_mixed_precision_training())

# customize which matmul is left in original precision.
config = Int8MixedPrecisionTrainingConfig(
    output=True,
    grad_input=True,
    grad_weight=False,
)
quantize_(model, int8_mixed_precision_training(config))

# train model as usual

During training, there are 3 matmuls involved in each nn.Linear layer:

  • 1 in forward: output = input @ weight.T
  • 2 in backward:
    • grad_input = grad_output @ weight
    • grad_weight = grad_output.T @ input

You can configure which matmul to be applied with INT8 mixed-precision (shown above). If convergence is an issue, we recommend leaving grad_weight in original matmul precision, and also grad_input if the issue still persists.

Note:

  • When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT for INT8 dynamic activations + INT8 weight quantization (A8W8).
  • When we only apply INT8 mixed-precision to output and grad_input, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling.
  • Apply stochastic rounding to INT8 quantization may improve matmul accuracy. However, from our testing, this seems to be unnecessary, thus we don't implement it at the moment.

FSDP support

Out of the box, this INT8 mixed-precision training is not compatible with FSDP2 MixedPrecisionPolicy(param_dtype=param_dtype), where param_dtype != model dtype. As a workaround, you will need to manually specify the FSDP2's param_dtype in Int8MixedPrecisionTrainingConfig

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig
from torchao.quantization import quantize_

model = ...  # FP32 model

# setup configs
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
int8mp_config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=mp_policy.param_dtype)

# exclude LM head
quantize_(model.layers, int8_mixed_precision_training(int8mp_config))

# shard the model w/ FSDP2
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)

# train model as usual

Preliminary results

(1) INT8 matmul microbenchmarks using benchmarks/quantized_training/benchmark_int8mm.py. Measure speedup over BF16 matmul. Using A @ B.T i.e. F.linear(). Using vast.ai instances.

A100 (sm80)

M N K CuBLAS INT8 speedup Triton INT8 dequant speedup
0 1024 1024 1024 1.26316 1.41176
1 2048 2048 2048 1.89655 2.07547
2 4096 4096 4096 1.56211 1.76491
3 32768 4096 4096 1.70497 1.95714
4 4096 4096 32768 1.44795 2.12719
5 32768 14336 4096 1.69107 1.94103
6 32768 4096 14336 1.66619 2.08398
7 14336 4096 32768 1.5258 2.16727

4090 (sm89, consumer)

M N K CuBLAS INT8 speedup Triton INT8 dequant speedup
0 1024 1024 1024 1.4375 1.76923
1 2048 2048 2048 1.55882 2.94444
2 4096 4096 4096 3.14859 3.4087
3 32768 4096 4096 3.46017 3.5822
4 4096 4096 32768 3.84755 3.78122
5 32768 14336 4096 3.23808 3.27171
6 32768 4096 14336 3.58319 3.5256
7 14336 4096 32768 3.64608 3.56454

L40S (sm89)

M N K CuBLAS INT8 speedup Triton INT8 dequant speedup
0 1024 1024 1024 1.2 1.125
1 2048 2048 2048 1.175 2.18605
2 4096 4096 4096 1.83962 2.06714
3 32768 4096 4096 2.06275 2.31772
4 4096 4096 32768 2.20076 2.25022
5 32768 14336 4096 2.02969 2.08319
6 32768 4096 14336 2.18722 2.19138
7 14336 4096 32768 2.22682 2.20711

Seems like consumer cards have an unusual speedups for INT8 🤔

(2) End2end speed benchmark using benchmarks/quantized_training/pretrain_llama2.py

Model & GPU bs x seq_len Config Tok/s Peak mem (GB)
Llama2-7B, A100 8 x 2048 BF16 (baseline) ~4400 59.69
Llama2-7B, A100 8 x 2048 INT8 mixed-precision ~6100 (+39%) 58.28
Llama2-1B, 4090 16 x 2048 BF16 (baseline) ~17,900 18.23
Llama2-1B, 4090 16 x 2048 INT8 mixed-precision ~30,700 (+72%) 18.34
image image

There is a small loss spike for INT8 mixed-precision training. This is different from (3) (see below). Possibly because this script uses the built-in torchao's Llama, which does not do proper weight initialization, while (3) uses Llama from HF.

4090 also sees a much larger end2end speedup compared to A100. This agrees with microbenchmarks in (1).

(3) Convergence benchmark for Llama2-1B pre-training on C4 realnewslike subset, using https://github.com/gau-nernst/quantized-training. bs=32, seq_len=2048 -> 65k tok/batch. Train for 20k steps (1.3B tokens). 1x 4090 -> 70% speedup.

Config Tok/s Peak mem (GB) Val loss
BF16 (baseline) ~17k 19.47 2.97
INT8 mixed-precision ~29k 19.47 2.90
image

The surprising thing is that INT8 mixed-precision outperforms BF16 in terms of both speed and accuracy. There are loss spikes in BF16 run that don't manifest in INT8 mixed-precision. I don't have an explanation for this. Maybe need to validate this more with torchtune and torchtitan.

(4) Short benchmark with torchtune. Llama3.1-8B-instruct on 1x A100.

Config Loss @ 500 tok/s Peak mem alloc (GB)
BF16 (baseline) 1.01 3900 63.18
BF16 (w/ padding) 1.01 3300 63.89
INT8 mixed-precision (w/ padding) 1.02 4000 63.96
image

The issue with torchtune is that it pads input tokens to longest sequence, which results in a variety of seq len. This will trigger triton autotune at every step. The current workaround is to pad the input tokens manually. If there is popular demand for dynamic seq_len, which is probably common for fine-tuning workload I think, we might need to write a custom triton autotune that will only trigger re-tuning when input dims change significantly.

Comparing BF16 w/ padding and INT8 mixed-precision w/ padding, there is ~20% speedup. However, due to the impact of padding, there is no significant speedup over baseline.

(5) Short benchmark with torchtitan. LLama3-8B pre-training on 8x A100. Default config llama3_8b.toml with profiling.enable_profiling=false, optimizer.fused=true, and training.compile=True -> ~20% speedup

The following code is added before applying FSDP

    from torchao.quantization import quantize_
    from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig

    dtype = getattr(torch, job_config.training.mixed_precision_param)
    config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=dtype)
    quantize_(model.layers, int8_mixed_precision_training(config))
Config Loss @ 1k tok/s Peak mem active (GB)
BF16 (baseline) 4.53 3300 39.13
INT8 mixed-precision 4.55 3900 29.63
image

The reduction in peak memory active is strange. But the loss curves match, so seems like there is nothing wrong? Will need to do memory profiling to understand what's going on.

Side note: For all end2end benchmarks, I don't apply INT8 mixed-precision to the LM head. This is because there will be an increase in peak memory. Haven't got the time to debug it, so I just leave it off. Other people's works also suggest that the first and last layers are the most sensitive.

Discussion

(1) In this blogpost, under Quantization configuration, Google AQT states that

In the forward pass, we INT8-quantized all tensor ops in each Transformer layer except for one tensor op in each attention, which was more sensitive. We also quantized the model head — the "logits'' layer.
For each tensor op quantized in the forward pass, we INT8-quantize one of the backprop tensor ops but we leave the other one using bf16 inputs.
For MLPerf experiments where we already had local AQT, we quantized both backprop tensor ops.

It's unclear which backward matmul is left in BF16. I'm guessing it's the one calculating grad_weight, since it directly affect weight updates, and SwitchBack also leaves this matmul in high precision, citing precision issues when batch size is large.

(2) Dynamic BF16->INT8 cast with stochastic rounding (SR). This can help to improve the accuracy of matmul. However, preliminary benchmarks show that it slows down the gains in speed quite substantially. There are also issues with increased peak memory in backward pass. Given that without SR, the results are already very good, I don't think it's necessary to introduce SR here.

(3) Many of optimization of FP8 training can be applied here too. Namely, delayed scaling (my current INT8 mixed-precision here is basically dynamic scaling), quantize before FSDP all-gather to reduce communication bandwidth. I leave this for future PRs.

Copy link

pytorch-bot bot commented Aug 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/748

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6082d30 with merge base 1b317f9 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 26, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Aug 26, 2024

nice! I am on vacation for most of this week but will try to take a look.

@msaroufim msaroufim requested review from andrewor14, msaroufim and vkuzo and removed request for andrewor14 August 27, 2024 00:04
@gau-nernst
Copy link
Collaborator Author

(Not marking this ready for review yet because I'm still updating some benchmarks number. I'm seeing some discrepancy in peak memory between ao's built-in Llama and HF's Llama. Otherwise the main code should be ready for review.)

from .int8 import quantize_int8_rowwise

if has_triton():
from .int8_mm import int8_mm_dequant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tried using torch._int_mm with torch.compile in max-autotune mode to generate the matmul?

Copy link
Collaborator Author

@gau-nernst gau-nernst Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't. I can try.

A question. Can torch._int_mm() generate INT8 matmul fused with dequant? If yes, what is the proper way to write matmul that torch.compile will codegen it correctly? I see some versions of this in dynamo template but not very sure under which condition it will be codegen-ed.

Some problems I often face:

  • max-autotune can take very long, especially for training.
  • triton codegen can be finicky i.e. it may or may not do what we expect. Feels like sometimes it's clearer to manually write out the triton kernel (though we may miss out some other fusion opportunities)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. To scale to all preceding/subsequent ops, IMO compiler is the only scalable option. If you are ok with just covering the matmul and not fusing preceding/subsequent ops, manual kernel works great. I was more just wondering if you have have tried using the compiler.

Copy link
Collaborator Author

@gau-nernst gau-nernst Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was playing around a bit with torch._int_mm() with max-autotune to codegen the fused matmul. Some observations:

Therefore, I think for now using the explicit triton kernel is still better.

What I'm curious is that is there a mechanism to register a custom op with torch compiler that provides 2 implementations (e.g. 1 hand-written fused Triton impl, and 1 pure PyTorch impl which can potentially be fused across op boundaries) so that the compiler can benchmark and choose the better one in a specific scenario.

@gau-nernst gau-nernst marked this pull request as ready for review August 27, 2024 07:02
@gau-nernst
Copy link
Collaborator Author

gau-nernst commented Aug 28, 2024

I was trying this with torchtune codebase on A100 and the results are disastrous. Definitely something is wrong here. Even in eager mode, the speed is crawling. Will investigate more when I have time 😢. At the moment, seems like this only works well for consumer cards.

image

Edit: the regression seems to be isolated to torchtune. Re-running benchmarks on A100 with benchmarks/quantized_training/pretrain_llama2.py now.

Edit 2: added A100 end2end benchmarks with benchmarks/quantized_training/pretrain_llama2.py in PR description. Seeing 25% speedup. The problem is indeed isolated to torchtune.

@vkuzo
Copy link
Contributor

vkuzo commented Sep 6, 2024

Looking at FP8 and NF4 implementations, I think I might have to implement it via aten.mm instead of F.linear. Appreciate if you can also share knowledge/experience with supporting tensor subclass weight in FSDP.

If you do the same thing Float8Tensor/NF4Tensor/MXTensor are doing, things should work. Let us know if you have more specific questions, happy to help.

Comment on lines 117 to 131
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[Tensor] = None,
):
(data,) = all_gather_outputs
(config,) = metadata
if out is not None:
assert isinstance(out, Int8MixedPrecisionTrainingLinearWeight)
assert out.config == config
return
return Int8MixedPrecisionTrainingLinearWeight(data.to(param_dtype), config), all_gather_outputs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo The only thing not working with FSDP now (in both eager and compile mode) is MixedPrecisionPolicy(param_dtype=torch.bfloat16) with FP32 model. I think this is where it goes wrong. With this code, the loss diverges from baseline.

If I change to the below, loss becomes NaN.

data = data.to(param_dtype)
return Int8MixedPrecisionTrainingLinearWeight(data, config), (data,)

If I don't do anything about param_dtype, I will get AssertionError: FSDP reduce-scatter expects uniform gradient dtype but got {torch.bfloat16, torch.float32}.

Do you know what should be the correct way to do this? Also, what should fsdp_post_all_gather() return? I suppose this is a private API so there is no documentation.

I think one way to circumvent this is to override dtype of tensor subclass and keeps track of this metadata to propagate across ops (like NF4), and in matmul/linear, return the correct dtype. However, it seems a bit cumbersome for what seems to be simple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly answering your question, but here is how I built / encouraged folks to build things for float8 training:

  1. focus on promising accuracy + performance first, take the simplest possible path through the PT stack to get good numbers
  2. improve UX, as needed, after (1) is done

At a high level, that meant using module swaps and sidestepping a lot of missing functionality in torch.compile + tensor subclasses, to move faster on (1). If you want, you are welcome to do the same here. That would mean something like:

  1. have a Int8TraningLinear, weight is a regular tensor
  2. tensor subclasses are transient and only used to work around autograd limitations and compose with distributed
  3. if you follow what float8 is doing, things will work with compile/distributed/autograd as we've fixed most/all of the issues

Note that this isn't aligned with torchao inference, and that alignment will have to happen at some future time, I'm just decoupling that from unblocking early research / getting promising performance + accuracy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what should fsdp_post_all_gather() return?

This is an API designed to allow FSDP to communicate in lower precision. Since you are currently always wrapping weights with subclasses, something that would sound right to me is:

  • first, just make FSDP comms work in higher precision, this means you have to ser/deser your weight wrapper in these APIs
  • then, in a separate PR, optinally enable doing the comms in int8, this means you quantize in the fsdp_pre_all_gather API and reconstruct the already quantized tensor in the fsdp_post_all_gather API

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one way to circumvent this is to override dtype of tensor subclass and keeps track of this metadata to propagate across ops (like NF4), and in matmul/linear, return the correct dtype. However, it seems a bit cumbersome for what seems to be simple.

hmm, not sure exactly what this means, but for float8, dtype is set to what we want autograd to see. I understand this might be confusing to a subset of people, but I think being practical and just having things work e2e with a slight regression to what is intuitive has worked well in that case. There isn't a single thing we can use as a "dtype" for arbitrary tensor subclasses anyways which composes with all the PyTorch systems.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think "module swap + regular weight tensor + call custom autograd.Function in forward" is simpler to implement too. May need to add @torch._dynamo.allow_in_graph to the custom autograd.Function for compile to work. I tested this combo and it works for single GPU. Will try if it works for FSDP w/ mixed-precision param_dtype.

The main reason I started this out as tensor subclass is because of familiarity when I wrote other stuff in torchao. Also, for quantized weight before all-gather, tensor subclass for linear weight is needed (similar to WeightWithDynamicFloat8CastTensor) (though I agree that this shouldn't be a big reason at this stage). Anyway, single GPU performance is already pretty good 😄, only FSDP is left. Though I guess FSDP performance is more important for big corps 😄.

This is an API designed to allow FSDP to communicate in lower precision

Yea I realize if the tensor subclass doesn't implement this method, the following ops will be called

torch.ops.fsdp.all_gather_copy_in.default,
torch.ops.fsdp.split_with_sizes_copy.default,
torch.ops.c10d._allgather_base_.default,

It's not clear which op here should return the wrapped tensor (so that later F.linear() dispatch is correct). So maybe it's also possible to not use fsdp_pre/post_all_gather() if I don't need low-precision comms at the moment.

first, just make FSDP comms work in higher precision, this means you have to ser/deser your weight wrapper in these APIs

Yes, this is my intention. But right now it doesn't work with MixedPrecisionPolicy(param_dtype=torch.bfloat16) as I said previously. There will be complaints about dtype (thus I thought about keeping tracking of a separate dtype metadata like NF4). If there is no MixedPrecisionPolicy, the current state of this PR works with FSDP, but very slow even with torch.compile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds like a bug, lmk if you have logs from a small toy model to demonstrate what is different, I'd be interested to check that out

What kind of logs do you need? Like profile trace? I don't keep this code anymore, so need some efforts to reproduce it. If you really want to check it out, I can produce it for you.

Also a side observation. Without fsdp_pre/post_all_gather(), I think we can't propagate tensor subclass across FSDP all-gather. I tried wrapping the outputs of the 3 FSDP ops mentioned earlier (fsdp.all_gather_copy_in.default, ...) but still could not propagate the tensor subclass. I think this is why NF4 also needs to use fsdp_pre/post_all_gather(), even though it doesn't exactly do dist comm quantization (just all-gather the pre-quantized weight, and wrap it with tensor subclass again).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of logs do you need? Like profile trace? I don't keep this code anymore, so need some efforts to reproduce it. If you really want to check it out, I can produce it for you.

feel free to ignore since looks like it will be a lot of work. I can check it out from my machine after this lands, although it will be on different hardware.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, just curious, do you have end2end speedup numbers for FP8? Specifically for Llama3-8B. I see there is one here https://github.com/pytorch/torchtitan/blob/main/docs/performance.md, but it's with tensor parallel and for Llama3-405B.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the most recent polished numbers we have: https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359 . It's on LLaMa 3 70B, 1.5x speedup on 128 H100s with FSDP but no TP/SP.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The setup there seems like doesn't use many special tricks yet (still just dynamic scaling) + FP8 all-gather (which probably is not too significant for an 8x A100 setup like I used). So there are still room to optimize INT8 training further 😄 (assuming INT8 and FP8 speedups over BF16 are similar) (or maybe using Llama3-70B will have larger speedup simply thanks to larger matmul -> more % time spent in matmul)

@gau-nernst gau-nernst marked this pull request as ready for review September 7, 2024 10:25
@gau-nernst
Copy link
Collaborator Author

Added torchtitan FSDP results to PR description. Looking very good 😄

@andrewor14 I have made all the changes I want to make. You can take a look through everything again to see if I haven't addressed any of your earlier feedback + any other changes you want.

M1 build wheels CI failing is unrelated.

@andrewor14
Copy link
Contributor

Looks good to me! I'm merging this and cherry-picking this into the 0.5 branch. Thanks for all your hard work figuring out the FSDP issue. If there are remaining issues let's address them in follow-up PRs.

@andrewor14 andrewor14 merged commit 9acc9a4 into pytorch:main Sep 9, 2024
18 of 20 checks passed
@andrewor14
Copy link
Contributor

By the way @gau-nernst I suspect the reason why you weren't able to get good numerics with torchtune is because fine-tuning with C4 there is simply broken. I ran into this myself recently: pytorch/torchtune#1526. Probably won't be able to test that particular workflow until they fix this problem.

andrewor14 pushed a commit that referenced this pull request Sep 9, 2024
* initial commit

* expose some UX. update test

* add test. update bench

* update test. add doc

* fix ngpu

* fix FSDP

* fix

* fix fsdp test

* fix

* grammar

* simplify fsdp test

* update benchmark script

* update

* make claim more conservative

* register fused adam

* update benchmark script

* add more ops

* update default

* use TorchAOBaseTensor

* fix fsdp param_dtype

* fix param_dtype

* dtype check to prevent unnecessary errors

* move checks

* add note

* fix

* simplify script

* add module-based UX

* fix

* use FP8 impl of __torch_dispatch__

* rename _dynamice interface

* update test

* fix compile on 2.4

* log torch version

* make log interval customizable

* make naming for explicit

* update readme

* some change

* fix big bug

* add docstring. update _get_linear_inserter

* add TorchAOBaseTensor back

* fix FSDP

* update FSDP test. add autocast support

* reduce iter

* update int8_mm fallback

* put leading dims logic to _dynamic_int8_mm
jainapurva pushed a commit that referenced this pull request Sep 9, 2024
* initial commit

* expose some UX. update test

* add test. update bench

* update test. add doc

* fix ngpu

* fix FSDP

* fix

* fix fsdp test

* fix

* grammar

* simplify fsdp test

* update benchmark script

* update

* make claim more conservative

* register fused adam

* update benchmark script

* add more ops

* update default

* use TorchAOBaseTensor

* fix fsdp param_dtype

* fix param_dtype

* dtype check to prevent unnecessary errors

* move checks

* add note

* fix

* simplify script

* add module-based UX

* fix

* use FP8 impl of __torch_dispatch__

* rename _dynamice interface

* update test

* fix compile on 2.4

* log torch version

* make log interval customizable

* make naming for explicit

* update readme

* some change

* fix big bug

* add docstring. update _get_linear_inserter

* add TorchAOBaseTensor back

* fix FSDP

* update FSDP test. add autocast support

* reduce iter

* update int8_mm fallback

* put leading dims logic to _dynamic_int8_mm
@gau-nernst gau-nernst deleted the int8_mp branch September 10, 2024 00:11
# return new unwrapped object
return out

def fsdp_pre_all_gather(self, mesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I was thinking of changing the signature to:

def fsdp_pre_all_gather(
    self,
    mesh: DeviceMesh,
    module: nn.Module,  # owning module (e.g. `lin` if this is `lin.weight`)
    mp_policy: MixedPrecisionPolicy,  # policy of the owning FSDP module
):

as (1) having access to the module can allow using the module as a scratch space for any needed state and (2) having access to the mp_policy helps with situations like the one here.

My main concern with doing this is how to deal with the backward compatibility issue. Technically, FSDP2 is still in prototype release, so we are allowed to break BC. However, libraries like ao and torchtune may need to do if/else based on PyTorch version. Maybe it is better to do it sooner than later though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we go: pytorch/pytorch#136129
will need more time to test it -- fortunately the subclasses using the FSDP extension are all in ao now it seems 😆

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants