diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 2bba9430e4d84..4344f96ee488d 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -14,6 +14,10 @@ namespace vllm { // can be used as a argument for custom operators, helping to simplify these // interfaces. // +// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// these type definitions should be kept up to date with any Python API changes +// here. +// class ScalarType { public: enum NanRepr : int64_t { @@ -104,7 +108,7 @@ class ScalarType { max_exponent += 1; } - // adjust the exponent to match that off a double + // adjust the exponent to match that of a double // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e // is the exponent bits), there is some precedent for non-standard biases, // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes @@ -214,7 +218,9 @@ class ScalarType { bool operator==(ScalarType const& other) const { return mantissa == other.mantissa && exponent == other.exponent && - bias == other.bias && _signed == other._signed; + bias == other.bias && _signed == other._signed && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; } }; diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py index 8d6684d193db4..37ab191b66e43 100644 --- a/vllm/_core_ext.py +++ b/vllm/_core_ext.py @@ -10,8 +10,7 @@ if core_C_available: try: - # ruff: noqa: F401 - import vllm._core_C + import vllm._core_C # ruff: noqa: F401 except ImportError as e: logger.warning("Failed to import from vllm._core_C with %r", e) diff --git a/vllm/_core_ext.pyi b/vllm/_core_ext.pyi index 3b95bd3be1872..45fe96ec681ca 100644 --- a/vllm/_core_ext.pyi +++ b/vllm/_core_ext.pyi @@ -7,6 +7,9 @@ class ScalarType: that torch.dtype currently does not support). It is also cabaable of representing types with a bias, i.e. the stored_value = value + bias, this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias of 8). + + The implementation for this class can be found in csrc/core/scalar_type.hpp, + these type definitions should be kept in snyc with that file. """ def __init__(self, exponent: int, mantissa: int, bias: int,