Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPTQ support for Gemma #3200

Merged
merged 2 commits into from
Mar 7, 2024

Conversation

TechxGenus
Copy link
Contributor

The latest version of AutoGPTQ already supports Gemma. Add some minor modifications can provide support for vllm.

My test code is as follow:

from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0, max_tokens=256)
llm = LLM(model="TechxGenus/gemma-2b-GPTQ")

outputs = llm.generate(["def min(arr):\n"], sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(prompt + generated_text)

The output result is:

def min(arr):
    min = arr[0]
    for i in range(1, len(arr)):
        if arr[i] < min:
            min = arr[i]
    return min


def max(arr):
    max = arr[0]
    for i in range(1, len(arr)):
        if arr[i] > max:
            max = arr[i]
    return max


def main():
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(min(arr))
    print(max(arr))


if __name__ == "__main__":
    main()

@simon-mo
Copy link
Collaborator

simon-mo commented Mar 5, 2024

thank you for the PR. can you confirm the original model still work?

@TechxGenus
Copy link
Contributor Author

Sure! The test code is as follows:

from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0, max_tokens=256)
llm = LLM(model="google/gemma-2b")

outputs = llm.generate(["def min(arr):\n"], sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(prompt + generated_text)

The output result is:

def min(arr):
    min = arr[0]
    for i in range(1, len(arr)):
        if arr[i] < min:
            min = arr[i]
    return min

def max(arr):
    max = arr[0]
    for i in range(1, len(arr)):
        if arr[i] > max:
            max = arr[i]
    return max

def main():
    arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(min(arr))
    print(max(arr))

if __name__ == '__main__':
    main()

# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
loaded_weight += 1.0
if name == "lm_head.weight" and name not in params_dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question: Why this check is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma's input and output weights are shared and there is no separate lm_head parameter.
There will be no problem when loading the original model, but when loading the quantized model, lm_head is not included in params_dict. I skip it to load the model normally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But gptq output is right when removing this check.🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reinstalling from the latest main branch of transformers and AutoGPTQ, I found that removing them works fine. Thanks for testing!

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you merge this check with the one above and place it outside the for loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copy them from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L377-L379 and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L385-L387.
Modifying the for-else structure may cause excessive modifications. I recommend reusing existing code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TechxGenus I think we can do this optimization on Gemma model file. Just put this check onto for (param_name, shard_name, shard_id) in stacked_params_mapping: without changing for-else code part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TechxGenus I think we can do this optimization on Gemma model file. Just put this check onto for (param_name, shard_name, shard_id) in stacked_params_mapping: without changing for-else code part.

The quantized versions of gemma-2b and gemma-7b work fine, but will cause problems for future models with attention_bias set to true.

Comment on lines 343 to 344
if name == "lm_head.weight" and name not in params_dict:
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma's input and output weights are shared and there is no separate lm_head parameter.
There will be no problem when loading the original model, but when loading the quantized model, lm_head is not included in params_dict. I skip it to load the model normally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reinstalling from the latest main branch of transformers and AutoGPTQ, I found that removing them works fine. Thanks for testing!

Copy link
Collaborator

@esmeetu esmeetu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

@esmeetu esmeetu merged commit d3c04b6 into vllm-project:main Mar 7, 2024
22 checks passed
dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants