Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Jul 22, 2024
1 parent 78050f6 commit 2b9ef42
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
10 changes: 8 additions & 2 deletions csrc/core/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ namespace vllm {
// can be used as a argument for custom operators, helping to simplify these
// interfaces.
//
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : int64_t {
Expand Down Expand Up @@ -104,7 +108,7 @@ class ScalarType {
max_exponent += 1;
}

// adjust the exponent to match that off a double
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
Expand Down Expand Up @@ -214,7 +218,9 @@ class ScalarType {

bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && _signed == other._signed;
bias == other.bias && _signed == other._signed &&
finite_values_only == other.finite_values_only &&
nan_repr == other.nan_repr;
}
};

Expand Down
3 changes: 1 addition & 2 deletions vllm/_core_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

if core_C_available:
try:
# ruff: noqa: F401
import vllm._core_C
import vllm._core_C # ruff: noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._core_C with %r", e)

Expand Down
3 changes: 3 additions & 0 deletions vllm/_core_ext.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ class ScalarType:
that torch.dtype currently does not support). It is also cabaable of
representing types with a bias, i.e. the stored_value = value + bias, this
is useful for quantized types (e.g. standard GPTQ 4bit uses a bias of 8).
The implementation for this class can be found in csrc/core/scalar_type.hpp,
these type definitions should be kept in snyc with that file.
"""

def __init__(self, exponent: int, mantissa: int, bias: int,
Expand Down

0 comments on commit 2b9ef42

Please sign in to comment.