Skip to content

Commit

Permalink
refactor get_punica_wrapper() into platform
Browse files Browse the repository at this point in the history
Signed-off-by: Shanshan Shen <467638484@qq.com>
  • Loading branch information
shen-shanshan committed Dec 26, 2024
1 parent 12b41b1 commit 8fd9677
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 13 deletions.
14 changes: 1 addition & 13 deletions vllm/lora/punica_wrapper/punica_selector.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
from vllm.platforms import current_platform
from vllm.utils import print_info_once

from .punica_base import PunicaWrapperBase


def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
if current_platform.is_cuda_alike():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
print_info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError
return current_platform.get_punica_wrapper(*args, **kwargs)
7 changes: 7 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch

from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.utils import print_info_once

from .interface import Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -106,3 +108,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
return False

@classmethod
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("PunicaWrapperCPU is not implemented yet.")
raise NotImplementedError
8 changes: 8 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import vllm._C # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
from vllm.utils import print_info_once

from .interface import DeviceCapability, Platform, PlatformEnum

Expand Down Expand Up @@ -148,6 +151,11 @@ def get_current_memory_usage(cls,
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)

@classmethod
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch

from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
from vllm.utils import print_info_once

from .interface import Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -58,3 +61,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
return False

@classmethod
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
8 changes: 8 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -247,6 +248,13 @@ def get_current_memory_usage(cls,
"""
raise NotImplementedError

@classmethod
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
"""
Return the punica wrapper for the current platform.
"""
raise NotImplementedError


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
from vllm.utils import print_info_once

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -117,3 +120,8 @@ def get_current_memory_usage(cls,
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)

@classmethod
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)

0 comments on commit 8fd9677

Please sign in to comment.