-
Notifications
You must be signed in to change notification settings - Fork 414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PTQ for generate_v2
#1866
base: main
Are you sure you want to change the base?
PTQ for generate_v2
#1866
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1866
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Cancelled JobsAs of commit fc501ee with merge base 33b8143 (): NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
self._device = utils.get_device(device=cfg.device) | ||
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) | ||
self._logger = utils.get_logger(cfg.log_level) | ||
self.device = utils.get_device(device=cfg.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a public recipe, no need to be a "private" variable.
cc @pbontrager
|
||
# Quantize the model if specified | ||
if cfg.get("quantization_method") is not None: | ||
from torchao.quantization.quant_api import quantize_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazily import torchao API
recipes/dev/generate_v2.py
Outdated
from torchao.quantization.quant_api import quantize_ | ||
|
||
quantization_method = config.instantiate(cfg.quantization_method) | ||
compile_model(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compiling the model is necessary for quantization to be really worth it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious whether compiling the model results in greater speedups than compiling the next-token-prediction fn like gptfast do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should compile after quantize_
for speedup actually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I was following the pattern from AO's README where the model is compiled first:
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
Why should the model be compiled after quantization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 Anecdotally, I don't see much difference in tok/sec (after first token) between putting compile first or second. Can you share some more details about which one is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right now quantize_
needs to compile after: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#full-affine-quantization-flow-example
but autoquant
will do compile first before calling autoquant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mode='max-autotune' will take a long time to compile. Is it worth it? We dont do it for training.
Its interesting that in AO's read it says to put compile first. Do we also do it for QLoRA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see much difference in tok/sec (after first token) between putting compile first or second.
I haven't tried calling quantize_
after compile actually, maybe it would have the same effect as well, need to confirm
|
||
# 6. Prefill step | ||
generated_tokens = [] | ||
t0 = time.perf_counter() | ||
logits = self.model(prompt, **batch)[:, -1] | ||
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k) | ||
t1 = time.perf_counter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we might have a warmup run, we log this differently so the user can see how good quantization / compilation is.
@@ -9,6 +9,10 @@ | |||
# Model arguments | |||
model: | |||
_component_: torchtune.models.llama2.llama2_7b | |||
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave this commented out until the user wants to do something with it.
generate_v2
prompt = torch.tensor( | ||
model_inputs["tokens"], device=self._device | ||
).unsqueeze(0) | ||
prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted this to fit on one line lol
@@ -18,6 +19,13 @@ | |||
CACHE_ARTIFACTS_SCRIPT_PATH = root + "/tests/cache_artifacts.sh" | |||
|
|||
|
|||
def pytest_sessionfinish(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compile tries to log a bunch of stuff using the atexit
decorator. However, pytest closes these logs before they finish so it throws an I/O error.
This disables logging exceptions. Not sure if the right way to do it.
# Generation arguments | ||
prompt: | ||
system: You are a helpful and creative AI assistant. | ||
user: What is the capital of France? | ||
max_new_tokens: 200 | ||
max_new_tokens: 500 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow longer generation to really see the benefit of quant + compile.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1866 +/- ##
===========================================
- Coverage 70.44% 25.92% -44.53%
===========================================
Files 308 308
Lines 16270 16292 +22
===========================================
- Hits 11462 4224 -7238
- Misses 4808 12068 +7260 ☔ View full report in Codecov by Sentry. |
@felipemello1 @ebsmothers Will this not pass on PyTorch 2.5 b/c of the issue with CUDNN? This test passes locally on PyTorch v2.5.1. Do we know when the patch will be released? |
@@ -9,6 +9,10 @@ | |||
# Model arguments | |||
model: | |||
_component_: torchtune.models.llama2.llama2_7b | |||
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM | |||
# quantization_method: | |||
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dumb q: so the torchtune.training.quantization API is just for QAT.. or we're not using it anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you mentioned this in the PR description - if we're going to be using the torchao APIs instead it'd be good to follow up with an issue
This looks overall sensible, but a few outstanding questions I have:
We probably don't need to answer all of these here but I think it'd help bring a lot of our quantization offerings in line if we can at least follow up on them. |
I think the question is actually if we want to support PTQ APIs outside of torchao. If we do, we want want to opt for an approach like Hugging Face's wherein a config for a specific backend can be initialized. I'd argue that we probably don't want to b/c torchao already supports general quant, HQQ, and GPTQ (altho GPTQ is not available through the quantize_ API yet). Idk if this is too short sighted though.
Not sure I understand the question. It's always slow during warmup run.
Exactly.
Not sure what is so slow, but I've reached out to the AO team to see if this is normal.
An excellent question. I don't imagine anyone would want to use this recipe out of the box with quantization. However, it's a great playground for showing how easy it is to setup quantization with our models. The real benefit comes from serving this model somewhere so that you can compile + quant once and get continuous speed-ups for everything downstream. Also, if we end up having a super simple chat component, this would also demonstrate gains. |
Are you seeing the slowdown for |
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM | ||
# quantization_method: | ||
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory | ||
# use_hqq: False # Turn on for more accurate results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HDCharles is this true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sorry this was anecdotal.
I tried both |
slower on the first run is expected I feel, since compile actually happens at the first run when it sees the real inputs, typically when we do benchmark there will be some warmup runs for compile to actually run and we'll benchmark the following runs |
I know that compile happens at the first forward pass, but what I'm seeing is a slowdown for the entire first generation of outputs (see logs in the PR description. Is this expected? |
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM | ||
# quantization_method: | ||
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory | ||
# use_hqq: False # Turn on to use Half-Quadratic Quantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does it mean? Can you add if it makes it faster/more accurate/less memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, what i meant is that this should be made clear for the user in the comment :P
Context
What is the purpose of this PR? Is it to
This PR adds post-training quantization support for generate_v2 via torchao. It is tested only for text-models, specifically Llama2.
Why did you change the way quantization APIs are called?
Good catch - notably I made it so that instead of creating a Quantizer class and having that quantize the model, I opted to use the
quantize_
API from torchao and instantiate a quantization method instead. I did this for two reasons:Does this work for vision models?
Technically, it runs, but we haven't fixed the torch.compile graph breaks in the Llama3.2 V model so it doesn't speed anything up. Therefore, I will not be including this in the default config for llama3.2V.
Why is it actually slower for the entire first run?
My assumption is that compile is the culprit here. Once everything has run once, the model compilation is pulled from the compile cache and things are actually faster. Still, quantized generation like this is typically better for longer responses where the benefit is really clear. cc @andrewor14 if my intuition is correct here.
This DOES NOT work for PTQ a QAT model. This will be added in a follow-up.
Changelog
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
Recipe without PTQ:
Recipe with PTQ (first run):
Recipe with PTQ (second run):
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example
To-do
Fix failing GPU test. It's passing locally, so I'm not sure how to make it work on the remote runners: