Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build] Avoid CUDA initialization #8534

Merged
merged 22 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Replace int() with to_int
  • Loading branch information
DarkLight1337 committed Sep 18, 2024
commit 23aa200e083c38fc0df12f9ee3a0bb52e7142a22
2 changes: 1 addition & 1 deletion tests/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def is_quant_method_supported(quant_method: str) -> bool:

min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()

return int(capability) >= min_capability
return capability.to_int() >= min_capability
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def get_config_filenames(cls) -> List[str]:
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability = current_platform.get_device_capability()
capability_tuple = current_platform.get_device_capability()

if capability is not None:
capability = capability[0] * 10 + capability[1]
if capability_tuple is not None:
capability = capability_tuple.to_int()
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
device_capability = int(current_platform.get_device_capability() or -1)
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())

if device_capability < 80:
return []
Expand All @@ -51,8 +53,9 @@ def _check_marlin_supported(
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:

if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())

supported_types = query_marlin_supported_quant_types(
has_zp, device_capability)
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if is_hip():
return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]

capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()

return ops.cutlass_scaled_mm_supports_fp8(capability)

Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = current_platform.get_device_capability() # type: ignore
capability_tuple = current_platform.get_device_capability()

if capability is not None:
capability = capability[0] * 10 + capability[1]
if capability_tuple is not None:
capability = capability_tuple.to_int()
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
device_available = current_platform.has_device_capability(major=8)
if device_available:
from transformers.utils import is_flash_attn_2_available

Expand Down
3 changes: 0 additions & 3 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def to_int(self) -> int:
"""
return self.major * 10 + self.minor

def __int__(self) -> int:
return self.to_int()


class Platform:
_enum: PlatformEnum
Expand Down
Loading