From e950acf01253a52c2d30ab69dae82909d9f7847e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 2 Aug 2024 16:51:58 -0400 Subject: [PATCH] [Misc] Disambiguate quantized types via a new ScalarType (#6396) --- CMakeLists.txt | 52 ++- Dockerfile.openvino | 3 + benchmarks/kernels/benchmark_marlin.py | 50 +-- cmake/cpu_extension.cmake | 1 - csrc/{ => core}/registration.h | 0 csrc/core/scalar_type.hpp | 382 ++++++++++++++++++ csrc/core/torch_bindings.cpp | 16 + csrc/cpu/torch_bindings.cpp | 2 +- csrc/moe/torch_bindings.cpp | 2 +- csrc/ops.h | 8 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 66 ++- .../marlin/sparse/marlin_24_cuda_kernel.cu | 17 +- csrc/torch_bindings.cpp | 2 +- setup.py | 9 +- tests/kernels/test_int8_quant.py | 2 - tests/kernels/test_marlin_gemm.py | 75 ++-- tests/test_scalartype.py | 36 ++ vllm/_core_ext.py | 177 ++++++++ vllm/_custom_ops.py | 29 +- .../layers/quantization/awq_marlin.py | 49 ++- .../schemes/compressed_tensors_w4a16_24.py | 18 +- .../schemes/compressed_tensors_wNa16.py | 29 +- .../layers/quantization/gptq_marlin.py | 43 +- .../layers/quantization/gptq_marlin_24.py | 29 +- .../layers/quantization/utils/marlin_utils.py | 120 +++--- .../quantization/utils/marlin_utils_test.py | 29 +- .../utils/marlin_utils_test_24.py | 30 +- .../layers/quantization/utils/quant_utils.py | 148 +++---- vllm/scalar_type.py | 35 ++ 29 files changed, 1107 insertions(+), 352 deletions(-) rename csrc/{ => core}/registration.h (100%) create mode 100644 csrc/core/scalar_type.hpp create mode 100644 csrc/core/torch_bindings.cpp create mode 100644 tests/test_scalartype.py create mode 100644 vllm/_core_ext.py create mode 100644 vllm/scalar_type.py diff --git a/CMakeLists.txt b/CMakeLists.txt index dbe688186f17f..922613ec5ddaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,39 @@ endif() # find_package(Torch REQUIRED) +# +# Add the `default` target which detects which extensions should be +# built based on platform/architecture. This is the same logic that +# setup.py uses to select which extensions should be built and should +# be kept in sync. +# +# The `default` target makes direct use of cmake easier since knowledge +# of which extensions are supported has been factored in, e.g. +# +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. +# cmake --build . --target default +# +add_custom_target(default) +message(STATUS "Enabling core extension.") + +# Define _core_C extension +# built for (almost) every target platform, (excludes TPU and Neuron) + +set(VLLM_EXT_SRC + "csrc/core/torch_bindings.cpp") + +define_gpu_extension_target( + _core_C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + USE_SABI 3 + WITH_SOABI) + +add_dependencies(default _core_C) + # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -74,7 +107,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND if (VLLM_TARGET_DEVICE STREQUAL "cpu") include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) else() - message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + return() endif() return() endif() @@ -132,7 +165,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") endif() # -# Define extension targets +# Define other extension targets # # @@ -228,21 +261,6 @@ define_gpu_extension_target( -# -# Add the `default` target which detects which extensions should be -# built based on platform/architecture. This is the same logic that -# setup.py uses to select which extensions should be built and should -# be kept in sync. -# -# The `default` target makes direct use of cmake easier since knowledge -# of which extensions are supported has been factored in, e.g. -# -# mkdir build && cd build -# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. -# cmake --build . --target default -# -add_custom_target(default) - if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 7c62dd845aa99..c84dea419e58a 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -13,6 +13,9 @@ COPY requirements-common.txt /workspace/vllm/ COPY requirements-openvino.txt /workspace/vllm/ COPY vllm/ /workspace/vllm/vllm +COPY csrc/core /workspace/vllm/csrc/core +COPY cmake/utils.cmake /workspace/vllm/cmake/ +COPY CMakeLists.txt /workspace/vllm/ COPY setup.py /workspace/vllm/ # install build requirements diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 684985b81f690..536c133bb3341 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -7,16 +7,17 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS) + MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, marlin_quantize) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, quantize_weights, sort_weights) + gptq_pack, gptq_quantize_weights, sort_weights) +from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] @@ -27,13 +28,14 @@ def bench_run(results: List[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, num_bits: int, group_size: int, - size_m: int, size_k: int, size_n: int): + act_order: bool, is_k_full: bool, quant_type: ScalarType, + group_size: int, size_m: int, size_k: int, size_n: int): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, b={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, - group_size, size_m, size_k, size_n)) + sub_label = ("{}, act={} k_full={}, q={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, + str(quant_type), group_size, size_m, + size_k, size_n)) print(f"Testing: {sub_label}") @@ -50,18 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_g_idx, marlin_sort_indices, marlin_rand_perm, - ) = marlin_quantize(b, num_bits, group_size, act_order) + ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant (w_ref, q_w, s, g_idx, - rand_perm) = quantize_weights(b, num_bits, group_size, act_order) - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" # so that group ids are increasing @@ -75,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) globals = { # Gen params - "num_bits": num_bits, + "quant_type": quant_type, "group_size": group_size, "size_m": size_m, "size_n": size_n, @@ -128,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -138,19 +141,19 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp32", ).blocked_autorange(min_run_time=min_run_time)) - if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): results.append( benchmark.Timer( stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -160,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -196,9 +199,10 @@ def main(args): ) > 0 and is_k_full not in args.limit_k_full: continue - for num_bits in MARLIN_SUPPORTED_NUM_BITS: - if len(args.limit_num_bits - ) > 0 and num_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types( + False): + if len(args.limit_num_bits) > 0 and \ + quant_type.size_bits not in args.limit_num_bits: continue for group_size in MARLIN_SUPPORTED_GROUP_SIZES: @@ -215,8 +219,8 @@ def main(args): for size_m in args.batch_sizes: bench_run(results, model, act_order, is_k_full, - num_bits, group_size, size_m, size_k, - size_n) + quant_type, group_size, size_m, + size_k, size_n) compare = benchmark.Compare(results) compare.print() diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 118f9b28e0ae3..3ba3a2b6a93cd 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -113,6 +113,5 @@ define_gpu_extension_target( WITH_SOABI ) -add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) diff --git a/csrc/registration.h b/csrc/core/registration.h similarity index 100% rename from csrc/registration.h rename to csrc/core/registration.h diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp new file mode 100644 index 0000000000000..9f78402eee2a7 --- /dev/null +++ b/csrc/core/scalar_type.hpp @@ -0,0 +1,382 @@ +#pragma once + +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// ScalarTypeTorch is a subclass of ScalarType that is compatible with +// TORCH_LIBRARY, making it accessible from Python as well meaning this class +// can be used as a argument for custom operators, helping to simplify these +// interfaces. +// +// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : int64_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa, + int64_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + bias(bias), + signed_(signed_), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) { + return ScalarType(true, 0, size_bits - 1, bias); + } + + static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) { + return ScalarType(false, 0, size_bits, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(int64_t exponent, + int64_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(true, exponent, mantissa, 0, finite_values_only, + nan_repr); + } + + int64_t const exponent; // size of the exponent field (0 for integer types) + int64_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + int64_t const bias; // stored values equal value + bias, + // used for quantized type + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + int64_t size_bits() const { return mantissa + exponent + is_signed(); } + bool is_signed() const { return signed_; } + bool is_integer() const { return exponent == 0; } + bool is_floating_point() const { return exponent > 0; } + bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } + bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from +// torch::CustomClassHolder), we use multiple inheritance here since we cannot +// have ScalarType inherit from torch::CustomClassHolder and have a constexpr +// constructor at the same time (torch::CustomClassHolder does not have a +// constexpr destructor) +class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { + public: + ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, + bool _signed) + : ScalarType(exponent, mantissa, bias, _signed){}; + + ScalarTypeTorch(ScalarType type) : ScalarType(type){}; + + using Base = ScalarType; + using Self = ScalarTypeTorch; + using SelfPtr = c10::intrusive_ptr; + + static SelfPtr int_(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::int_(size_bits, bias.value_or(0))); + } + + static SelfPtr uint(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::uint(size_bits, bias.value_or(0))); + } + + static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { + return c10::make_intrusive( + ScalarType::float_IEEE754(exponent, mantissa)); + } + + static SelfPtr float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, int64_t nan_repr) { + return c10::make_intrusive(ScalarType::float_( + exponent, mantissa, finite_values_only, NanRepr(nan_repr))); + } + + template + static void bind_readonly_property(torch::class_& cls, + std::string const& name, T Base::*field) { + auto getter_func = [field = std::move(field)](SelfPtr const& self) { + if constexpr (std::is_member_function_pointer_v) { + return (self.get()->*field)(); + } else { + return self.get()->*field; + } + }; + + cls.def_property(name, getter_func); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + MemberFunc Cls::*member) { + cls.def(name, [member = std::move(member)](SelfPtr const& self) { + return (self.get()->*member)(); + }); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + Func func) { + cls.def(name, func); + } + + template + static void bind_static_function(torch::class_& cls, + const std::string& name, Func func) { + cls.def_static(name, func); + } + + static void bind_class(torch::Library& lib) { + auto cls = lib.class_("ScalarType") + .def(torch::init()); + + // Bind Properties + bind_readonly_property(cls, "mantissa", &Base::mantissa); + bind_readonly_property(cls, "exponent", &Base::exponent); + bind_readonly_property(cls, "bias", &Base::bias); + bind_readonly_property(cls, "signed", &Base::is_signed); + bind_readonly_property(cls, "size_bits", &Base::size_bits); + + // Bind member functions + bind_function(cls, "is_signed", &Base::is_signed); + bind_function(cls, "is_integer", &Base::is_integer); + bind_function(cls, "is_floating_point", &Base::is_floating_point); + bind_function(cls, "is_ieee_754", &Base::is_ieee_754); + bind_function(cls, "has_nans", &Base::has_nans); + bind_function(cls, "has_infs", &Base::has_infs); + bind_function(cls, "has_bias", &Base::has_bias); + + bind_function(cls, "max", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->max()); + }); + bind_function(cls, "min", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->min()); + }); + + bind_function(cls, "__str__", &Base::str); + bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { + return *self == *other; + }); + bind_function(cls, "__repr__", [](SelfPtr const& self) { + return "ScalarType." + self.get()->str(); + }); + + // Bind static functions (convenience constructors) + bind_static_function(cls, "int_", &ScalarTypeTorch::int_); + bind_static_function(cls, "uint", &ScalarTypeTorch::uint); + bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754); + bind_static_function(cls, "float_", &ScalarTypeTorch::float_); + } +}; + +using ScalarTypeTorchPtr = c10::intrusive_ptr; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +}; // namespace vllm diff --git a/csrc/core/torch_bindings.cpp b/csrc/core/torch_bindings.cpp new file mode 100644 index 0000000000000..f60254189a2f7 --- /dev/null +++ b/csrc/core/torch_bindings.cpp @@ -0,0 +1,16 @@ +#include + +#include "scalar_type.hpp" +#include "registration.h" + +// Note the CORE exstension will be built for (almost) all hardware targets so +// new additions must account for this. (currently not built for TPU and Neuron) + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) { + // ScalarType, a custom class for representing data types that supports + // quantized types, declared here so it can be used when creating interfaces + // for custom ops. + vllm::ScalarTypeTorch::bind_class(lib); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7d549e271a30d..cf7d977da7c1c 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -1,6 +1,6 @@ #include "cache.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 243752b9a9e8c..86e42af44df15 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,4 +1,4 @@ -#include "registration.h" +#include "core/registration.h" #include "moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { diff --git a/csrc/ops.h b/csrc/ops.h index f274a7e647b95..3bd4a9eda5ee3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -3,6 +3,8 @@ #include #include +#include "core/scalar_type.hpp" + void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, @@ -84,14 +86,16 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k); torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce); diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 26cc248e6ac5d..edf19365c8098 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -21,6 +21,7 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" +#include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -71,14 +72,15 @@ __global__ void Marlin( bool use_fp32_reduce // whether to use fp32 global reduce ) {} -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { + bool is_k_full, bool has_zp) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -1963,18 +1965,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, - void* s, void* zp, void* g_idx, void* perm, void* a_tmp, - int prob_m, int prob_n, int prob_k, void* workspace, - int num_bits, bool has_act_order, bool is_k_full, - bool has_zp, int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par, bool use_fp32_reduce) { - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool use_fp32_reduce) { + if (has_zp) { + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + // TODO: remove alias when we start supporting other 8bit types + int num_bits = q_type.size_bits(); int tot_m = prob_m; int tot_m_blocks = div_ceil(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; @@ -2126,19 +2139,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, } } -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; + if (has_zp) { + TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", + b_q_type->str()); + } else { + TORCH_CHECK( + *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + b_q_type->str()); + } + + int pack_factor = 32 / b_q_type->size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -2265,21 +2287,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { - marlin::marlin_mm_f16i4( + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else if (a.scalar_type() == at::ScalarType::BFloat16) { - marlin::marlin_mm_f16i4( + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else { diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 3c50f1786bc68..93445a386593b 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -27,6 +27,7 @@ #include #include "common/base.h" +#include "core/scalar_type.hpp" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -86,7 +87,8 @@ __global__ void Marlin_24( torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + int pack_factor = 32 / b_q_type->size_bits(); // Verify M TORCH_CHECK(size_m == a.size(0), @@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, marlin_24::marlin_cuda_2_4( a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_m, sms, max_par); + b_q_type->size_bits(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); return c; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bf8cefa8d4713..7c0d617fc8b3b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -1,7 +1,7 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include diff --git a/setup.py b/setup.py index 91307e8a94062..b146299f8269d 100644 --- a/setup.py +++ b/setup.py @@ -271,6 +271,10 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() +def _build_core_ext() -> bool: + return not _is_neuron() and not _is_tpu() + + def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -433,6 +437,9 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] +if _build_core_ext(): + ext_modules.append(CMakeExtension(name="vllm._core_C")) + if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) @@ -477,7 +484,7 @@ def _read_requirements(filename: str) -> List[str]: extras_require={ "tensorizer": ["tensorizer>=2.9.0"], }, - cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, + cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, package_data=package_data, entry_points={ "console_scripts": [ diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 03acbf7968ff1..0b7ed26a39e1e 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,8 +1,6 @@ import pytest import torch -# ruff: noqa: F401 -import vllm._C from tests.kernels.quant_utils import ref_dynamic_per_token_quant from vllm._custom_ops import scaled_int8_quant diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index a9e34ac8a7aa8..2f58ffda21408 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -9,14 +9,14 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.qqq import ( MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS, - marlin_make_empty_g_idx, marlin_permute_scales) + MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, + marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( pack_fp8_to_int32) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -27,8 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp, - sort_weights) + awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -65,12 +64,13 @@ def rand_data(shape, dtype=torch.float16): reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, - mnk_factors): +def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, + act_order, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -95,11 +95,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, b_weight = rand_data((size_k, size_n)) # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, - group_size, act_order) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + b_weight, quant_type, group_size, act_order) # Pack to GPTQ format - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -108,8 +108,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - weight_perm = get_weight_perm(num_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -117,7 +118,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, sort_indices, size_k, size_n, - num_bits, + quant_type.size_bits, ) torch.cuda.synchronize() @@ -128,10 +129,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, +def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors @@ -150,22 +152,25 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, b_weight = rand_data((size_k, size_n)) # Quantize - w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits, - group_size) + w_ref, q_w, s, zp = quantize_weights(b_weight, + quant_type, + group_size, + zero_points=True) # Pack to AWQ format - q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) # Pack to Marlin format - weight_perm = get_weight_perm(num_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.awq_marlin_repack( q_w_awq, size_k, size_n, - num_bits, + quant_type.size_bits, ) torch.cuda.synchronize() @@ -176,7 +181,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @@ -185,7 +191,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, def test_gptq_marlin_gemm( k_chunk, n_chunk, - num_bits, + quant_type, group_size, mnk_factors, act_order, @@ -211,7 +217,7 @@ def test_gptq_marlin_gemm( b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, num_bits, group_size, act_order) + b_weight, quant_type, group_size, act_order) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) @@ -226,7 +232,7 @@ def test_gptq_marlin_gemm( g_idx, sort_indices, workspace.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], @@ -248,10 +254,10 @@ def test_gptq_marlin_gemm( reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) -@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, +def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors @@ -266,7 +272,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, b_weight = rand_data((size_k, size_n)) (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size) workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) @@ -279,7 +285,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, marlin_24_meta, marlin_24_s, workspace_24.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], @@ -371,14 +377,15 @@ def test_fp8_marlin_gemm( reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) def test_awq_marlin_gemm( k_chunk, n_chunk, - num_bits, + quant_type, group_size, mnk_factors, use_fp32_reduce, @@ -396,7 +403,7 @@ def test_awq_marlin_gemm( b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, num_bits, group_size) + b_weight, quant_type, group_size) g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) @@ -414,7 +421,7 @@ def test_awq_marlin_gemm( g_idx, sort_indices, workspace.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py new file mode 100644 index 0000000000000..1201aaa92ea89 --- /dev/null +++ b/tests/test_scalartype.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from vllm.scalar_type import scalar_types + + +@pytest.mark.parametrize("type_tuple", ( + (-8, 7, scalar_types.int4), + (0, 15, scalar_types.uint4), + (-8, 7, scalar_types.uint4b8), + (-128, 127, scalar_types.uint8b128), + (-28., 28., scalar_types.float6_e3m2f), + (torch.int8, scalar_types.int8), + (torch.uint8, scalar_types.uint8), + (torch.float8_e5m2, scalar_types.float8_e5m2), + (torch.float8_e4m3fn, scalar_types.float8_e4m3fn), + (torch.bfloat16, scalar_types.float16_e8m7), + (torch.float16, scalar_types.float16_e5m10), +), + ids=lambda x: str(x)) +def test_scalar_type_min_max(type_tuple): + print(type_tuple) + if len(type_tuple) == 3: + min, max, t = type_tuple + else: + torch_type, t = type_tuple + if torch_type.is_floating_point: + min = torch.finfo(torch_type).min + max = torch.finfo(torch_type).max + else: + min = torch.iinfo(torch_type).min + max = torch.iinfo(torch_type).max + + print(t, min, max, t.min(), t.max()) + assert min == t.min() + assert max == t.max() diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py new file mode 100644 index 0000000000000..e3b9fbb938915 --- /dev/null +++ b/vllm/_core_ext.py @@ -0,0 +1,177 @@ +import importlib.util +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) +core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +if TYPE_CHECKING or not core_C_available: + # On platforms were we cannot use/build the C++ core extension (i.e. namely + # neuron and tpu), we define the mock ScalarType class here that partially + # mimics the C++ ScalarType class. + # + # We also use this provide type signatures to the Python LSP for the methods + # in the C++ ScalarType class. So these type signatures should be kept + # in sync with csrc/core/scalar_type.hpp + + from dataclasses import dataclass + + @dataclass(frozen=True) + class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + _finite_values_only: bool = False + """ + Private: if NANs are supported, used `has_infs()` instead. + """ + + nan_repr: int = NanRepr.IEEE_754.value + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + @property + def size_bits(self): + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + ... + + def is_floating_point(self): + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self): + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self): + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self): + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self): + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + raise NotImplementedError + + def __repr__(self) -> str: + raise NotImplementedError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + return cls(size_bits - 1, size_bits, bias if bias else 0, True) + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + return cls(size_bits, size_bits, bias if bias else 0, False) + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True) + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: int): + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True, finite_values_only, + nan_repr) + +elif core_C_available: + try: + import vllm._core_C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm._core_C with %r", e) + + ScalarType = torch.classes._core_C.ScalarType diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6cd77f75cae8d..ad7e5bd199339 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -4,6 +4,7 @@ import torch +from vllm._core_ext import ScalarType from vllm.logger import init_logger logger = init_logger(__name__) @@ -220,10 +221,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # marlin_24 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, num_bits: int, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, num_bits, size_m, + workspace, b_q_type, size_m, size_n, size_k) @@ -279,14 +280,22 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) -def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, b_zeros: torch.Tensor, - g_idx: torch.Tensor, perm: torch.Tensor, - workspace: torch.Tensor, num_bits: int, size_m: int, - size_n: int, size_k: int, is_k_full: bool, has_zp: bool, - use_fp32_reduce: bool) -> torch.Tensor: +def gptq_marlin_gemm(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, num_bits, + g_idx, perm, workspace, b_q_type, size_m, size_n, size_k, is_k_full, has_zp, use_fp32_reduce) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 5ffbb8e854e87..2cc080608c7a9 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,11 +10,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, awq_to_marlin_zero_points, - check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_awq_marlin_supported, - verify_marlin_supports_shape) + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -22,20 +22,31 @@ class AWQMarlinConfig(QuantizationConfig): """Config class for AWQ Marlin""" + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + def __init__(self, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool) -> None: - self.weight_bits = weight_bits - self.pack_factor = 32 // self.weight_bits # packed into 32bits + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized - verify_awq_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - has_zp=self.has_zp) + if weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}") + + self.quant_type = self.TYPE_MAP[weight_bits] + + verify_marlin_supported(self.quant_type, + group_size=self.group_size, + has_zp=self.has_zp) def __repr__(self) -> str: - return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, " + return (f"AWQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"has_zp={self.has_zp}, " f"lm_head_quantized={self.lm_head_quantized})") @@ -110,11 +121,13 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): if (num_bits is None or group_size is None or has_zp is None): return False - return check_awq_marlin_supported( - num_bits=num_bits, - group_size=group_size, - has_zp=has_zp, - min_capability=cls.get_min_capability()) + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], + group_size=group_size, + has_zp=has_zp, + min_capability=cls.get_min_capability()) class AWQMarlinLinearMethod(LinearMethodBase): @@ -226,7 +239,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qweight, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. @@ -242,7 +255,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qzeros, size_k=layer.num_groups, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qzeros", marlin_zp) # Not-used @@ -263,7 +276,7 @@ def apply( g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, + quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index b8ffb22d7a89d..c1adfdb2980b6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsW4A16Sparse24"] -W4A16SPARSE24_SUPPORTED_BITS = [4] +W4A16SPARSE24_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, +} +W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): @@ -22,9 +26,15 @@ def __init__(self, group_size: Optional[int] = None): self.strategy = strategy self.group_size = group_size - self.num_bits = num_bits self.tile_size = 16 + if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + + self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] + if self.strategy == "group" and self.group_size is None: raise ValueError( "group_size must be given when using strategy group") @@ -43,7 +53,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - pack_factor = 32 // self.num_bits + pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes) qweight = Parameter( @@ -138,7 +148,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, size_n = scales.shape[1] output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, - workspace, self.num_bits, size_m, + workspace, self.quant_type, size_m, size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index a41962ccd66d8..b8880f7ac136f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -8,12 +8,17 @@ CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported, + marlin_permute_scales, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] -WNA16_SUPPORTED_BITS = [4, 8] +WNA16_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, +} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsWNA16(CompressedTensorsScheme): @@ -22,8 +27,8 @@ def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None): - self.num_bits = num_bits - self.pack_factor = 32 // self.num_bits + + self.pack_factor = 32 // num_bits self.strategy = strategy self.group_size: int @@ -37,10 +42,16 @@ def __init__(self, else: self.group_size = group_size + if num_bits not in WNA16_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + # Verify supported on platform. - verify_gptq_marlin_supported(num_bits=self.num_bits, - group_size=self.group_size, - is_sym=True) + verify_marlin_supported(quant_type=self.quant_type, + group_size=self.group_size) @classmethod def get_min_capability(cls) -> int: @@ -150,7 +161,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.num_bits) + num_bits=self.quant_type.size_bits) replace_tensor(layer, "weight_packed", marlin_qweight) # Permute scales from compressed-tensors format to marlin format. @@ -172,7 +183,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.num_bits, + wtype=self.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=True, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index bdcc9c3b4f0c5..4a11b14971076 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,11 +10,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full, + apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_gptq_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -22,6 +23,12 @@ class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool, lm_head_quantized: bool) -> None: if desc_act and group_size == -1: @@ -29,20 +36,23 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, # (since we have only one group per output channel) desc_act = False - self.weight_bits = weight_bits - self.pack_factor = 32 // self.weight_bits # packed into int32 + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act - self.is_sym = is_sym self.lm_head_quantized = lm_head_quantized + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + # Verify supported on platform. - verify_gptq_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - is_sym=self.is_sym) + verify_marlin_supported(quant_type=self.quant_type, + group_size=self.group_size) def __repr__(self) -> str: - return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act}, " f"lm_head_quantized={self.lm_head_quantized})") @@ -122,11 +132,12 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): or desc_act is None): return False - return check_gptq_marlin_supported( - num_bits=num_bits, - group_size=group_size, - is_sym=sym, - min_capability=cls.get_min_capability()) + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], + group_size=group_size, + min_capability=cls.get_min_capability()) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -293,7 +304,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -319,7 +330,7 @@ def apply( g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, + wtype=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index e708c4da95af3..cafd100a2f40c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ + scalar_types.uint4b8, scalar_types.uint8b128 +] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] -GPTQ_MARLIN_24_SUPPORTED_SYM = [True] class GPTQMarlin24Config(QuantizationConfig): @@ -31,14 +33,19 @@ def __init__( weight_bits: int, group_size: int, ) -> None: - self.weight_bits = weight_bits + quant_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }.get(weight_bits) + self.group_size = group_size # Verify - if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + if quant_type is None or \ + quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: raise ValueError( - f"Marlin_24 does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " + f"Marlin_24 does not support quant_type = {quant_type}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " "are supported.") if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( @@ -46,8 +53,10 @@ def __init__( f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " "are supported.") + self.quant_type = quant_type + # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.weight_bits + self.pack_factor = 32 // self.quant_type.size_bits # Tile size used by marlin kernels. self.tile_size = 16 @@ -66,8 +75,8 @@ def __init__( self.perm_len = 1024 def __repr__(self) -> str: - return "Marlin24Config(weight_bits={}, group_size={})".format( - self.weight_bits, self.group_size) + return "Marlin24Config(quant_type={}, group_size={})".format( + self.quant_type, self.group_size) @classmethod def get_name(cls) -> str: @@ -279,7 +288,7 @@ def apply( output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, workspace, - self.quant_config.weight_bits, + self.quant_config.quant_type, size_m, size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index b789ca20cadb3..6e84d36219361 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -5,6 +5,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types from .quant_utils import pack_cols, unpack_cols @@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 -MARLIN_SUPPORTED_NUM_BITS = [4, 8] MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # In case there is a performance issue with Marlin, the variable below can be @@ -22,76 +22,70 @@ USE_FP32_REDUCE_DEFAULT = True -def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, - min_capability: Optional[int], - has_zp: bool) -> Tuple[bool, Optional[str]]: - if min_capability is not None: +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types(has_zp: bool, + min_capability: Optional[int] = None): + if min_capability is None: major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor - if device_capability < min_capability: - return (False, "Marlin does not support device_capability = {}" - ", the min_capability required is {}".format( - device_capability, min_capability)) - - if num_bits not in MARLIN_SUPPORTED_NUM_BITS: - return (False, "Marlin does not support weight_bits = {}. " - "Only weight_bits = {} are supported.".format( - num_bits, MARLIN_SUPPORTED_NUM_BITS)) - - if group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return (False, "Marlin does not support group_size = {}. Only " - "group_sizes = {} are supported.".format( - group_size, MARLIN_SUPPORTED_GROUP_SIZES)) - - if not has_zp and not is_sym: - return (False, - "Marlin without zero_points must have symmetric quantization") + min_capability = major * 10 + minor - return True, None + if min_capability < 80: + return [] + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] -def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool, - min_capability: int) -> bool: - cond, _ = _check_marlin_supported(num_bits, - group_size, - is_sym, - min_capability, - has_zp=False) - return cond +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: -def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool, - min_capability: int) -> bool: - cond, _ = _check_marlin_supported(num_bits, - group_size, - False, - min_capability, - has_zp=has_zp) - return cond + if min_capability is None: + major, minor = current_platform.get_device_capability() + min_capability = major * 10 + minor + supported_types = query_marlin_supported_quant_types( + has_zp, min_capability) -def verify_gptq_marlin_supported(num_bits: int, group_size: int, - is_sym: bool) -> None: - cond, err_msg = _check_marlin_supported(num_bits, - group_size, - is_sym, - min_capability=None, - has_zp=False) - if not cond: - assert err_msg is not None - raise ValueError("GPTQ" + err_msg) + if quant_type not in supported_types: + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"min_capability = {min_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + min_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + min_capability) + return cond -def verify_awq_marlin_supported(num_bits: int, group_size: int, - has_zp: bool) -> None: - cond, err_msg = _check_marlin_supported(num_bits, - group_size, - False, - min_capability=None, - has_zp=has_zp) +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None - raise ValueError("AWQ" + err_msg) + raise ValueError(err_msg) def verify_marlin_supports_shape(output_size_per_partition: int, @@ -245,7 +239,7 @@ def apply_gptq_marlin_linear( g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, - num_bits: int, + wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, @@ -261,7 +255,7 @@ def apply_gptq_marlin_linear( g_idx, g_idx_sort_indices, workspace, - num_bits, + wtype, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, @@ -283,7 +277,7 @@ def apply_awq_marlin_linear( g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, - num_bits: int, + quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, bias: Optional[torch.Tensor] = None, @@ -298,7 +292,7 @@ def apply_awq_marlin_linear( g_idx, g_idx_sort_indices, workspace, - num_bits, + quant_type, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 541d148c761fc..7d08ac6f87469 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -5,10 +5,12 @@ import numpy as np import torch +from vllm.scalar_type import ScalarType + from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points) -from .quant_utils import (get_pack_factor, quantize_weights, - quantize_weights_with_zp, sort_weights) +from .quant_utils import (get_pack_factor, gptq_quantize_weights, + quantize_weights, sort_weights) class MarlinWorkspace: @@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, +def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, act_order: bool): size_k, size_n = w.shape + num_bits = quant_type.size_bits # Normalize group_size if group_size == -1: @@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, assert group_size <= size_k # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, - act_order) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, return res_list -def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, + group_size: int): size_k, size_n = w.shape # Normalize group_size @@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): num_groups = size_k // group_size # Quantize with zp - w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size) + w_ref, q_w, s, zp = quantize_weights(w, + quant_type, + group_size, + zero_points=True) # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, + quant_type.size_bits) # Create result res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 648c32249a571..17d09055b1eac 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -6,8 +6,10 @@ import numpy import torch +from vllm.scalar_type import ScalarType + from .marlin_utils_test import marlin_weights -from .quant_utils import quantize_weights +from .quant_utils import gptq_quantize_weights # This is PyTorch implementation of main part of reorder_meta() @@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False): print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") -def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): assert q_24.shape == (size_k, size_n) - # Remove zp to normalize over 0 - max_q_val = (1 << num_bits) - 1 - zp = (max_q_val + 1) // 2 - q_24_no_zp = q_24 - zp + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias # Compress q_24_no_zp = q_24_no_zp.t().contiguous() @@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): q_24_no_zp) q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - # Restore zp - q_24_comp = q_24_no_zp_comp + zp + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias # Resize meta to its actual shape (without moving any data) meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) @@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, def marlin_24_quantize( w: torch.Tensor, - num_bits: int, + quant_type: ScalarType, group_size: int, ): size_k, size_n = w.shape @@ -441,20 +441,18 @@ def marlin_24_quantize( w_24, mask_24 = inject_24(w, size_k, size_n) # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, - num_bits, - group_size, - act_order=False) + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False) # Compress quantized weight q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, - num_bits) + quant_type) size_k_comp = size_k // 2 # Reformat to marlin - weight_perm = get_weight_perm_24(num_bits) + weight_perm = get_weight_perm_24(quant_type.size_bits) marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, - num_bits, weight_perm) + quant_type.size_bits, weight_perm) marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) # Create result diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 7ade8bf664ccc..7f9081b257705 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -4,7 +4,11 @@ import numpy import torch -SUPPORTED_NUM_BITS = [4, 8] +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # Note: this is a hack. We should update each model to register the @@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: def get_pack_factor(num_bits): - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits @@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): ) -def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, - act_order: bool): +def quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + zero_points: bool = False): + assert quant_type.is_integer(), \ + "Floating point quantization may work but has not been tested" + orig_device = w.device + orig_type = w.dtype size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" if group_size == -1: group_size = size_k assert group_size <= size_k - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - # Reshape to [groupsize, -1] if group_size < size_k: w = w.reshape((-1, group_size, size_n)) @@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, w = w.reshape((group_size, -1)) # Compute scale for each group - s = torch.max(torch.abs(w), 0, keepdim=True)[0] - s *= 2 / max_q_val # 2 => symmetric + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + maybe_w_zp = None # Quantize - q_w = torch.round(w / s).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias # Restore original shapes if group_size < size_k: @@ -119,90 +140,48 @@ def reshape_w(w): w = w.reshape((size_k, size_n)).contiguous() return w - q_w = reshape_w(q_w) + w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) - s = s.reshape((-1, size_n)).contiguous() + w_s = w_s.reshape((-1, size_n)).contiguous() - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k) - - w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + if zero_points: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), + w_q.to(device=orig_device), + w_s.to(device=orig_device), + maybe_w_zp, ) -def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape +def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, + group_size: int, act_order: bool): + size_k, _ = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + f"Unsupported gptq type = {quant_type}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - max_q_val = 2**num_bits - 1 - min_q_val = 0 + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - # Reshape to [groupsize, -1] - if group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max = torch.max(w, 0, keepdim=True)[0] - min = torch.min(w, 0, keepdim=True)[0] - s = (max - min).clamp(min=1e-5) / max_q_val - - # Compute zero-point for each group - zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int() - - # Quantize - q_w = torch.round(w / s).int() + zp - q_w = torch.clamp(q_w, min_q_val, max_q_val) - - # Compute ref (dequantized) - w_ref = (q_w - zp).half() * s - - # Restore original shapes - if group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) - s = s.reshape((-1, size_n)).contiguous() - zp = zp.reshape((-1, size_n)).contiguous() + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s.to(device=orig_device), - zp.to(device=orig_device), - ) + return w_ref, w_q, w_s, g_idx, rand_perm # QQQ employs different quant schemes for per-group and @@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ + f"Unsupported num_bits = {num_bits}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py new file mode 100644 index 0000000000000..eb491dd1554a8 --- /dev/null +++ b/vllm/scalar_type.py @@ -0,0 +1,35 @@ +from ._core_ext import NanRepr, ScalarType + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, + NanRepr.EXTD_RANGE_MAX_MIN.value) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) + + # "gptq" types + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10