-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify the implementation of quantization-related methods #19954
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19954 +/- ##
==========================================
+ Coverage 74.29% 79.25% +4.96%
==========================================
Files 500 500
Lines 46867 46862 -5
Branches 8651 8628 -23
==========================================
+ Hits 34819 37142 +2323
+ Misses 10364 7980 -2384
- Partials 1684 1740 +56
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! @mattdangerw what do you think?
5f8235a
to
82f2681
Compare
82f2681
to
15e50fd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM to me.
Does this change the API contract between Keras & KerasNLP again? Since the releases are not coupled, let's give ourselves some wiggle room. Simple API surfaces from Keras is good, defensive usage from KerasNLP good.
We might even want to do things like this form KerasNLP, so we don't crash on earlier Keras versions:
if getattr(self, "quantization_mode", None) == "int8":
more quantization code
The potential version cross product always gets confusing, but flexibility on the keras version will probably save us bug reports post release where people get the versioning wrong.
We could also consider a hard version dep from KerasNLP to Keras, but those can come with their own headache.
57c7ab4
to
6229538
Compare
This is an important point that I missed previously. Thanks for pointing out. I have made the change so that Additionally, I have added
Considering quantization is a newly introduced feature, it should be acceptable to require users to update their Keras version to call WDYT? EDITED: |
I have added a compatibility test to ensure that we don't break the API contract between Keras and KerasNLP. The idea of the test is that I copied the implementation of |
Yeah I think that's fine as well. We just want to avoid breaking users that aren't using the feature at all. If we can without too much code complexity. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me if we remove the sleep
call and maybe just ditch the compat test. But adding a sleep
call really sounds like it will cause lots of pain down the line -- let's avoid it.
I'll leave to @fchollet for final review!
Apart from the simplification, I have found that there is a high peak in GPU memory usage when calling The solution I propose is to use numpy ops for EDITED: |
88b511c
to
1e74d09
Compare
@james77777778 awesome that does seem worth it. Is it noticeably slower? Because of the extra conversions? Let's make sure to test jax too. I think for jax you need to specify |
Hey @mattdangerw
Actually, it is faster in the following tiny example. I haven't check the time cost on LLMs but my rough guess is that it should not be noticeably slower because we call
I have verified that this PR reduces the peak GPU memory requirement in both JAX & TF. (We don't need to specify import time
from keras import backend
from keras import layers
from keras import models
def get_memory_profile():
if backend.backend() == "tensorflow":
import tensorflow as tf
peak = tf.config.experimental.get_memory_info("GPU:0")["peak"]
print(f"Peak memory: {peak / 1024 / 1024:.2f} MB")
elif backend.backend() == "jax":
import jax
peak = jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]
print(f"Peak memory: {peak / 1024 / 1024:.2f} MB")
else:
raise NotImplementedError()
model = models.Sequential([layers.Input(shape=(1024,)), layers.Dense(1024)])
get_memory_profile()
st = time.time()
model.quantize("int8")
ed = time.time()
get_memory_profile()
print(f"`quantize`: {ed-st:.3f}s")
|
db7b2ac
to
a6fece4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just one comment.
Great to know about jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]
@@ -64,7 +66,24 @@ def abs_max_quantize( | |||
value_range=(-127, 127), | |||
dtype="int8", | |||
epsilon=backend.epsilon(), | |||
to_numpy=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why isn't to_numpy
the default if the outputs are the same and it performers better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abs_max_quantize
is used not only in quantize
, but also in AbsMaxQuantizer
which is called in _int8_call
.
We cannot use numpy for that call.
As suggested by @mattdangerw in keras-team/keras-hub#1670 (comment)
I have tried my best to simplify the logic of quantization-related methods, especially
quantize
, by using a wrapper.Please let me know if we can further simplify the implementation.
CAUTION:We might need to modify the code in KerasNLP (ReversibleEmbedding
) if this PR is merged and released.EDITED:
The compatibility has been guaranteed by KerasNLP.
EDITED2:
Apart from the simplification, I have found that there is a high peak in GPU memory usage when calling
quantize
. The root cause might bekeras.quantizers.abs_max_quantize
.The solution I propose is to use numpy ops for
quantize
. This should be feasible because we only call it in eager mode.