Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the implementation of quantization-related methods #19954

Merged
merged 8 commits into from
Jul 23, 2024

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Jul 4, 2024

As suggested by @mattdangerw in keras-team/keras-hub#1670 (comment)
I have tried my best to simplify the logic of quantization-related methods, especially quantize, by using a wrapper.

Please let me know if we can further simplify the implementation.

CAUTION:
We might need to modify the code in KerasNLP (ReversibleEmbedding) if this PR is merged and released.

EDITED:
The compatibility has been guaranteed by KerasNLP.

EDITED2:
Apart from the simplification, I have found that there is a high peak in GPU memory usage when calling quantize. The root cause might be keras.quantizers.abs_max_quantize.

The solution I propose is to use numpy ops for quantize. This should be feasible because we only call it in eager mode.

from keras import backend
from keras import layers
from keras import models


def get_memory_profile():
    if backend.backend() == "tensorflow":
        import tensorflow as tf

        peak = tf.config.experimental.get_memory_info("GPU:0")["peak"]
        print(f"Peak memory: {peak / 1024 / 1024:.2f} MB")
    elif backend.backend() == "jax":
        import jax

        peak = jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]
        print(f"Peak memory: {peak / 1024 / 1024:.2f} MB")
    else:
        raise NotImplementedError()


model = models.Sequential([layers.Input(shape=(1024,)), layers.Dense(1024)])
get_memory_profile()

model.quantize("int8")
get_memory_profile()
Branch Backend Peak Mem.
master TensorFlow 12.00MB -> 23.95MB
JAX 12.00MB -> 16.01MB
PR TensorFlow 12.00MB -> 12.00MB
JAX 12.00MB -> 12.00MB

@codecov-commenter
Copy link

codecov-commenter commented Jul 4, 2024

Codecov Report

Attention: Patch coverage is 96.00000% with 3 lines in your changes missing coverage. Please review.

Project coverage is 79.25%. Comparing base (0ed820f) to head (a6fece4).
Report is 241 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/layer.py 93.61% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19954      +/-   ##
==========================================
+ Coverage   74.29%   79.25%   +4.96%     
==========================================
  Files         500      500              
  Lines       46867    46862       -5     
  Branches     8651     8628      -23     
==========================================
+ Hits        34819    37142    +2323     
+ Misses      10364     7980    -2384     
- Partials     1684     1740      +56     
Flag Coverage Δ
keras 79.11% <96.00%> (+4.93%) ⬆️
keras-jax 62.43% <93.33%> (?)
keras-numpy 57.51% <86.66%> (+0.05%) ⬆️
keras-tensorflow 63.77% <93.33%> (-0.03%) ⬇️
keras-torch 62.50% <96.00%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! @mattdangerw what do you think?

keras/src/layers/layer.py Outdated Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM to me.

Does this change the API contract between Keras & KerasNLP again? Since the releases are not coupled, let's give ourselves some wiggle room. Simple API surfaces from Keras is good, defensive usage from KerasNLP good.

We might even want to do things like this form KerasNLP, so we don't crash on earlier Keras versions:

if getattr(self, "quantization_mode", None) == "int8":
    more quantization code

The potential version cross product always gets confusing, but flexibility on the keras version will probably save us bug reports post release where people get the versioning wrong.

We could also consider a hard version dep from KerasNLP to Keras, but those can come with their own headache.

keras/src/layers/layer.py Outdated Show resolved Hide resolved
keras/src/layers/layer.py Outdated Show resolved Hide resolved
@james77777778
Copy link
Contributor Author

james77777778 commented Jul 10, 2024

Does this change the API contract between Keras & KerasNLP again?

This is an important point that I missed previously. Thanks for pointing out.

I have made the change so that quantize_wrapper is free from side effects. Therefore, the current KerasNLP should compatible with both this and earlier versions of Keras.

Additionally, I have added def __delattr__(self, name) in Layer. This should be cleaner and make more sense, as __setattr__ and __delattr__ become corresponding functions. This should not affect KerasNLP either.

We might even want to do things like this form KerasNLP, so we don't crash on earlier Keras versions:

if getattr(self, "quantization_mode", None) == "int8":
    more quantization code

The potential version cross product always gets confusing, but flexibility on the keras version will probably save us bug reports post release where people get the versioning wrong.

We could also consider a hard version dep from KerasNLP to Keras, but those can come with their own headache.

Considering quantization is a newly introduced feature, it should be acceptable to require users to update their Keras version to call quantize.
A simple fix for this could be adding an if block to the overridden quantize to check the Keras version.

WDYT?

EDITED:
Missed that quantization_mode lives in build...
I will try to propose a fix for that in KerasNLP. However, we don't have a version-based CI. Should we add one or just test it locally?
For the compatibility fix of KerasNLP, you can check the details in this PR: keras-team/keras-hub#1690

@james77777778
Copy link
Contributor Author

I have added a compatibility test to ensure that we don't break the API contract between Keras and KerasNLP.

The idea of the test is that I copied the implementation of ReversibleEmbedding to run the quantization-related tests.

@mattdangerw
Copy link
Member

Considering quantization is a newly introduced feature, it should be acceptable to require users to update their Keras version to call quantize.

Yeah I think that's fine as well. We just want to avoid breaking users that aren't using the feature at all. If we can without too much code complexity.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me if we remove the sleep call and maybe just ditch the compat test. But adding a sleep call really sounds like it will cause lots of pain down the line -- let's avoid it.

I'll leave to @fchollet for final review!

keras/src/layers/compatibility_test.py Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jul 19, 2024
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jul 19, 2024
@james77777778
Copy link
Contributor Author

james77777778 commented Jul 19, 2024

Apart from the simplification, I have found that there is a high peak in GPU memory usage when calling quantize. The root cause might be keras.quantizers.abs_max_quantize.

The solution I propose is to use numpy ops for quantize. This should be feasible because we only call quantize in eager mode.
What do you think? @mattdangerw @fchollet

EDITED:
Please refer to #19954 (comment)

@mattdangerw
Copy link
Member

@james77777778 awesome that does seem worth it. Is it noticeably slower? Because of the extra conversions? Let's make sure to test jax too. I think for jax you need to specify XLA_PYTHON_CLIENT_PREALLOCATE=false and XLA_PYTHON_CLIENT_ALLOCATOR=platform to get it to stop just grabbing most available GPU memory.

@james77777778
Copy link
Contributor Author

james77777778 commented Jul 20, 2024

Hey @mattdangerw

Is it noticeably slower?

Actually, it is faster in the following tiny example. I haven't check the time cost on LLMs but my rough guess is that it should not be noticeably slower because we call quantize in an eager manner, which doesn't benefit from XLA.

Let's make sure to test jax too.

I have verified that this PR reduces the peak GPU memory requirement in both JAX & TF.

(We don't need to specify XLA_PYTHON_CLIENT_PREALLOCATE=false and XLA_PYTHON_CLIENT_ALLOCATOR=platform for the following script)

import time

from keras import backend
from keras import layers
from keras import models


def get_memory_profile():
    if backend.backend() == "tensorflow":
        import tensorflow as tf

        peak = tf.config.experimental.get_memory_info("GPU:0")["peak"]
        print(f"Peak memory: {peak / 1024 / 1024:.2f} MB")
    elif backend.backend() == "jax":
        import jax

        peak = jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]
        print(f"Peak memory: {peak / 1024 / 1024:.2f} MB")
    else:
        raise NotImplementedError()


model = models.Sequential([layers.Input(shape=(1024,)), layers.Dense(1024)])
get_memory_profile()

st = time.time()
model.quantize("int8")
ed = time.time()
get_memory_profile()
print(f"`quantize`: {ed-st:.3f}s")
Branch Backend Peak Mem. Time Cost
master TensorFlow 12.00MB -> 23.95MB 0.170s
JAX 12.00MB -> 16.01MB 0.399s
PR TensorFlow 12.00MB -> 12.00MB 0.113s
JAX 12.00MB -> 12.00MB 0.096s

Tested on AMD R7 7700 + RTX4070

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just one comment.

Great to know about jax.local_devices()[0].memory_stats()["peak_bytes_in_use"]

@@ -64,7 +66,24 @@ def abs_max_quantize(
value_range=(-127, 127),
dtype="int8",
epsilon=backend.epsilon(),
to_numpy=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't to_numpy the default if the outputs are the same and it performers better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abs_max_quantize is used not only in quantize, but also in AbsMaxQuantizer which is called in _int8_call.
We cannot use numpy for that call.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jul 22, 2024
@fchollet fchollet merged commit 3ac43b1 into keras-team:master Jul 23, 2024
11 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Jul 23, 2024
@james77777778 james77777778 deleted the improve-quantize-2 branch September 4, 2024 00:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

6 participants