Skip to content

Commit d4d1a60

Browse files
[Lora]Load tuned multi-lora kernel configs from json files (#26319)
Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Signed-off-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent db1764e commit d4d1a60

File tree

4 files changed

+198
-16
lines changed

4 files changed

+198
-16
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Multi-LoRA Tuning
2+
3+
**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`. Without this, the shrink/expand kernels will use default configurations.
4+
5+
## Tuning Process
6+
7+
Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from [Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).
8+
9+
**Step 1**
10+
Define the searching space. An example searching space:
11+
12+
```python
13+
block_m_range = [16, 32, 64, 128, 256]
14+
block_n_range = [32, 64, 128, 256]
15+
block_k_range = [32, 64, 128, 256]
16+
num_warps_range = [4, 8]
17+
num_stage_range = [2, 3, 4, 5]
18+
num_ctas_range = [1]
19+
split_k_range = [4, 8, 16, 32, 64]
20+
```
21+
22+
**Step 2**
23+
Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.
24+
25+
For example, we can aquire those info by simply checking [add_lora_linear](https://github.com/li2haipeng/vllm/blob/multi_lora_v01011/vllm/lora/punica_wrapper/punica_gpu.py#L192):
26+
27+
```python
28+
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
29+
print(f"num_sclises: {len(output_slices)}")
30+
for i in range(len(output_slices)):
31+
print(f"a{i} shape: {lora_a_stacked[i].shape}")
32+
print(f"b{i} shape: {lora_b_stacked[i].shape}")
33+
print("y_shape", y.shape)
34+
```
35+
36+
**Step 3**
37+
Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space by performing a grid search to find the optimal kernel configuration. vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py) can be used to search for configurations for different shapes.
38+
39+
## Config Files
40+
41+
### File Name
42+
43+
For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA_H200_SHRINK.json`.
44+
45+
For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.
46+
47+
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`
48+
49+
### Json Structure
50+
51+
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`

vllm/lora/ops/triton_ops/lora_expand_op.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111

1212
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
13-
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
13+
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
1414
from vllm.triton_utils import tl, triton
1515
from vllm.utils import direct_register_custom_op
1616

@@ -201,12 +201,21 @@ def _lora_expand(
201201
NUM_SLICES = len(lora_b_weights)
202202

203203
# Triton kernel configs.
204-
BLOCK_M = 64
205-
BLOCK_N = 128
206-
BLOCK_K = 16
207-
NUM_WARPS = 4
208-
NUM_CTAS = 1
209-
NUM_STAGES = 2
204+
kernel_config = get_lora_op_configs(
205+
op_type="expand",
206+
max_loras=MAX_LORAS,
207+
batch=M,
208+
hidden_size=MAX_N,
209+
rank=K,
210+
num_slices=NUM_SLICES,
211+
add_inputs=add_inputs,
212+
)
213+
BLOCK_M = kernel_config["block_m"]
214+
BLOCK_N = kernel_config["block_n"]
215+
BLOCK_K = kernel_config["block_k"]
216+
NUM_WARPS = kernel_config["num_warps"]
217+
NUM_CTAS = kernel_config["num_ctas"]
218+
NUM_STAGES = kernel_config["num_stages"]
210219

211220
EVEN_K = K % BLOCK_K == 0 # type: ignore
212221

vllm/lora/ops/triton_ops/lora_shrink_op.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111

1212
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
13-
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
13+
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
1414
from vllm.triton_utils import tl, triton
1515
from vllm.utils import direct_register_custom_op
1616

@@ -177,14 +177,21 @@ def _lora_shrink(
177177
MAX_LORAS = lora_ids.size(0)
178178

179179
# Triton kernel configs
180-
BLOCK_M = 32
181-
BLOCK_N = 16
182-
BLOCK_K = 256 if M < 128 else 32
183-
SPLIT_K = 64 if M < 128 else 8
184-
NUM_WARPS = 4
185-
NUM_CTAS = 1
186-
NUM_STAGES = 2
187-
180+
kernel_config = get_lora_op_configs(
181+
"shrink",
182+
max_loras=MAX_LORAS,
183+
batch=M,
184+
hidden_size=K,
185+
rank=N,
186+
num_slices=NUM_SLICES,
187+
)
188+
BLOCK_M = kernel_config["block_m"]
189+
BLOCK_N = kernel_config["block_n"]
190+
BLOCK_K = kernel_config["block_k"]
191+
SPLIT_K = kernel_config["split_k"]
192+
NUM_WARPS = kernel_config["num_warps"]
193+
NUM_STAGES = kernel_config["num_stages"]
194+
NUM_CTAS = kernel_config["num_ctas"]
188195
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
189196

190197
# TODO (varun): This grid formulation maximizes parallelization at the

vllm/lora/ops/triton_ops/utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import functools
5+
import json
6+
from pathlib import Path
7+
from typing import Any
8+
49
import torch
510

11+
from vllm import envs
12+
from vllm.logger import init_logger
13+
14+
logger = init_logger(__name__)
15+
616
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
717
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
818

@@ -133,3 +143,108 @@ def _get_lora_b_ptr(
133143
MAX_N,
134144
)
135145
return _LORA_B_PTR_DICT.get(key)
146+
147+
148+
@functools.lru_cache
149+
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
150+
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
151+
if user_defined_config_folder is not None:
152+
gpu_name = torch.cuda.get_device_name()
153+
gpu_name = gpu_name.replace(" ", "_")
154+
gpu_name = gpu_name.replace("-", "_")
155+
156+
config_fname = None
157+
if op_type == "shrink":
158+
config_fname = f"{gpu_name}_{op_type.upper()}.json"
159+
else:
160+
assert op_type == "expand"
161+
config_fname = (
162+
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
163+
)
164+
165+
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
166+
if not config_path.exists():
167+
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
168+
return None
169+
170+
# Load json
171+
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
172+
with open(str(config_path)) as f:
173+
config_data = json.load(f)
174+
else:
175+
config_data = None
176+
177+
return config_data
178+
179+
180+
@functools.lru_cache
181+
def get_lora_op_configs(
182+
op_type: str,
183+
max_loras: int,
184+
batch: int,
185+
hidden_size: int,
186+
rank: int,
187+
num_slices: int,
188+
add_inputs: bool | None = None,
189+
) -> dict[str, int | None]:
190+
assert op_type in ["shrink", "expand"]
191+
192+
# default config
193+
default = {}
194+
if op_type == "shrink":
195+
default = {
196+
"block_m": 32,
197+
"block_n": 16,
198+
"block_k": 256 if batch < 128 else 32,
199+
"split_k": 64 if batch < 128 else 8,
200+
"num_warps": 4,
201+
"num_ctas": 1,
202+
"num_stages": 2,
203+
"max_nreg": None,
204+
}
205+
else:
206+
default = {
207+
"block_m": 64,
208+
"block_n": 128,
209+
"block_k": 16,
210+
"num_warps": 4,
211+
"num_ctas": 1,
212+
"num_stages": 2,
213+
"max_nreg": None,
214+
}
215+
m = batch
216+
217+
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
218+
219+
config_data: Any
220+
config_data = load_lora_op_config(op_type, add_inputs)
221+
if not config_data:
222+
logger.warning_once("Using default LoRA kernel configs")
223+
return default
224+
225+
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
226+
# slice by max_loras
227+
config_data = (
228+
config_data.get(str(max_loras))
229+
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
230+
)
231+
# slice by num_slices
232+
config_data = config_data[str(num_slices)]
233+
# slice by m
234+
config_data = (
235+
config_data.get(str(m))
236+
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
237+
)
238+
# slice by k
239+
config_data = (
240+
config_data.get(str(k))
241+
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
242+
)
243+
# slice by n
244+
config_data = (
245+
config_data.get(str(n))
246+
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
247+
)
248+
249+
assert config_data is not None
250+
return config_data

0 commit comments

Comments
 (0)