-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding kv_cache quantization #532
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/532
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6c0c9b7 with merge base 0cc1c4d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Were the memory decreases also minimal at larger sequence lengths? I'd increase it substantially to 100K, 200K, 500K, 1M and see how results change then |
i'll test that |
torchao/_models/llama/model.py
Outdated
del v_val | ||
|
||
# return k_out, v_out | ||
return self.k_cache*self.k_cache_scale, self.v_cache*self.v_cache_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing the bfloat16 kv cache is materialized in global memory, thus no memory saving. need to inspect it further.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: the peak memory improvement is extremely small, tried a few things to fix this but didn't have any luck. Accuracy is very poor (text is unintelligible) tried to leave most recent token not quantized (since we have full fidelity information for whatever the current token is). That didn't solve the issue and resulted in a significant memory increase, may need to try affine quantization but currently more concerned with the lack of memory improvement. (see benchmark_results.txt for the results see kv_quant: True vs kv_quant: False for comparison.) i also took a memory trace you can get with (if you're a meta employee) jf download GCqU9BqGNUybzv8CABWUzUtOiPZ5bsIXAAAz --file "mem_trace_kvq.html" Test Plan: sh benchmarks.sh Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed followup work offline
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Adding support for kv_cache quantization, we are using simple symmetric quantization, though using the full precision of the k and v values of the current token. we see tok/s reduction of 3-5 tok/s depending on context length. image and a reduction in peak memory image We expect this reduction to scale to large context lengths, in the model memory trace we can see the point where we replace the bf16 cache with the int8 cache which visually saves about half of the used memory Screenshot 2024-08-02 at 2 45 14 AM at longer context lengths both quantized and non-quantized kv_cache models start outputing weird stuff but otherwise accuracy of the kv_cache quant looks reasonable though e.g. for 2048 context length: <|begin_of_text|>Hello, my name is Richard Brown and I have been a professional musician for over 25 years. I have played in a number of bands, doing a wide variety of genres (soul/funk, rock, jazz, blues, latin, world). I have played on over a hundred albums so far. I have played with many different singers, as well as instrumentalists (guitarists, sax players, brass players, etc.). I love to play and try to learn as much as I can from others. I have become an all-round musician - playing keyboards, drums, programming, arranging; as well as writing songs myself. I have my own studio, and I can do sessions online. I also have my own website, where you can find out more about me and my music. I hope that you will find the music that you are looking for here. Otherwise there are some fixes in generate.py to get things working for large context lengths without overflowing beyond the model limit. test plan: sh benchmarks.sh (specifically the last 6 rows of benchmark_results.txt)
Adding support for kv_cache quantization, we are using simple symmetric quantization, though using the full precision of the k and v values of the current token.
we see tok/s reduction of 3-5 tok/s depending on context length.
and a reduction in peak memory
We expect this reduction to scale to large context lengths, in the model memory trace we can see the point where we replace the bf16 cache with the int8 cache which visually saves about half of the used memory
at longer context lengths both quantized and non-quantized kv_cache models start outputing weird stuff but otherwise accuracy of the kv_cache quant looks reasonable though e.g. for 2048 context length:
<|begin_of_text|>Hello, my name is Richard Brown and I have been a professional musician for over 25 years. I have played in a number of bands, doing a wide variety of genres (soul/funk, rock, jazz, blues, latin, world). I have played on over a hundred albums so far.
I have played with many different singers, as well as instrumentalists (guitarists, sax players, brass players, etc.). I love to play and try to learn as much as I can from others. I have become an all-round musician - playing keyboards, drums, programming, arranging; as well as writing songs myself. I have my own studio, and I can do sessions online.
I also have my own website, where you can find out more about me and my music.
I hope that you will find the music that you are looking for here.
Otherwise there are some fixes in generate.py to get things working for large context lengths without overflowing beyond the model limit.
test plan:
sh benchmarks.sh
(specifically the last 6 rows of benchmark_results.txt)