Skip to content

Commit 02b8ec7

Browse files
Akshat-Tripathiyaochengji
authored andcommitted
[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend (vllm-project#14238)
Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 60993f4 commit 02b8ec7

File tree

19 files changed

+929
-46
lines changed

19 files changed

+929
-46
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ docker run --privileged --net host --shm-size=16G -it \
5050
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
5151
&& echo TEST_12 \
5252
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
53+
# Disable the TPU LoRA tests until the feature is activated
54+
# && echo TEST_13 \
55+
# && pytest -s -v /workspace/vllm/tests/tpu/lora/" \
5356

5457

5558
# TODO: This test fails because it uses RANDOM_SEED sampling

tests/lora/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def dist_init():
4747
temp_file = tempfile.mkstemp()[1]
4848

4949
backend = "nccl"
50-
if current_platform.is_cpu():
50+
if current_platform.is_cpu() or current_platform.is_tpu():
5151
backend = "gloo"
5252

5353
init_distributed_environment(world_size=1,

tests/tpu/lora/__init__.py

Whitespace-only changes.

tests/tpu/lora/test_lora.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
import vllm
5+
from vllm.lora.request import LoRARequest
6+
7+
# This file contains tests to ensure that LoRA works correctly on the TPU
8+
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
9+
# for this. The adapters are:
10+
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
11+
# from 1 to 4.
12+
13+
# These adapters are trained using a standard huggingface peft training script,
14+
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
15+
# 100 training iterations with a training batch size of 100.
16+
17+
18+
@pytest.fixture(scope="function", autouse=True)
19+
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
20+
"""
21+
Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
22+
for all tests in this file
23+
"""
24+
with monkeypatch.context() as m:
25+
m.setenv("VLLM_USE_V1", "1")
26+
yield
27+
28+
29+
def setup_vllm(num_loras: int) -> vllm.LLM:
30+
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
31+
num_scheduler_steps=1,
32+
max_model_len=256,
33+
max_seq_len_to_capture=256,
34+
max_num_seqs=8,
35+
enable_lora=True,
36+
max_loras=num_loras,
37+
max_lora_rank=8)
38+
39+
40+
def test_single_lora():
41+
"""
42+
This test ensures we can run a single LoRA adapter on the TPU backend.
43+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
44+
will force Qwen2.5-3B-Instruct to claim 1+1=1.
45+
"""
46+
47+
llm = setup_vllm(1)
48+
49+
prompt = "What is 1+1? \n"
50+
51+
lora_request = LoRARequest(
52+
"lora_adapter_1", 1,
53+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
54+
output = llm.generate(prompt,
55+
sampling_params=vllm.SamplingParams(max_tokens=256,
56+
temperature=0),
57+
lora_request=lora_request)[0].outputs[0].text
58+
59+
answer = output.strip()[0]
60+
61+
assert answer.isdigit()
62+
assert int(answer) == 1
63+
64+
65+
def test_lora_hotswapping():
66+
"""
67+
This test ensures we can run multiple LoRA adapters on the TPU backend, even
68+
if we only have space to store 1.
69+
70+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
71+
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
72+
"""
73+
74+
lora_name_template = \
75+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
76+
lora_requests = [
77+
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
78+
for i in range(1, 5)
79+
]
80+
81+
llm = setup_vllm(1)
82+
83+
prompt = "What is 1+1? \n"
84+
85+
for i, req in enumerate(lora_requests):
86+
output = llm.generate(prompt,
87+
sampling_params=vllm.SamplingParams(
88+
max_tokens=256, temperature=0),
89+
lora_request=req)[0].outputs[0].text
90+
answer = output.strip()[0]
91+
92+
assert answer.isdigit()
93+
assert int(answer) == i + 1
94+
95+
96+
def test_multi_lora():
97+
"""
98+
This test ensures we can run multiple LoRA adapters on the TPU backend, when
99+
we have enough space to store all of them.
100+
101+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
102+
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
103+
"""
104+
lora_name_template = \
105+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
106+
lora_requests = [
107+
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
108+
for i in range(1, 5)
109+
]
110+
111+
llm = setup_vllm(4)
112+
113+
prompt = "What is 1+1? \n"
114+
115+
for i, req in enumerate(lora_requests):
116+
output = llm.generate(prompt,
117+
sampling_params=vllm.SamplingParams(
118+
max_tokens=256, temperature=0),
119+
lora_request=req)[0].outputs[0].text
120+
121+
answer = output.strip()[0]
122+
123+
assert answer.isdigit()
124+
assert int(output.strip()[0]) == i + 1

tests/tpu/lora/test_pallas_kernels.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
import torch
4+
5+
# Required to register the custom ops
6+
import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import
7+
8+
N_TOKENS = [16, 1024, 4096]
9+
HIDDEN_SIZES = [1024, 2048, 4096]
10+
11+
DTYPES = [torch.bfloat16]
12+
NUM_LORA = [1, 4, 16]
13+
RANKS = [32, 256, 512]
14+
15+
16+
def generate_test_data(T, D, L, N, seed, dtype=torch.float32):
17+
"""
18+
Inputs: (All integers)
19+
T: Total number of tokens
20+
D: Input dim
21+
L: LoRA Dim
22+
N: N LoRAs
23+
24+
Outputs:
25+
inputs: torch.Tensor - shape (T, D)
26+
loras: torch.Tensor - shape (N, 1, L, D)
27+
idxs: torch.Tensor - shape (T, ) - all values must be in [0, N)
28+
29+
ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T
30+
"""
31+
torch.manual_seed(seed)
32+
33+
inputs = torch.randn((T, D), device="xla", dtype=dtype)
34+
loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype)
35+
idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla")
36+
37+
ref_output = ref_bgmv(inputs, loras, idxs)
38+
return inputs, loras, idxs, ref_output
39+
40+
41+
def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor):
42+
selected_loras = loras[idxs]
43+
if len(selected_loras.shape) == 4:
44+
selected_loras = selected_loras.squeeze(axis=1)
45+
46+
batch_size, output_size, input_size = selected_loras.shape
47+
return (selected_loras @ inputs.reshape(
48+
(batch_size, input_size, 1))).reshape((batch_size, output_size))
49+
50+
51+
# Parameterize tests with various shapes and dtypes
52+
@pytest.mark.parametrize("T", N_TOKENS)
53+
@pytest.mark.parametrize("D", HIDDEN_SIZES)
54+
@pytest.mark.parametrize("L", RANKS)
55+
@pytest.mark.parametrize("N", NUM_LORA)
56+
@pytest.mark.parametrize("dtype", DTYPES)
57+
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
58+
@pytest.mark.parametrize("seed", [0])
59+
def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed):
60+
if op_type == "expand":
61+
D, L = L, D
62+
63+
inputs, loras, idxs, ref_output = generate_test_data(
64+
T, D, L, N, seed, dtype)
65+
66+
# Run bgmv
67+
output = torch.ops.xla.bgmv(inputs, loras, idxs)
68+
69+
# Make sure we have no NaNs
70+
assert not torch.any(torch.isnan(output))
71+
72+
# Compare with reference output
73+
assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2)

vllm/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,8 +2694,8 @@ class LoRAConfig:
26942694
lora_extra_vocab_size: int = 256
26952695
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
26962696
(added to the base model vocabulary)."""
2697-
# This is a constant.
2698-
lora_vocab_padding_size: ClassVar[int] = 256
2697+
lora_vocab_padding_size: ClassVar[int] = current_platform\
2698+
.get_lora_vocab_padding_size()
26992699
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
27002700
"""Specify multiple scaling factors (which can be different from base model
27012701
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
@@ -2723,6 +2723,7 @@ def compute_hash(self) -> str:
27232723
factors.append(self.fully_sharded_loras)
27242724
factors.append(self.lora_dtype)
27252725
factors.append(self.lora_extra_vocab_size)
2726+
factors.append(self.lora_vocab_padding_size)
27262727
factors.append(self.long_lora_scaling_factors)
27272728
factors.append(self.bias_enabled)
27282729
hash_str = hashlib.md5(str(factors).encode(),

vllm/lora/fully_sharded_layers.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
MergedQKVParallelLinearWithLoRA,
1717
QKVParallelLinearWithLoRA,
1818
RowParallelLinearWithLoRA)
19+
from vllm.platforms import current_platform
1920

2021
if TYPE_CHECKING:
2122
pass
@@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
5758
device=x.device,
5859
)
5960

60-
layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
61+
shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink(
62+
buffers, x, layer.lora_a_stacked, 1.0)
63+
64+
if not current_platform.can_update_inplace():
65+
buffers = shrunk_buffers
66+
6167
buffers = tensor_model_parallel_all_gather(buffers)
62-
layer.punica_wrapper.add_expand(output,
63-
buffers,
64-
layer.lora_b_stacked,
65-
layer.lora_bias_stacked,
66-
layer.output_slices,
67-
offset_start=0,
68-
add_input=True)
68+
69+
lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand(
70+
output,
71+
buffers,
72+
layer.lora_b_stacked,
73+
layer.lora_bias_stacked,
74+
layer.output_slices,
75+
offset_start=0,
76+
add_input=True)
77+
78+
if not current_platform.can_update_inplace():
79+
output = lora_output
6980

7081
output = output.view(*out_orig_shape)
7182
# now have column partitioned and packed output
@@ -292,7 +303,11 @@ def apply(self,
292303
device=x.device,
293304
)
294305

295-
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
306+
shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink(
307+
buffer, x, self.lora_a_stacked, 1.0)
308+
if not current_platform.can_update_inplace():
309+
buffer = shrunk_buffer
310+
296311
buffer = tensor_model_parallel_all_reduce(buffer)
297312

298313
# following S-LoRA, allows the fusing of all_gather and all_reduce
@@ -304,7 +319,7 @@ def apply(self,
304319
# NOTE offset are based on the rank.
305320
shard_size = self.lora_b_stacked[0].shape[2]
306321
offset_start = self.tp_rank * shard_size
307-
self.punica_wrapper.add_expand(
322+
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand(
308323
output,
309324
buffer,
310325
self.lora_b_stacked,
@@ -313,6 +328,10 @@ def apply(self,
313328
offset_start=offset_start,
314329
add_input=True,
315330
)
331+
332+
if not current_platform.can_update_inplace():
333+
output = lora_output
334+
316335
output = output.view(*out_orig_shape)
317336
return output
318337

vllm/lora/layers.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
261261
full_lora_a_embeddings.shape[1],
262262
-1,
263263
)
264-
self.punica_wrapper.add_lora_embedding(full_output,
265-
full_lora_a_embeddings,
266-
self.lora_b_stacked,
267-
add_input=True)
264+
265+
lora_output: Optional[
266+
torch.Tensor] = self.punica_wrapper.add_lora_embedding(
267+
full_output,
268+
full_lora_a_embeddings,
269+
self.lora_b_stacked,
270+
add_input=True)
271+
272+
if not current_platform.can_update_inplace():
273+
full_output = lora_output
274+
268275
return full_output.view_as(full_output_org)
269276

270277
@classmethod
@@ -410,10 +417,13 @@ def apply(self,
410417
output = output.flatten(0, 1)
411418
x = x.flatten(0, 1)
412419

413-
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
414-
self.lora_b_stacked,
415-
self.lora_bias_stacked, 1.0,
416-
self.output_slices)
420+
lora_output: Optional[
421+
torch.Tensor] = self.punica_wrapper.add_lora_linear(
422+
output, x, self.lora_a_stacked, self.lora_b_stacked,
423+
self.lora_bias_stacked, 1.0, self.output_slices)
424+
if not current_platform.can_update_inplace():
425+
output = lora_output
426+
417427
return output
418428

419429
@property
@@ -1133,15 +1143,23 @@ def _get_logits(
11331143
torch.matmul(self.embeddings_tensors,
11341144
hidden_states.T,
11351145
out=lora_logits[:-1])
1136-
lora_logits[-1] = float("-inf")
1146+
1147+
neg_inf, pos_inf = current_platform.get_infinity_values(
1148+
lora_logits.dtype)
1149+
1150+
lora_logits[-1] = neg_inf
11371151
lora_logits = lora_logits.mT
11381152
indices_padded = self.punica_wrapper.sampler_indices_padded
1153+
1154+
if current_platform.is_tpu():
1155+
indices_padded = indices_padded[:logits.size(0)]
1156+
11391157
lora_logits = (lora_logits.reshape(
11401158
lora_logits.shape[0] * lora_logits.shape[1],
11411159
lora_logits.shape[2],
1142-
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
1143-
posinf=float("inf"),
1144-
neginf=float("-inf")))
1160+
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
1161+
posinf=pos_inf,
1162+
neginf=neg_inf))
11451163

11461164
# HPU needs special handling to prune out dummy samples.
11471165
if current_platform.is_hpu():
@@ -1151,10 +1169,13 @@ def _get_logits(
11511169
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
11521170
lora_logits.shape[1]] = lora_logits
11531171

1154-
# LogitsProcessorWithLoRA always using bgmv
1155-
self.punica_wrapper.add_lora_logits(logits, hidden_states,
1156-
self.lora_a_stacked,
1157-
self.lora_b_stacked, 1.0)
1172+
lora_output: Optional[
1173+
torch.Tensor] = self.punica_wrapper.add_lora_logits(
1174+
logits, hidden_states, self.lora_a_stacked,
1175+
self.lora_b_stacked, 1.0)
1176+
1177+
if not current_platform.can_update_inplace():
1178+
logits = lora_output
11581179

11591180
# Remove paddings in vocab (if any).
11601181
logits = logits[:, :self.base_layer.vocab_size]

0 commit comments

Comments
 (0)