-
Notifications
You must be signed in to change notification settings - Fork 10
Initial CompressedTensors
config + Activation Quantization support for static W8A8 per tensor
#195
Conversation
CompressedTensors
config support
Use cutlass kernels. --------- Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
36f302b
to
6868f97
Compare
CompressedTensors
config supportCompressedTensors
config + Activation Quantization support
CompressedTensors
config + Activation Quantization supportCompressedTensors
config + Activation Quantization support - DO NOT MERGE
CompressedTensors
config + Activation Quantization support - DO NOT MERGECompressedTensors
config + Activation Quantization support
CompressedTensors
config + Activation Quantization support CompressedTensors
config + Activation Quantization support for static W8A8 per tensor
Description: Remove logging triggers a device-to-host copy. --------- Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nits. Generally looks good. Biggest things are:
- Clean up the
StaticW8A8Scheme
- Remove unnecessary cuda stuff
@varun-sundar-rabindranath can you do a quick audit of the kernels and let us know which can be removed?
vllm/model_executor/layers/quantization/compressed_tensors/data/quantization_args.py
Outdated
Show resolved
Hide resolved
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
Show resolved
Hide resolved
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
Outdated
Show resolved
Hide resolved
@@ -16,6 +18,7 @@ | |||
"gptq": GPTQConfig, | |||
"squeezellm": SqueezeLLMConfig, | |||
"marlin": MarlinConfig, | |||
"sparseml": CompressedTensorsConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why sparseml
and not compressed-tensors
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To comply with the vllm input handling. The quantization method listed in the sparsmel model config is sparseml
so we can change this if we change the value listed in the config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you coordinate with Sara on this?
I think it should be compressed-tensors
in the HF config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with the update, but it is currently called "sparseml" in the compressed-tensors repo so it will need to be updated there too
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
Show resolved
Hide resolved
Description: - rename `csrc/quantization/smoothquant/fused_kernels.cu` -> `csrc/quantization/compressed_tensors/int8_quant_kernels.cu` - Remove `csrc/attention/dtype_int8.cuh` - Remove unused quant_per_token kernel. Rename `ops.quant` to `ops.quant_per_tensor` - Remove unused `quant_utils.cuh` - Remove unused `blockReduceMax` code from reduction_utils.cuh --------- Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
…for static W8A8 per tensor (#195) - Depending on how we end up parsing `ignore` and `targets` (layer_name vs layer_type) we may not need layer_name to be added to the linear_method. Will experiment using a compressed-tensors function in a follow-up PR - Initial implementation for Compressed Config support + Activation Quantization for static per tensor w8a8 - Includes fused kernels added by @varun-sundar-rabindranath ```python from vllm import LLM, SamplingParams import torch prompts = [ "Hello, my name is", "The capital of France is", "The US president is", "The future of AI is" ] sampling_params = SamplingParams(temperature=0.80, top_p=0.95) llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml") outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - Verification of the different inputs expected for `targets` and `ignore` --> use functions to parse the layer names which can be shared by sparseml and vllm; would live in compressed tensors (https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86) - Updates to further optimize fake qunat --------- Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Adding
layer_name
ignore
andtargets
(layer_name vs layer_type) we may not need layer_name to be added to the linear_method. Will experiment using a compressed-tensors function in a follow-up PRSummary
Testing/Sample Script:
Next Steps:
targets
andignore
--> use functions to parse the layer names which can be shared by sparseml and vllm; would live in compressed tensors (https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86)