Skip to content

[Bugfix] Explicitly set LoRA triton kernel device #13043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/bgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs
from .utils import _set_cuda_device, get_lora_op_configs


@triton.jit
Expand Down Expand Up @@ -142,6 +142,7 @@ def _bgmv_expand(
META["SPLIT_N"],
batches,
)
_set_cuda_device(inputs.device)
_bgmv_expand_kernel[grid](
inputs,
lora_b_weights,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/bgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs
from .utils import _set_cuda_device, get_lora_op_configs


@triton.jit
Expand Down Expand Up @@ -158,6 +158,7 @@ def _bgmv_expand_slice(
META["SPLIT_N"],
batches,
)
_set_cuda_device(inputs.device)
_bgmv_expand_slice_kernel[grid](
inputs,
lora_b_weights,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/bgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs
from .utils import _set_cuda_device, get_lora_op_configs


@triton.jit
Expand Down Expand Up @@ -124,6 +124,7 @@ def _bgmv_shrink(
META["SPLIT_K"],
batches,
)
_set_cuda_device(inputs.device)
_bgmv_shrink_kernel[grid](
inputs,
lora_a_weights,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from vllm.utils import direct_register_custom_op

from .utils import _get_lora_b_ptr
from .utils import _get_lora_b_ptr, _set_cuda_device


@triton.jit
Expand Down Expand Up @@ -218,6 +218,7 @@ def _sgmv_expand(
batches,
len(lora_b_weights),
)
_set_cuda_device(inputs.device)
_sgmv_expand_kernel[grid](
inputs,
lora_ptr_tensor,
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/ops/triton_ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from vllm.utils import direct_register_custom_op

from .utils import _get_lora_a_ptr
from .utils import _get_lora_a_ptr, _set_cuda_device


@triton.jit
Expand Down Expand Up @@ -184,6 +184,7 @@ def _sgmv_shrink(
SPLIT_K * len(lora_a_weights),
batches,
)
_set_cuda_device(inputs.device)
_sgmv_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
Expand Down
9 changes: 9 additions & 0 deletions vllm/lora/ops/triton_ops/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from functools import lru_cache
from typing import Dict, List, Tuple

import torch
Expand Down Expand Up @@ -50,6 +51,14 @@ def get_lora_op_configs(op_type: str, batch: int,
return config


@lru_cache
def _set_cuda_device(device: torch.device):
"""
Sets the current CUDA device.
"""
torch.cuda.set_device(device)


_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}

Expand Down
Loading