Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Jul 18, 2024
1 parent 67bd816 commit 78050f6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
26 changes: 13 additions & 13 deletions csrc/core/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ class ScalarType {
/*
* generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f):
* - trailing f: means finite values only (no infinities)
* - trailing n: means nans are supported
* - no-trailing letters: means it follows IEEE 754 conventions
* - E_: exponent size
* - M_: mantissa size
* for integer types:
* - no-trailing letters: means it follows IEEE 754 conventions
* - trailing f: means finite values only (no infinities)
* - trailing n: means nans are supported (non-standard encoding)
* for integer types (leading s/u):
* - leading s: means signed
* - leading u: means unsigned
* - number following s/u: number of bits
Expand Down Expand Up @@ -337,17 +337,17 @@ using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;

/*
* generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading F):
* - trailing f: means finite values only (no infinities)
* - trailing n: means nans are supported
* - no-trailing letters: means it follows IEEE 754 conventions
* for floating point types (leading f):
* - E_: exponent size
* - M_: mantissa size
* for integer types:
* - leading S: means signed
* - leading U: means unsigned
* - number following S/U: number of bits
* - BX: indicates a non-zero bias of X
* - no-trailing letters: means it follows IEEE 754 conventions
* - trailing f: means finite values only (no infinities)
* - trailing n: means nans are supported (non-standard encoding)
* for integer types (leading s/u):
* - leading s: means signed
* - leading u: means unsigned
* - number following s/u: number of bits
* - bX: indicates a non-zero bias of X
*/
static inline constexpr auto kS4 = ScalarType::s(4);
static inline constexpr auto kU4 = ScalarType::u(4);
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,10 @@ def apply(
size_k = x_2d.shape[1]
size_n = scales.shape[1]

output_2d = ops.gptq_marlin_24_gemm(
x_2d, qweight, meta, scales, workspace,
self.quant_config.quant_type.size_bits, size_m, size_n, size_k)
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.quant_type,
size_m, size_n, size_k)

output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

Expand Down
8 changes: 4 additions & 4 deletions vllm/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class NanRepr(Enum):

# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f):
# - trailing f: means finite values only (no infinities)
# - trailing n: means nans are supported
# - no-trailing letters: means it follows IEEE 754 conventions
# - E_: exponent size
# - M_: mantissa size
# for integer types:
# - no-trailing letters: means it follows IEEE 754 conventions
# - trailing f: means finite values only (no infinities)
# - trailing n: means nans are supported (non-standard encoding)
# for integer types (leading s/u):
# - leading s: means signed
# - leading u: means unsigned
# - number following s/u: number of bits
Expand Down

0 comments on commit 78050f6

Please sign in to comment.