Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Disambiguate quantized types via a new ScalarType #6396

Merged

Conversation

LucasWilkinson
Copy link
Contributor

@LucasWilkinson LucasWilkinson commented Jul 12, 2024

Summary / Background

Currently when invoking kernels that use quantized types (namely the marlin family of kernels) a num_bits parameter is use to determine if the weights are stored as 8 bit integers or 4 bit integers. Given that there is now a fp8 marlin this now ambiguous (2 8bit types supported), the current solution to get around this ambiguity is to copy and paste much of the marlin code and have a separate python entrypoint (i.e. fp8_marlin_gemm). This ambiguity is further compounded that the marlin kernels implicitly assume "gptq" types, i.e. unsigned integers with symmetric zero point and fold this zero point upconvert into the marlin code at compile time (this may not always be the case if we add runtime zero point support to marlin, in which case the the weights will be stored as normal unsigned integers with zero points passed in separately). This ambiguity is only set to increase as we investigate other lower precision types such as FP4 or FP6.

This PR seeks to resolve this be adding a ScalarType class that exists in both C++ and Python. This class will able to represent the type information for any floating point or integer type (within reason), and will also be able to represent types that have 'baked in' (compile time) zero points. In this PR we refer to the compile-time zero points as a "bias" in an attempt to not cause confusion with runtime static zero-points (zero-points that are static w.r.t the model definition but are not known at C++ compile time). This also mirrors the way the term bias is used for exponent biases in IEEE 754. I am open to opinions/suggestions around this naming.

Naming:

naming generally follows: https://github.com/jax-ml/ml_dtypes

for floating point types (leading f) the scheme is:
float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags] flags:

  • no-flags: means it follows IEEE 754 conventions
  • f: means finite values only (no infinities)
  • n: means nans are supported (non-standard encoding)

for integer types the scheme is:
[u]int<size_bits>[b<bias>]

  • if bias is not present it means its zero

Some Examples:

GPTQ 4bit symmetric uses: uint4b8: i.e. unsigned 4bit integer with a bias of 8 (i.e. symmetric zero-point)
FP8 in vLLM uses float8_e4m3fn (PyTorch also supports, float8_e5m2)
FP6 uses float6_e3m2f: i.e. 3bit exponent, 2bit mantissa, finite values only, no-nans

im open to suggestions around this naming scheme, currently no-nans + infs (no trailing letters) would be ambiguous with IEEE-754 (no trailing letters), currently there are no use cases for this though (kinda a odd combo so I dont really see this case happening)

Goals:

  • enable writing of more generalized quantization utilities
    • Example: in this PR, quantize_weights has been updated to generically use ScalarType so it can now support uint4b8, uint8b128, uint4, int4 (and potentially more but those are the only ones that have been tested), before it implicitly assumed uint4b8 and uint8b128
  • enable more seamless expansion of quantized gemm type support while maintaining a common interface
    • Example: marlin was expanded to use fp8 but code had to be duplicated since num_bits == 8 was ambiguous between uint8b128 and float8_e4m3fn (i.e. fp8), with ScalarType and some additional dispatching logic we should be able to merge this implementations using C++ templating (since ScalarType can be constexpr we can even template specialize on it once the custom extensions migrate to C++20, example of template dispatch), the ultimate goal here would be to merge marlin_gemm, fp8_marlin_gemm, gptq_marlin_gemm into marlin_gemm with a type parameter (future PR)
    • Example: currently I am working on a successor to marlin (marlinV2, working name) optimized for Hopper and is based on cutlass, this version should support an expanded set of types of compared to the current marlin which supports uint4b8, uint8b128, float8_e4m3fn for weights, the new version will add uint4, int4 (as well as an expanded set of activation types). MarlinV2 will also benefit from the more generalized quantization utilities in this PR for unit testing / benchmarking
    • Example: this could potentially replace Fp8KVCacheDataType

Challenges:

Since this is intended to be a platform-agnostic utility, I created a new _core_C Torch extension that will be built on almost all platforms. This required some reordering in the CMakeLists.txt because the "cpu" target performs an early return within cmake.

However, I encountered issues compiling on Neuron (buildkite reports "Your installed Caffe2 version uses CUDA but I cannot find the CUDA libraries ..." during cmake (inside find_package(Torch REQUIRED), run here) and for TPU (buildkite reports no cmake). For these targets, I resorted to using a fallback partial Python implementation for ScalarType. This means that on these targets, not all ScalarType methods will not work and the class cannot be passed to C++. Currently, I believe these limitations are not a significant concern since ScalarType is primarily intended for use with custom quantized gemms, which are not available on these targets. As for the latter, if ScalarType is being used as an argument passed to C++ then this implies we are compile custom extensions on these targets (something we are not currently doing) and in theory the _core_C build issues should be solved.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only trigger fastcheck CI to run, which consists only a small and essential subset of tests to quickly catch small errors.

Full CI run is still required to merge this PR so please make sure that you run full CI before merging or if you need more test signals.

To run full CI, you can do one of these:

  • Add ready label to the PR
  • Comment /ready on the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/scalar-type-cherrypick branch 3 times, most recently from aaddfd2 to 5cf3968 Compare July 17, 2024 18:04
@LucasWilkinson
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 17, 2024
@LucasWilkinson LucasWilkinson changed the title [Misc][WIP] Disambiguate quantized types via a new ScalarType [Misc] Disambiguate quantized types via a new ScalarType Jul 17, 2024
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/scalar-type-cherrypick branch 2 times, most recently from 06f5d9e to 85c1afb Compare July 18, 2024 04:42
@LucasWilkinson LucasWilkinson marked this pull request as ready for review July 18, 2024 18:01
@bnellnm
Copy link
Contributor

bnellnm commented Jul 18, 2024

Have you checked how the ScalarTypes affect dynamo? I'd like to make sure that the registered schemas still work properly and that dynamo can still trace models that use primitives using ScalarTypes.

@LucasWilkinson
Copy link
Contributor Author

@mgoin purposed a new naming scheme to more closely match PyTorch (reiterating it here for record keeping), i.e. for floating point types, do

float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]

where the flags continue with the current style set out by https://github.com/jax-ml/ml_dtypes

and for integer types do:

[u]int<size_bits>[b<bias>]

overall I think this is a good suggestion, outside of bfloat16 being a bit confusing, i.e.:

fE5M2   -> float8_e5m2
fE4M3fn -> float8_e4m3fn
fE5M10  -> float16_e5m10
fE8M7   -> float16_e8m7            <= this is bfloat16

as for the C++ constexpr's I think we'd want to keep a version of the current ones to align with the "rust" style in found in the PyTorch C++ code:
https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70

but we could also add alias like:
floating point types:
kFloat<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]
integer types:
k(UI | I)nt<size_bits>[b<bias>]

to align with the "fixed width dtypes" styler here: https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L46-L57

@LucasWilkinson
Copy link
Contributor Author

LucasWilkinson commented Jul 18, 2024

Have you checked how the ScalarTypes affect dynamo? I'd like to make sure that the registered schemas still work properly and that dynamo can still trace models that use primitives using ScalarTypes.

@bnellnm how would I test that?, not very familiar with dynamo

@bnellnm
Copy link
Contributor

bnellnm commented Jul 18, 2024

Have you checked how the ScalarTypes affect dynamo? I'd like to make sure that the registered schemas still work properly and that dynamo can still trace models that use primitives using ScalarTypes.

how would I test that, not very familiar with dynamo

I think the easiest way would be to add a @torch.compile(backend='eager') annotation to LlamaModel.forward (or a model uses a ScalarTyped op) and run a simple test with enforce_eager=True, e.g.

import torch

from vllm import LLM, SamplingParams

import vllm

# Sample prompts.                                                                                                                                                                                             
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create a sampling params object.                                                                                                                                                                            
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM                                                                                                                                                                                               
eager=True
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", enforce_eager=eager, gpu_memory_utilization=0.50)

# Generate texts from the prompts. The output is a list of RequestOutput objects                                                                                                                              
# that contain the prompt, generated text, and other information.                                                                                                                                             
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.                                                                                                                                                                                          
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

This runs without error in current main. I haven't tried a ton of other models but I think they should mostly work w/o errors too.

@LucasWilkinson
Copy link
Contributor Author

@bnellnm thanks! do you know of any quantized models that work on main? I tried both:

"neuralmagic/llama-2-7b-chat-marlin"
"neuralmagic/Llama-2-7b-chat-quantized.w4a16"

and both fail when I add @torch.compile(backend='eager')

@LucasWilkinson
Copy link
Contributor Author

also tried:

"neuralmagic/Llama-2-7b-chat-quantized.w8a16"

it also fails, I dont think marlin supports dynamo as is:

[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function _C.gptq_marlin_gemm(*(FakeTensor(..., device='cuda:0', size=(4096, 4096), dtype=torch.float16), Parameter(FakeTensor(..., device='cuda:0', size=(256, 49152), dtype=torch.int32)), Parameter(FakeTensor(..., device='cuda:0', size=(1, 12288), dtype=torch.float16)), Parameter(FakeTensor(..., device='cuda:0', size=(0,), dtype=torch.int32)), Parameter(FakeTensor(..., device='cuda:0', size=(0,), dtype=torch.int32)), FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.int32), 8, 4096, 12288, 4096, True), **{}):
[rank0]: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

@bnellnm
Copy link
Contributor

bnellnm commented Jul 18, 2024

also tried:

"neuralmagic/Llama-2-7b-chat-quantized.w8a16"

it also fails, I dont think marlin supports dynamo as is:

[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function _C.gptq_marlin_gemm(*(FakeTensor(..., device='cuda:0', size=(4096, 4096), dtype=torch.float16), Parameter(FakeTensor(..., device='cuda:0', size=(256, 49152), dtype=torch.int32)), Parameter(FakeTensor(..., device='cuda:0', size=(1, 12288), dtype=torch.float16)), Parameter(FakeTensor(..., device='cuda:0', size=(0,), dtype=torch.int32)), Parameter(FakeTensor(..., device='cuda:0', size=(0,), dtype=torch.int32)), FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.int32), 8, 4096, 12288, 4096, True), **{}):
[rank0]: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

That fails for me also. It's on my list of things to fix. Are those the only models using the ScalarType atm?

I've been testing using the following quantized models:

nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change
neuralmagic/Meta-Llama-3-8B-Instruct-FP8
nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples

@LucasWilkinson
Copy link
Contributor Author

Are those the only models using the ScalarType atm?

just any model using marlin or marlin2:4 (i.e. weight only quantized models)

@LucasWilkinson
Copy link
Contributor Author

@bnellnm thanks for the help, with updated schema (note this is the schema for this PR):

  ops.def(
      "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor "
      "g_idx, Tensor perm, Tensor! workspace, "
      "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, int "
      "size_n, int size_k, bool is_k_full) -> Tensor");

it ran through torch dynamo without issue 👍

csrc/core/scalar_type.hpp Outdated Show resolved Hide resolved
csrc/core/scalar_type.hpp Outdated Show resolved Hide resolved
vllm/_core_ext.py Outdated Show resolved Hide resolved
vllm/_core_ext.pyi Outdated Show resolved Hide resolved
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/scalar-type-cherrypick branch from f06b40c to 2b9ef42 Compare July 22, 2024 14:19
Copy link
Contributor

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR makes sense to me.
@WoosukKwon @simon-mo it would be better if you could also take a look as this PR also changes the build logic a bit.

csrc/core/scalar_type.hpp Outdated Show resolved Hide resolved
csrc/core/scalar_type.hpp Outdated Show resolved Hide resolved
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/scalar-type-cherrypick branch from 775049e to 1d90d74 Compare July 31, 2024 04:44
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Comment on lines +72 to +73
int64_t const bias; // stored values equal value + bias,
// used for quantized type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lament: I wish we had a better unambiguous name for this, but don't have any good suggestions. For sure bias is better than zero point though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, ya I feel you, ive warmed up to bias a bit since it kinda mirrors exponent bias in IEEE-754, but I agree its still a bit ambiguous

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/scalar-type-cherrypick branch from a926e67 to 36ee4f4 Compare August 2, 2024 18:28
Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super solid work, I like the refactored marlin utilities! No real comments aside from being unsure about the naming of zp/bias/something else, but bias is definitely fine for now.

Comment on lines +378 to +380
static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for these!

@simon-mo simon-mo merged commit a8d604c into vllm-project:main Aug 2, 2024
61 of 63 checks passed
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants