-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft] Avoid loading model weights before recipe application if any #2230
base: main
Are you sure you want to change the base?
Conversation
""" | ||
Takes a loaded Pytorch model and applies any structural changes such as quantization | ||
to the model, then reloads the model. | ||
|
||
:param model: PyTorch model to apply structure to | ||
:param recipe_path: path to recipe to apply to the model | ||
:param model_path: path to model, used for reloading the state dict | ||
:param reload_weights: flag to reload the weights after applying the recipe. | ||
Dafault is True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dafault is True. | |
Default is True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
@@ -130,12 +135,27 @@ def skip(*args, **kwargs): | |||
compressor.overwrite_weights(model_path=model_path, model=model) | |||
|
|||
recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path) | |||
|
|||
# this must be done before recipe is applied |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, why? how does this modify the state of the model?
Also @rahul-tuli, the correct implementantion of this PR should make this part of
redundant! |
Peviously when
SparseAutoModelForCausalLM.from_pretrained(...)
was called the weights were loaded in twice, once duringmodel = super(AutoModelForCausalLM, cls).from_pretrained(...)
and then again after recipe application, which is undesirable.This PR updates the flow to use
from_config(...)
over from_pretrained, which initializes a model with init weight data, after recipe application the actual trained weights are loaded back in.More info on from_config: https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#transformers.AutoModel.from_config
initial effort was to accomplish this with
accelerate.init_empty weights
but we run into https://discuss.huggingface.co/t/error-the-model-weights-are-not-tied-please-use-the-tie-weights-method-before-using-the-infer-auto-device-function-even-after-adding-model-tie-weights/46325 issue with quantized models.Tests: Tested loading dense, sparse and quantized checkpoints which load just fine
Test script: