Skip to content

Commit 2a543d6

Browse files
authored
Add LoRA support for Mixtral (#2831)
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
1 parent 317b29d commit 2a543d6

File tree

10 files changed

+251
-121
lines changed

10 files changed

+251
-121
lines changed

tests/lora/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ def sql_lora_files():
121121
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
122122

123123

124+
@pytest.fixture(scope="session")
125+
def mixtral_lora_files():
126+
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
127+
128+
124129
@pytest.fixture
125130
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
126131
cleanup()

tests/lora/test_lora_manager.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,35 @@
1111
RowParallelLinearWithLoRA,
1212
MergedColumnParallelLinearWithLoRA)
1313
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
14-
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
14+
from vllm.lora.models import (LoRAModel, LoRAModelManager,
1515
LRUCacheLoRAModelManager, LoRAMapping)
1616
from vllm.lora.request import LoRARequest
1717
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
1818
WorkerLoRAManager)
1919
from vllm.model_executor.layers.linear import RowParallelLinear
2020

21+
EMBEDDING_MODULES = {
22+
"embed_tokens": "input_embeddings",
23+
"lm_head": "output_embeddings",
24+
}
25+
26+
EMBEDDING_PADDING_MODULES = ["lm_head"]
27+
2128

2229
def test_from_lora_tensors(sql_lora_files):
2330
tensors = load_file(
2431
os.path.join(sql_lora_files, "adapter_model.safetensors"))
2532
new_embeddings = load_file(
2633
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
27-
lora_model = LoRAModel.from_lora_tensors(1,
28-
8,
29-
16,
30-
tensors,
31-
"cuda",
32-
embeddings=new_embeddings)
34+
lora_model = LoRAModel.from_lora_tensors(
35+
1,
36+
8,
37+
16,
38+
tensors,
39+
"cuda",
40+
embeddings=new_embeddings,
41+
embedding_modules=EMBEDDING_MODULES,
42+
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
3343
for module_name, lora in lora_model.loras.items():
3444
assert lora.module_name == module_name
3545
assert lora.rank == 8
@@ -90,14 +100,11 @@ def create_packed_lora(
90100

91101
def test_replace_submodules(dist_init, dummy_model):
92102
model = dummy_model
93-
manager = LoRAModelManager(model,
94-
1,
95-
1,
96-
1,
97-
LoRAConfig(max_lora_rank=8,
98-
max_cpu_loras=8,
99-
max_loras=8),
100-
lora_target_modules=["dense1", "layer1.dense2"])
103+
model.supported_lora_modules = ["dense1", "layer1.dense2"]
104+
model.packed_modules_mapping = {}
105+
manager = LoRAModelManager(
106+
model, 1, 1, 1,
107+
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
101108
model = manager.model
102109

103110
assert isinstance(model.get_submodule("dense1"),
@@ -111,16 +118,14 @@ def test_replace_submodules(dist_init, dummy_model):
111118

112119
def test_lora_model_manager(dist_init, dummy_model):
113120
model = dummy_model
121+
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
122+
model.packed_modules_mapping = {}
114123
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
115124
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
116125
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
117126
manager = LoRAModelManager(
118-
model,
119-
2,
120-
2,
121-
2,
122-
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
123-
lora_target_modules=["dense1", "dense2", "lm_head"])
127+
model, 2, 2, 2,
128+
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
124129
assert all(x is None for x in manager.lora_index_to_id)
125130
assert manager.add_lora(model_lora1)
126131
assert manager.activate_lora(1)
@@ -159,16 +164,14 @@ def test_lora_model_manager(dist_init, dummy_model):
159164

160165
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
161166
model = dummy_model
167+
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
168+
model.packed_modules_mapping = {}
162169
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
163170
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
164171
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
165172
manager = LRUCacheLoRAModelManager(
166-
model,
167-
2,
168-
2,
169-
2,
170-
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
171-
lora_target_modules=["dense1", "dense2", "lm_head"])
173+
model, 2, 2, 2,
174+
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
172175
assert all(x is None for x in manager.lora_index_to_id)
173176
assert manager.add_lora(model_lora1)
174177
assert manager.activate_lora(1)
@@ -212,14 +215,15 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
212215
# This tests just the LRU cache functionality, everything else is
213216
# tested in test_lora_model_manager
214217
model = dummy_model
218+
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
219+
model.packed_modules_mapping = {}
215220
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
216221
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
217222
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
218223
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
219224
manager = LRUCacheLoRAModelManager(
220225
model, 2, 2, 2,
221-
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
222-
["dense1", "dense2", "lm_head"])
226+
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
223227

224228
assert all(x is None for x in manager.lora_index_to_id)
225229

@@ -289,8 +293,9 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
289293
sql_lora_files):
290294
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
291295
worker_lora_manager = LRUCacheWorkerLoRAManager(
292-
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
293-
torch.device("cuda"))
296+
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
297+
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
298+
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
294299
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
295300

296301
mapping = LoRAMapping([], [])
@@ -362,8 +367,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
362367
# Should remove every LoRA not specified in the request.
363368
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
364369
worker_lora_manager = WorkerLoRAManager(
365-
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
366-
torch.device("cuda"))
370+
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
371+
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
372+
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
367373
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
368374

369375
mapping = LoRAMapping([], [])
@@ -428,6 +434,13 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
428434

429435
def test_packed_loras(dist_init, dummy_model_gate_up):
430436
model = dummy_model_gate_up
437+
model.supported_lora_modules = ["gate_up_proj"]
438+
model.packed_modules_mapping = {
439+
"gate_up_proj": [
440+
"gate_proj",
441+
"up_proj",
442+
],
443+
}
431444
model_lora = create_packed_lora(
432445
1,
433446
model,
@@ -443,8 +456,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
443456

444457
manager = LoRAModelManager(
445458
model, 2, 2, 2,
446-
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
447-
["gate_up_proj"])
459+
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
448460
model = manager.model
449461

450462
assert isinstance(model.get_submodule("gate_up_proj"),

tests/lora/test_mixtral.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
import torch
3+
4+
import vllm
5+
from vllm.lora.request import LoRARequest
6+
7+
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
8+
9+
10+
def do_sample(llm, lora_path: str, lora_id: int):
11+
prompts = [
12+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
13+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
14+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
15+
]
16+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
17+
outputs = llm.generate(
18+
prompts,
19+
sampling_params,
20+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
21+
if lora_id else None)
22+
# Print the outputs.
23+
generated_texts = []
24+
for output in outputs:
25+
prompt = output.prompt
26+
generated_text = output.outputs[0].text.strip()
27+
generated_texts.append(generated_text)
28+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
29+
return generated_texts
30+
31+
32+
@pytest.mark.parametrize("tp_size", [4])
33+
def test_mixtral_lora(mixtral_lora_files, tp_size):
34+
if torch.cuda.device_count() < tp_size:
35+
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
36+
37+
llm = vllm.LLM(MODEL_PATH,
38+
enable_lora=True,
39+
max_num_seqs=16,
40+
max_loras=4,
41+
tensor_parallel_size=tp_size,
42+
worker_use_ray=True)
43+
44+
expected_lora_output = [
45+
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
46+
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
47+
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])",
48+
]
49+
50+
assert do_sample(llm, mixtral_lora_files,
51+
lora_id=1) == expected_lora_output
52+
assert do_sample(llm, mixtral_lora_files,
53+
lora_id=2) == expected_lora_output

0 commit comments

Comments
 (0)