-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Disambiguate quantized types via a new ScalarType #6396
[Misc] Disambiguate quantized types via a new ScalarType #6396
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Full CI run is still required to merge this PR so please make sure that you run full CI before merging or if you need more test signals. To run full CI, you can do one of these:
🚀 |
aaddfd2
to
5cf3968
Compare
/ready |
06f5d9e
to
85c1afb
Compare
Have you checked how the |
@mgoin purposed a new naming scheme to more closely match PyTorch (reiterating it here for record keeping), i.e. for floating point types, do
where the flags continue with the current style set out by https://github.com/jax-ml/ml_dtypes and for integer types do:
overall I think this is a good suggestion, outside of
as for the C++ constexpr's I think we'd want to keep a version of the current ones to align with the "rust" style in found in the PyTorch C++ code: but we could also add alias like: to align with the "fixed width dtypes" styler here: https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L46-L57 |
@bnellnm how would I test that?, not very familiar with dynamo |
I think the easiest way would be to add a
This runs without error in current main. I haven't tried a ton of other models but I think they should mostly work w/o errors too. |
@bnellnm thanks! do you know of any quantized models that work on main? I tried both:
and both fail when I add |
also tried:
it also fails, I dont think marlin supports dynamo as is:
|
That fails for me also. It's on my list of things to fix. Are those the only models using the I've been testing using the following quantized models:
|
just any model using marlin or marlin2:4 (i.e. weight only quantized models) |
@bnellnm thanks for the help, with updated schema (note this is the schema for this PR):
it ran through torch dynamo without issue 👍 |
f06b40c
to
2b9ef42
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR makes sense to me.
@WoosukKwon @simon-mo it would be better if you could also take a look as this PR also changes the build logic a bit.
775049e
to
1d90d74
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
int64_t const bias; // stored values equal value + bias, | ||
// used for quantized type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lament: I wish we had a better unambiguous name for this, but don't have any good suggestions. For sure bias is better than zero point though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha, ya I feel you, ive warmed up to bias a bit since it kinda mirrors exponent bias in IEEE-754, but I agree its still a bit ambiguous
e31dd1f
to
a926e67
Compare
a926e67
to
36ee4f4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super solid work, I like the refactored marlin utilities! No real comments aside from being unsure about the naming of zp/bias/something else, but bias is definitely fine for now.
static inline constexpr auto kHalf = kFE5M10; | ||
static inline constexpr auto kFloat16 = kHalf; | ||
static inline constexpr auto kBFloat16 = kFE8M7; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for these!
…t#6396) Signed-off-by: Alvant <alvasian@yandex.ru>
Summary / Background
Currently when invoking kernels that use quantized types (namely the marlin family of kernels) a
num_bits
parameter is use to determine if the weights are stored as 8 bit integers or 4 bit integers. Given that there is now a fp8 marlin this now ambiguous (2 8bit types supported), the current solution to get around this ambiguity is to copy and paste much of the marlin code and have a separate python entrypoint (i.e.fp8_marlin_gemm
). This ambiguity is further compounded that the marlin kernels implicitly assume "gptq" types, i.e. unsigned integers with symmetric zero point and fold this zero point upconvert into the marlin code at compile time (this may not always be the case if we add runtime zero point support to marlin, in which case the the weights will be stored as normal unsigned integers with zero points passed in separately). This ambiguity is only set to increase as we investigate other lower precision types such as FP4 or FP6.This PR seeks to resolve this be adding a
ScalarType
class that exists in both C++ and Python. This class will able to represent the type information for any floating point or integer type (within reason), and will also be able to represent types that have 'baked in' (compile time) zero points. In this PR we refer to the compile-time zero points as a "bias" in an attempt to not cause confusion with runtime static zero-points (zero-points that are static w.r.t the model definition but are not known at C++ compile time). This also mirrors the way the term bias is used for exponent biases in IEEE 754. I am open to opinions/suggestions around this naming.Naming:
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]
flags:for integer types the scheme is:
[u]int<size_bits>[b<bias>]
Some Examples:
GPTQ 4bit symmetric uses:
uint4b8
: i.e. unsigned 4bit integer with a bias of 8 (i.e. symmetric zero-point)FP8 in vLLM uses
float8_e4m3fn
(PyTorch also supports,float8_e5m2
)FP6 uses
float6_e3m2f
: i.e. 3bit exponent, 2bit mantissa, finite values only, no-nansim open to suggestions around this naming scheme, currently
no-nans + infs
(no trailing letters) would be ambiguous with IEEE-754 (no trailing letters), currently there are no use cases for this though (kinda a odd combo so I dont really see this case happening)Goals:
quantize_weights
has been updated to generically useScalarType
so it can now supportuint4b8
,uint8b128
,uint4
,int4
(and potentially more but those are the only ones that have been tested), before it implicitly assumeduint4b8
anduint8b128
num_bits == 8
was ambiguous betweenuint8b128
andfloat8_e4m3fn
(i.e. fp8), withScalarType
and some additional dispatching logic we should be able to merge this implementations using C++ templating (sinceScalarType
can be constexpr we can even template specialize on it once the custom extensions migrate to C++20, example of template dispatch), the ultimate goal here would be to mergemarlin_gemm
,fp8_marlin_gemm
,gptq_marlin_gemm
intomarlin_gemm
with a type parameter (future PR)uint4b8
,uint8b128
,float8_e4m3fn
for weights, the new version will adduint4
,int4
(as well as an expanded set of activation types). MarlinV2 will also benefit from the more generalized quantization utilities in this PR for unit testing / benchmarkingFp8KVCacheDataType
Challenges:
Since this is intended to be a platform-agnostic utility, I created a new
_core_C
Torch extension that will be built on almost all platforms. This required some reordering in theCMakeLists.txt
because the "cpu" target performs an early return within cmake.However, I encountered issues compiling on Neuron (buildkite reports "Your installed Caffe2 version uses CUDA but I cannot find the CUDA libraries ..." during cmake (inside
find_package(Torch REQUIRED)
, run here) and for TPU (buildkite reports nocmake
). For these targets, I resorted to using a fallback partial Python implementation forScalarType
. This means that on these targets, not allScalarType
methods will not work and the class cannot be passed to C++. Currently, I believe these limitations are not a significant concern sinceScalarType
is primarily intended for use with custom quantized gemms, which are not available on these targets. As for the latter, ifScalarType
is being used as an argument passed to C++ then this implies we are compile custom extensions on these targets (something we are not currently doing) and in theory the_core_C
build issues should be solved.