-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPTQ support for Gemma #3200
Conversation
thank you for the PR. can you confirm the original model still work? |
Sure! The test code is as follows: from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0, max_tokens=256)
llm = LLM(model="google/gemma-2b")
outputs = llm.generate(["def min(arr):\n"], sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(prompt + generated_text) The output result is: def min(arr):
min = arr[0]
for i in range(1, len(arr)):
if arr[i] < min:
min = arr[i]
return min
def max(arr):
max = arr[0]
for i in range(1, len(arr)):
if arr[i] > max:
max = arr[i]
return max
def main():
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
print(min(arr))
print(max(arr))
if __name__ == '__main__':
main() |
vllm/model_executor/models/gemma.py
Outdated
# GemmaRMSNorm is different from Llama's in that it multiplies | ||
# (1 + weight) to the output, instead of just weight. | ||
if "norm.weight" in name: | ||
loaded_weight += 1.0 | ||
if name == "lm_head.weight" and name not in params_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question: Why this check is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gemma's input and output weights are shared and there is no separate lm_head
parameter.
There will be no problem when loading the original model, but when loading the quantized model, lm_head
is not included in params_dict
. I skip it to load the model normally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But gptq output is right when removing this check.🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reinstalling from the latest main branch of transformers and AutoGPTQ, I found that removing them works fine. Thanks for testing!
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you merge this check with the one above and place it outside the for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copy them from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L377-L379 and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L385-L387.
Modifying the for-else
structure may cause excessive modifications. I recommend reusing existing code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TechxGenus I think we can do this optimization on Gemma
model file. Just put this check onto for (param_name, shard_name, shard_id) in stacked_params_mapping:
without changing for-else
code part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TechxGenus I think we can do this optimization on
Gemma
model file. Just put this check ontofor (param_name, shard_name, shard_id) in stacked_params_mapping:
without changingfor-else
code part.
The quantized versions of gemma-2b and gemma-7b work fine, but will cause problems for future models with attention_bias
set to true
.
vllm/model_executor/models/gemma.py
Outdated
if name == "lm_head.weight" and name not in params_dict: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gemma's input and output weights are shared and there is no separate lm_head
parameter.
There will be no problem when loading the original model, but when loading the quantized model, lm_head
is not included in params_dict
. I skip it to load the model normally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reinstalling from the latest main branch of transformers and AutoGPTQ, I found that removing them works fine. Thanks for testing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution!
The latest version of AutoGPTQ already supports Gemma. Add some minor modifications can provide support for vllm.
My test code is as follow:
The output result is: