Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Initial Activation Quantization Support #4525

Merged
merged 49 commits into from
May 23, 2024

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented May 1, 2024

Summary

  • Initial support for Activation Quantization (specifically static-per tensor for W8A8)
  • Adds CompressedTensorsConfig and CompressedTensorsLinearMethod to support models quantized through sparseml and saved through compressed-tensors
  • Adds a new optional layer_name parameter to create_weights. The layer_name can be used to match the appropriate quantization scheme from the CompressedTensorsConfig for a given layer
  • Adds a static-per-tensor quant kernel (Inspired and refactored from Support W8A8 inference in vllm #1508)
  • Use the nvidia-cutlass python interface to invoke a fused GEMM+dequant kernel.

From Neural Magic, Co-authored by @varun-sundar-rabindranath @robertgshaw2-neuralmagic

dsikka and others added 8 commits April 30, 2024 18:50
…for static W8A8 per tensor (#195)

- Depending on how we end up parsing `ignore` and `targets` (layer_name
vs layer_type) we may not need layer_name to be added to the
linear_method. Will experiment using a compressed-tensors function in a
follow-up PR

- Initial implementation for Compressed Config support + Activation
Quantization for static per tensor w8a8
- Includes fused kernels added by @varun-sundar-rabindranath

```python
from vllm import LLM, SamplingParams
import torch

prompts = [
    "Hello, my name is",
    "The capital of France is",
    "The US president is",
    "The future of AI is"
]
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)

llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml")

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

- Verification of the different inputs expected for `targets` and
`ignore` --> use functions to parse the layer names which can be shared
by sparseml and vllm; would live in compressed tensors
(https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86)
- Updates to further optimize fake qunat

---------

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@dsikka dsikka marked this pull request as ready for review May 1, 2024 14:18
dsikka and others added 6 commits May 1, 2024 14:20
vllm CI fixes

---------

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
lazy cutlass_gemm_dq import

---------

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO the layer_name approach is simple and effective, but it also creates more complexity to model implementers. Ideally we should match the scheme automatically after the model is initialized (but before weight loading). In this case we need to make all parameters in meta tensor (like a placeholder) until the weights are actually loaded. In this way we can change data type and don't have to worry about memory footprint.

csrc/pybind.cpp Outdated Show resolved Hide resolved
@@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unclear to me about the name compressed_tensors. I suppose this is the official method name of SparseML? Then can we just use sparseml here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compressed-tensors is the name of the package responsible for saving quantized and sparse models

So the flow is:

  • use SparseML to apply quantization / sparsity
  • save model to safetensors with a compressed-tensors config
  • load + run in vllm

@@ -403,6 +440,13 @@ def weight_loader(self,
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)

# If a param_shard_splitter is defined by the LinearMethod, use it.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does the same thing as scale_shard_splitter we had for fp8 ... we can rename to match fp8

but yes this will be addressed by the refactor

@@ -403,6 +440,13 @@ def weight_loader(self,
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)

# If a param_shard_splitter is defined by the LinearMethod, use it.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does the same thing as scale_shard_splitter we had for fp8 ... we can rename to match fp8

but yes this will be addressed by the refactor

vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compressed-tensors is the name of the package responsible for saving quantized and sparse models

So the flow is:

  • use SparseML to apply quantization / sparsity
  • save model to safetensors with a compressed-tensors config
  • load + run in vllm

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented May 2, 2024

IMHO the layer_name approach is simple and effective, but it also creates more complexity to model implementers. Ideally we should match the scheme automatically after the model is initialized (but before weight loading). In this case we need to make all parameters in meta tensor (like a placeholder) until the weights are actually loaded. In this way we can change data type and don't have to worry about memory footprint.

Per our slack discussion:

Plan is to refactor weight_loading logic generically (separate from this PR) with a flow that looks like this:

model = init_model(...) # parameters are in meta tensors
for key, val in scheme:
    mod = find_module_by_name(model, key)
    config_module(mod, val)
...
weight_loading(model, ckpt)

This is similar to how we do things in SparseML / HF. This would also enable lack of memory savings for fp8

Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test

@Yard1 Yard1 reopened this May 2, 2024
@bnellnm
Copy link
Contributor

bnellnm commented May 14, 2024

@dsikka can you add some tests for the new functionality? Can any of the tests from #1508 be reused/adapted?

import torch
from torch.nn import Parameter

# TODO (varun) : Unify ops and custom ops
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be left as a TODO and instead be done before the PR is merged -- it is a very small amount of work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. This skipped my radar. This is fixed now. Thanks for catching it.

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
assert(input.is_contiguous());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both of these asserts should be TORCH_CHECK so the interpreter doesn't crash if this gets triggered

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

Varun Sundar Rabindranath added 2 commits May 22, 2024 20:21
static constexpr float dt_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = round(x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note - This rounding doesn't match Compressed-tensors's/Torch's/Numpy's rounding method. To fix this I have a patch at neuralmagic#263 - this will be merged before this lands.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

@pcmoritz
Copy link
Collaborator

Thanks for removing the layer name until the weight refactor is ready @dsikka

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic enabled auto-merge (squash) May 23, 2024 19:31
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit a124232 into vllm-project:main May 23, 2024
63 checks passed
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 31, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 14, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants