Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Initial CompressedTensors config + Activation Quantization support for static W8A8 per tensor #195

Merged
merged 30 commits into from
Apr 30, 2024

Conversation

dsikka
Copy link

@dsikka dsikka commented Apr 18, 2024

Adding layer_name

  • Depending on how we end up parsing ignore and targets (layer_name vs layer_type) we may not need layer_name to be added to the linear_method. Will experiment using a compressed-tensors function in a follow-up PR

Summary

  • Initial implementation for Compressed Config support + Activation Quantization for static per tensor w8a8
  • Includes fused kernels added by @varun-sundar-rabindranath

Testing/Sample Script:

from vllm import LLM, SamplingParams
import torch

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The capital of France is",
    "The US president is",
    "The future of AI is"
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)

# Create an LLM.
llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml")

outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Next Steps:

@dsikka dsikka marked this pull request as ready for review April 24, 2024 16:57
@dsikka dsikka changed the title [WIP] sparseml compression config support [WIP] Initial CompressedTensors config support Apr 24, 2024
@dsikka dsikka force-pushed the compression_config branch from 36f302b to 6868f97 Compare April 24, 2024 19:25
@dsikka dsikka changed the title [WIP] Initial CompressedTensors config support Initial CompressedTensors config + Activation Quantization support Apr 24, 2024
@dsikka dsikka changed the title Initial CompressedTensors config + Activation Quantization support Initial CompressedTensors config + Activation Quantization support - DO NOT MERGE Apr 24, 2024
@dsikka dsikka changed the base branch from upstream-main to ds-quant April 24, 2024 21:42
@dsikka dsikka changed the title Initial CompressedTensors config + Activation Quantization support - DO NOT MERGE Initial CompressedTensors config + Activation Quantization support Apr 24, 2024
@dsikka dsikka changed the title Initial CompressedTensors config + Activation Quantization support Initial CompressedTensors config + Activation Quantization support for static W8A8 per tensor Apr 25, 2024
Description:
 Remove logging triggers a device-to-host copy.

---------

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nits. Generally looks good. Biggest things are:

  • Clean up the StaticW8A8Scheme
  • Remove unnecessary cuda stuff

@varun-sundar-rabindranath can you do a quick audit of the kernels and let us know which can be removed?

csrc/pybind.cpp Outdated Show resolved Hide resolved
csrc/attention/dtype_int8.cuh Outdated Show resolved Hide resolved
csrc/attention/dtype_float32.cuh Outdated Show resolved Hide resolved
csrc/reduction_utils.cuh Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
@@ -16,6 +18,7 @@
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
"sparseml": CompressedTensorsConfig
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why sparseml and not compressed-tensors?

Copy link
Author

@dsikka dsikka Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To comply with the vllm input handling. The quantization method listed in the sparsmel model config is sparseml so we can change this if we change the value listed in the config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you coordinate with Sara on this?

I think it should be compressed-tensors in the HF config

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with the update, but it is currently called "sparseml" in the compressed-tensors repo so it will need to be updated there too

csrc/quantization/smoothquant/quant_utils.cuh Outdated Show resolved Hide resolved
dsikka and others added 7 commits April 29, 2024 14:49
Description:
- rename `csrc/quantization/smoothquant/fused_kernels.cu` ->
`csrc/quantization/compressed_tensors/int8_quant_kernels.cu`
 - Remove `csrc/attention/dtype_int8.cuh`
- Remove unused quant_per_token kernel. Rename `ops.quant` to
`ops.quant_per_tensor`
 - Remove unused `quant_utils.cuh`
 - Remove unused `blockReduceMax` code from reduction_utils.cuh

---------

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@dsikka dsikka merged this pull request into ds-quant Apr 30, 2024
@dsikka dsikka deleted the compression_config branch April 30, 2024 17:10
dsikka added a commit that referenced this pull request Apr 30, 2024
…for static W8A8 per tensor (#195)

- Depending on how we end up parsing `ignore` and `targets` (layer_name
vs layer_type) we may not need layer_name to be added to the
linear_method. Will experiment using a compressed-tensors function in a
follow-up PR

- Initial implementation for Compressed Config support + Activation
Quantization for static per tensor w8a8
- Includes fused kernels added by @varun-sundar-rabindranath

```python
from vllm import LLM, SamplingParams
import torch

prompts = [
    "Hello, my name is",
    "The capital of France is",
    "The US president is",
    "The future of AI is"
]
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)

llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml")

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

- Verification of the different inputs expected for `targets` and
`ignore` --> use functions to parse the layer names which can be shared
by sparseml and vllm; would live in compressed tensors
(https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86)
- Updates to further optimize fake qunat

---------

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants