Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exllama GPTQ CUDA kernel support #553

Closed
wants to merge 27 commits into from

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Jul 5, 2023

Examples:

text-generation-launcher --model-id Narsil/starcoder-gptq --port 8080 --num-shard 2 --quantize gptq
GPTQ_GROUPSIZE=128 GPTQ_BITS=4 text-generation-launcher --model-id TheBloke/WizardLM-7B-uncensored-GPTQ --port 8080 --num-shard 2 --quantize gptq
pytest integration-tests/models/test_flash_llama_gptq.py -s
pytest integration-tests/models/test_flash_starcoder_gptq.py -s

This PR adds to TGI the mixed precision int4/fp16 kernels from the excellent exllama repo, that from my benchmark is much better than the implementations available in autogptq & gptq-for-llama.

On batch size 1, for starcoder with starcoder & GPTQ-4bit-no-actorder, we get a x2.1 speedup on the prefill over GPTQ-triton, and x1.8 speedup on the decode over GPTQ-triton. I'll have a look at the peak memory.

I verified locally that logits match.

Note that exllama implementation can not be used with act-order & tp rank>=2 for row tensor parallel linear, because exllama reorders weights ahead of runtime, requiring to reorder the activation as well (which are split on several GPUs for row parallel + TP rank>=2). In this specific case, we default to the trition implementation (that is much slower because reordering is done one the scales/zero points, and each weight row need to have its own specific scale/zero point).

Exllama implementation is specifically for n_bits = 4. Thus, for the other cases we fall back on the triton kernel.

Results on starcoder are as follow (TP rank = 2, A100, before vllm):

Parameter Value
Model Narsil/starcoder-gptq
Sequence Length 512
Decode Length 512
N Runs 10
Warmups 1
Temperature None
Top K None
Top P None
Typical P None
Repetition Penalty None
Watermark false
Do Sample false
GPTQ (current):
Step Batch Size Average Lowest Highest p50 p90 p99
Prefill 1 139.00 ms 138.31 ms 139.93 ms 139.06 ms 139.93 ms 139.93 ms
4 522.17 ms 521.71 ms 522.57 ms 522.16 ms 522.57 ms 522.57 ms
8 1016.23 ms 1015.78 ms 1016.87 ms 1016.10 ms 1016.87 ms 1016.87 ms
Decode (token) 1 36.19 ms 36.07 ms 36.54 ms 36.11 ms 36.54 ms 36.54 ms
4 36.65 ms 36.57 ms 36.74 ms 36.64 ms 36.74 ms 36.74 ms
8 36.72 ms 36.55 ms 36.94 ms 36.72 ms 36.94 ms 36.94 ms
Decode (total) 1 18491.31 ms 18433.56 ms 18671.24 ms 18451.02 ms 18671.24 ms 18671.24 ms
4 18728.41 ms 18689.53 ms 18776.16 ms 18724.24 ms 18776.16 ms 18776.16 ms
8 18762.84 ms 18678.29 ms 18875.27 ms 18763.06 ms 18875.27 ms 18875.27 ms
Step Batch Size Average Lowest Highest
Prefill 1 7.19 tokens/secs 7.15 tokens/secs 7.23 tokens/secs
4 7.66 tokens/secs 7.65 tokens/secs 7.67 tokens/secs
8 7.87 tokens/secs 7.87 tokens/secs 7.88 tokens/secs
Decode 1 27.64 tokens/secs 27.37 tokens/secs 27.72 tokens/secs
4 109.14 tokens/secs 108.86 tokens/secs 109.37 tokens/secs
8 217.88 tokens/secs 216.58 tokens/secs 218.86 tokens/secs
GPTQ-CUDA (exllama):
Step Batch Size Average Lowest Highest p50 p90 p99
Prefill 1 65.49 ms 64.76 ms 65.81 ms 65.58 ms 65.81 ms 65.81 ms
4 190.36 ms 189.25 ms 194.22 ms 190.25 ms 194.22 ms 194.22 ms
8 350.45 ms 349.56 ms 353.83 ms 350.07 ms 353.83 ms 353.83 ms
Decode (token) 1 19.69 ms 18.90 ms 21.20 ms 18.99 ms 21.20 ms 21.20 ms
4 30.51 ms 30.42 ms 30.64 ms 30.49 ms 30.58 ms 30.58 ms
8 34.73 ms 34.56 ms 34.80 ms 34.76 ms 34.80 ms 34.80 ms
Decode (total) 1 10061.35 ms 9659.25 ms 10835.27 ms 9705.76 ms 10835.27 ms 10835.27 ms
4 15592.28 ms 15547.03 ms 15659.89 ms 15582.29 ms 15626.84 ms 15626.84 ms
8 17749.09 ms 17661.60 ms 17781.76 ms 17760.34 ms 17781.76 ms 17781.76 ms
Step Batch Size Average Lowest Highest
Prefill 1 15.27 tokens/secs 15.20 tokens/secs 15.44 tokens/secs
4 21.01 tokens/secs 20.59 tokens/secs 21.14 tokens/secs
8 22.83 tokens/secs 22.61 tokens/secs 22.89 tokens/secs
Decode 1 50.91 tokens/secs 47.16 tokens/secs 52.90 tokens/secs
4 131.09 tokens/secs 130.52 tokens/secs 131.47 tokens/secs
8 230.32 tokens/secs 229.90 tokens/secs 231.46 tokens/secs

Before submitting

  • Doc ==> There's no doc currently right?
  • Tests ==> Done for starcoder and llama

@fxmarty fxmarty requested a review from Narsil July 5, 2023 16:37
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat numbers !

I feel like gptq and gptq-cuda is not necessary here.

IIUC, both can run on the same weights (as you didn't change the conversion script).
Therefore, we could simply use exllama kernel whenever available (when g_idx is increasing).

That should simplify the codebase a lot.

Also nothing should be modified in model files, everything should be very agnostic to it, especially since the weights are exactly the same on disk.

This could also be explained in the gptq script (act-order, if True, more precise models, but slower inference because different kernels, if False, lower precision, but ffaster inference).

I still fail to understand why we cannot reorder on load to use exllama for act-order (since we can reslice at will in the original tensors, we could probably de-entagle g_idx again.

It's a lot more work certainly.

Comment on lines 481 to 504
# Buffers need to be persistent to avoid any bug.
self.buffers = {}
if config.quantize == "gptq-cuda":
max_dq_buffer_size = 0
for name, submodule in self.named_modules():
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)

intermediate_size = config.n_inner
max_seq_len = 2048 # TODO: we should be able to set it

self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=weights.device)
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=weights.device)

prepare_buffers(weights.device, self.buffers["temp_state"], self.buffers["temp_dq"])

# TODO: ability to set them
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

torch.cuda.empty_cache()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should go directly in the loading part (within weights).
That ways it's truly agnostic to models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to Model init. This requires to have access to model.config which is currently not defined though. There is model.transformer.config, or model.gpt_neox.config or model.model.config depending on the architecture. Is it intended that the config is not registered at the top level? @OlivierDehaene @Narsil

The thing is that the weights = Weights(...) call is in each model definition, and we need to have loaded all weights to determine the shapes of the buffers. Also, the buffers need to be persistent, while I think this weights object is not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffers intend to be shared, no ?

So why not just have a single location for these buffers, use the pointer on every layer, and increase the size every time max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8) is larger ?

The issue with this (and any post loading treatment) is that you're now dealing with updating every single model file, any time of those line hits. This is what we had before and it was painful to maintain.

This seems to be used as globals let's just use them as globals. (They are temporary buffers IIUC preallocated to avoid reallocating them all the time)

server/text_generation_server/utils/layers.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/weights.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/weights.py Outdated Show resolved Hide resolved
@fxmarty fxmarty requested a review from Narsil July 12, 2023 18:44
@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 12, 2023

For some reason the logprob slightly differ between different runs, there's a source of randomicity I've still yet not identified.

Edit: comes from the atomicAdd of the kernel - this is fine.

I'll add llama support in this PR too.

@Narsil
Copy link
Collaborator

Narsil commented Jul 14, 2023

atomicAdd + randomicity

This is very suspicious, really ? Isn't the purpose of atomicAdd to remove randomness by forcing access order ? :)
It may well be acceptable though :)

@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 17, 2023

@Narsil
Copy link
Collaborator

Narsil commented Jul 18, 2023

Ahhh that level of randomness ! :) I see, yeah totally legit source of "randomness".

@fxmarty
Copy link
Contributor Author

fxmarty commented Jul 19, 2023

Some tests up to date with main on llama 2 70b

image

Narsil
Narsil previously approved these changes Jul 20, 2023
Narsil added a commit that referenced this pull request Jul 21, 2023
Just trying to get the integration tests to pass.


# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
@Narsil
Copy link
Collaborator

Narsil commented Jul 21, 2023

Closing as superseeded by #666

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants