diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 4e190d837e862..2bba9430e4d84 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -180,12 +180,12 @@ class ScalarType { /* * generally follows: https://github.com/jax-ml/ml_dtypes * for floating point types (leading f): - * - trailing f: means finite values only (no infinities) - * - trailing n: means nans are supported - * - no-trailing letters: means it follows IEEE 754 conventions * - E_: exponent size * - M_: mantissa size - * for integer types: + * - no-trailing letters: means it follows IEEE 754 conventions + * - trailing f: means finite values only (no infinities) + * - trailing n: means nans are supported (non-standard encoding) + * for integer types (leading s/u): * - leading s: means signed * - leading u: means unsigned * - number following s/u: number of bits @@ -337,17 +337,17 @@ using ScalarTypeTorchPtr = c10::intrusive_ptr; /* * generally follows: https://github.com/jax-ml/ml_dtypes - * for floating point types (leading F): - * - trailing f: means finite values only (no infinities) - * - trailing n: means nans are supported - * - no-trailing letters: means it follows IEEE 754 conventions + * for floating point types (leading f): * - E_: exponent size * - M_: mantissa size - * for integer types: - * - leading S: means signed - * - leading U: means unsigned - * - number following S/U: number of bits - * - BX: indicates a non-zero bias of X + * - no-trailing letters: means it follows IEEE 754 conventions + * - trailing f: means finite values only (no infinities) + * - trailing n: means nans are supported (non-standard encoding) + * for integer types (leading s/u): + * - leading s: means signed + * - leading u: means unsigned + * - number following s/u: number of bits + * - bX: indicates a non-zero bias of X */ static inline constexpr auto kS4 = ScalarType::s(4); static inline constexpr auto kU4 = ScalarType::u(4); diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 556401a455b28..09897b8521a5f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -285,9 +285,10 @@ def apply( size_k = x_2d.shape[1] size_n = scales.shape[1] - output_2d = ops.gptq_marlin_24_gemm( - x_2d, qweight, meta, scales, workspace, - self.quant_config.quant_type.size_bits, size_m, size_n, size_k) + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, + self.quant_config.quant_type, + size_m, size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 87d6ecfe005d0..2d26928f91ded 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -12,12 +12,12 @@ class NanRepr(Enum): # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f): -# - trailing f: means finite values only (no infinities) -# - trailing n: means nans are supported -# - no-trailing letters: means it follows IEEE 754 conventions # - E_: exponent size # - M_: mantissa size -# for integer types: +# - no-trailing letters: means it follows IEEE 754 conventions +# - trailing f: means finite values only (no infinities) +# - trailing n: means nans are supported (non-standard encoding) +# for integer types (leading s/u): # - leading s: means signed # - leading u: means unsigned # - number following s/u: number of bits