Skip to content

Commit d380f11

Browse files
jeejeeleexuebwang-amd
authored andcommitted
[Core] Optimize LoRA weight loading (vllm-project#25403)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 6ed283f commit d380f11

File tree

10 files changed

+83
-83
lines changed

10 files changed

+83
-83
lines changed

tests/lora/test_layers.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def populate_loras(
164164
weight=layer_weights,
165165
generate_embeddings_tensor=generate_embeddings_tensor,
166166
)
167-
sublora.lora_b = sublora.lora_b[:, (sublora_len *
168-
i):(sublora_len * (i + 1))]
167+
sublora.lora_b = sublora.lora_b[(sublora_len *
168+
i):(sublora_len * (i + 1)), :]
169169
sublora.optimize()
170170
subloras.append(sublora)
171171

@@ -304,9 +304,9 @@ def create_random_embedding_layer():
304304
result = embedding(input_)
305305
after_a = F.embedding(
306306
input_,
307-
lora.lora_a,
307+
lora.lora_a.T,
308308
)
309-
result += (after_a @ lora.lora_b)
309+
result += (after_a @ lora.lora_b.T)
310310
expected_results.append(result)
311311
expected_result = torch.cat(expected_results)
312312

@@ -445,9 +445,9 @@ def create_random_embedding_layer():
445445
result = expanded_embedding(input_)
446446
after_a = F.embedding(
447447
original_input_,
448-
lora.lora_a,
448+
lora.lora_a.T,
449449
)
450-
result += (after_a @ lora.lora_b)
450+
result += (after_a @ lora.lora_b.T)
451451
expected_results.append(result)
452452
expected_result = torch.cat(expected_results)
453453

@@ -575,7 +575,7 @@ def _pretest():
575575
lm_head=linear,
576576
embedding_bias=None)
577577
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
578-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
578+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
579579
expected_results.append(result)
580580
expected_result = torch.cat(expected_results)
581581
logits_processor.org_vocab_size = vocab_size
@@ -692,9 +692,10 @@ def create_random_linear_replicated_layer():
692692

693693
expected_results: list[torch.Tensor] = []
694694
for input_, lora_id in zip(inputs, prompt_mapping):
695+
695696
lora = lora_dict[lora_id]
696697
result = linear(input_)[0]
697-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
698+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
698699
expected_results.append(result)
699700
expected_result = torch.cat(expected_results)
700701

@@ -817,7 +818,7 @@ def create_random_linear_parallel_layer():
817818
for input_, lora_id in zip(inputs, prompt_mapping):
818819
lora = lora_dict[lora_id]
819820
result = linear(input_)[0]
820-
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
821+
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
821822
expected_results.append(result)
822823
expected_result = torch.cat(expected_results)
823824

@@ -965,9 +966,10 @@ class FakeConfig:
965966
result = linear(input_)[0]
966967
subloras = sublora_dict[lora_id]
967968
for i, sublora in enumerate(subloras):
968-
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
969-
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
970-
sublora.scaling)
969+
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
970+
(i + 1)] += (
971+
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
972+
sublora.scaling)
971973
expected_results.append(result)
972974
expected_result = torch.cat(expected_results)
973975

tests/lora/test_lora_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
6363
assert lora.lora_b is not None
6464
assert lora.lora_a.device == torch.device(device)
6565
assert lora.lora_b.device == torch.device(device)
66-
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
66+
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
6767
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
68-
assert lora.lora_a.shape[1] == 8
68+
assert lora.lora_a.shape[0] == 8
6969
embeddings_module = next(
7070
(k for k in EMBEDDING_MODULES if k in module_name), None)
7171
if embeddings_module:
@@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
8686
name,
8787
8,
8888
16,
89-
torch.rand([w.shape[1], 8], device=device),
90-
torch.rand([8, w.shape[0]], device=device),
89+
torch.rand([8, w.shape[1]], device=device),
90+
torch.rand([w.shape[0], 8], device=device),
9191
)
9292
return LoRAModel(lora_id, 8, loras)
9393

@@ -109,8 +109,8 @@ def create_packed_lora(
109109
replaced_module_name,
110110
8,
111111
16,
112-
torch.rand([w.shape[1], 8], device=device),
113-
torch.rand([8, w.shape[0] // len(replaced_module_names)],
112+
torch.rand([8, w.shape[1]], device=device),
113+
torch.rand([w.shape[0] // len(replaced_module_names), 8],
114114
device=device),
115115
)
116116
return LoRAModel(lora_id, 8, loras)

tests/lora/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def init_random_lora(
3636
module_name,
3737
rank=rank,
3838
lora_alpha=1,
39-
lora_a=torch.rand([weight.shape[1], rank],
39+
lora_a=torch.rand([rank, weight.shape[1]],
4040
dtype=weight.dtype,
4141
device=self._device),
42-
lora_b=torch.rand([rank, weight.shape[0]],
42+
lora_b=torch.rand([weight.shape[0], rank],
4343
dtype=weight.dtype,
4444
device=self._device),
4545
)
@@ -67,8 +67,8 @@ def init_lora(
6767
module_name,
6868
rank=rank,
6969
lora_alpha=1,
70-
lora_a=torch.rand([input_dim, rank], device="cuda"),
71-
lora_b=torch.rand([rank, output_dim], device="cuda"),
70+
lora_a=torch.rand([rank, input_dim], device="cuda"),
71+
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
7272
embeddings_tensor=embeddings_tensor,
7373
)
7474
self.set_module_lora(module_name, lora)

vllm/lora/layers/base_linear.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,18 @@ def set_lora(
121121
lora_bias = self.slice_bias(lora_bias)
122122

123123
self.lora_a_stacked[0][index,
124-
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
125-
lora_a.T, non_blocking=True)
124+
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
125+
lora_a, non_blocking=True)
126126
self.lora_b_stacked[0][index,
127-
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
128-
lora_b.T, non_blocking=True)
127+
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
128+
lora_b, non_blocking=True)
129129
if lora_bias is not None:
130130

131131
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
132132
self.lora_bias_stacked)
133133
assert len(self.lora_bias_stacked)
134134
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
135-
lora_bias.T, non_blocking=True)
135+
lora_bias, non_blocking=True)
136136

137137
def apply(self,
138138
x: torch.Tensor,

vllm/lora/layers/column_parallel_linear.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,21 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
9999
if self.is_merged_col_linear:
100100
tp_rank = get_tensor_model_parallel_rank()
101101
shard_size = self.output_size // 2
102-
offset = lora_b.shape[-1] // 2
102+
offset = lora_b.shape[0] // 2
103103

104-
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
105-
shard_size]
106-
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
107-
(tp_rank + 1) * shard_size]
108-
lora_b = torch.cat([left_weight, right_weight], dim=1)
104+
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
105+
shard_size, :]
106+
right_weight = lora_b[offset + tp_rank * shard_size:offset +
107+
(tp_rank + 1) * shard_size, :]
108+
lora_b = torch.cat([left_weight, right_weight], dim=0)
109109
# Applicable to cases where the base_layer is
110110
# ColumnParallelLinear.
111111
else:
112112
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
113113
shard_size = self.output_size
114114
start_idx = tensor_model_parallel_rank * shard_size
115115
end_idx = (tensor_model_parallel_rank + 1) * shard_size
116-
lora_b = lora_b[:, start_idx:end_idx]
116+
lora_b = lora_b[start_idx:end_idx, :]
117117
return lora_b
118118

119119
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@@ -251,9 +251,8 @@ def slice_lora_b(
251251
for i, (shard_id, shard_size) in enumerate(
252252
zip(self.output_ids, self.output_slices)):
253253
if (lora_b_i := lora_b[i]) is not None:
254-
sliced_lora_b[i] = lora_b_i[:,
255-
shard_size * shard_id:shard_size *
256-
(shard_id + 1)]
254+
sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size *
255+
(shard_id + 1), :]
257256
return sliced_lora_b
258257

259258
def slice_bias(
@@ -285,12 +284,12 @@ def set_lora(
285284
for i in range(self.n_slices):
286285
if (lora_a_i := lora_a[i]) is not None:
287286
self.lora_a_stacked[i][
288-
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
289-
lora_a_i.T, non_blocking=True)
287+
index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_(
288+
lora_a_i, non_blocking=True)
290289
if (lora_b_i := lora_b[i]) is not None:
291290
self.lora_b_stacked[i][
292-
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
293-
lora_b_i.T, non_blocking=True)
291+
index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_(
292+
lora_b_i, non_blocking=True)
294293

295294
if lora_bias is not None:
296295
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
@@ -299,7 +298,7 @@ def set_lora(
299298
if (lora_bias_i := lora_bias[i]) is not None:
300299
self.lora_bias_stacked[i][index,
301300
0, :lora_bias_i.shape[0]].copy_(
302-
lora_bias_i.T,
301+
lora_bias_i,
303302
non_blocking=True)
304303

305304
@classmethod
@@ -345,18 +344,18 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
345344
tp_rank = get_tensor_model_parallel_rank()
346345
self.q_shard_id = tp_rank
347346
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
348-
lora_b_q = lora_b[:, self.q_proj_shard_size *
347+
lora_b_q = lora_b[self.q_proj_shard_size *
349348
self.q_shard_id:self.q_proj_shard_size *
350-
(self.q_shard_id + 1)]
349+
(self.q_shard_id + 1), :]
351350
k_offset = self.q_proj_total_size
352-
lora_b_k = lora_b[:, k_offset +
351+
lora_b_k = lora_b[k_offset +
353352
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
354-
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
353+
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
355354
v_offset = k_offset + self.kv_proj_total_size
356-
lora_b_v = lora_b[:, v_offset +
355+
lora_b_v = lora_b[v_offset +
357356
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
358-
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
359-
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
357+
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
358+
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
360359
return lora_b
361360

362361
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@@ -465,7 +464,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
465464
tp_rank = get_tensor_model_parallel_rank()
466465
shard_size = self.lora_a_stacked[0].shape[2]
467466
start_idx = tp_rank * shard_size
468-
lora_a = lora_a[:, start_idx:start_idx + shard_size]
467+
lora_a = lora_a[start_idx:start_idx + shard_size, :]
469468
return lora_a
470469

471470
def apply(self,
@@ -508,10 +507,10 @@ def slice_lora_a(
508507
output_shard_size = self.lora_a_stacked[0].shape[2]
509508
output_start_idx = self.tp_rank * output_shard_size
510509
lora_a = [
511-
lora_a[0][:, output_start_idx:output_start_idx +
512-
output_shard_size] if lora_a[0] is not None else None,
513-
lora_a[1][:, output_start_idx:output_start_idx +
514-
output_shard_size] if lora_a[1] is not None else None,
510+
lora_a[0][output_start_idx:output_start_idx +
511+
output_shard_size, :] if lora_a[0] is not None else None,
512+
lora_a[1][output_start_idx:output_start_idx +
513+
output_shard_size, :] if lora_a[1] is not None else None,
515514
]
516515
return lora_a
517516

@@ -551,7 +550,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
551550
tp_rank = get_tensor_model_parallel_rank()
552551
shard_size = self.lora_a_stacked[0].shape[2]
553552
start_idx = tp_rank * shard_size
554-
lora_a = lora_a[:, start_idx:start_idx + shard_size]
553+
lora_a = lora_a[start_idx:start_idx + shard_size, :]
555554
return lora_a
556555

557556
def apply(self,
@@ -589,12 +588,12 @@ def slice_lora_a(
589588
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
590589
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
591590
lora_a = [
592-
lora_a[0][:, start_idx[0]:start_idx[0] +
593-
shard_size[0]] if lora_a[0] is not None else None,
594-
lora_a[1][:, start_idx[1]:start_idx[1] +
595-
shard_size[1]] if lora_a[1] is not None else None,
596-
lora_a[2][:, start_idx[2]:start_idx[2] +
597-
shard_size[2]] if lora_a[2] is not None else None,
591+
lora_a[0][start_idx[0]:start_idx[0] +
592+
shard_size[0], :] if lora_a[0] is not None else None,
593+
lora_a[1][start_idx[1]:start_idx[1] +
594+
shard_size[1], :] if lora_a[1] is not None else None,
595+
lora_a[2][start_idx[2]:start_idx[2] +
596+
shard_size[2], :] if lora_a[2] is not None else None,
598597
]
599598
return lora_a
600599

vllm/lora/layers/logits_processor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ def set_lora(
140140
):
141141
self.reset_lora(index)
142142
self.lora_a_stacked[index,
143-
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
144-
lora_a.T, non_blocking=True)
143+
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
144+
lora_a, non_blocking=True)
145145
self.lora_b_stacked[index,
146-
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
147-
lora_b.T, non_blocking=True)
146+
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
147+
lora_b, non_blocking=True)
148148
if embeddings_tensor is not None:
149149
self.embeddings_tensors[
150150
index,

vllm/lora/layers/row_parallel_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
3939
shard_size = self.input_size
4040
start_idx = self.tp_rank * shard_size
4141
end_idx = (self.tp_rank + 1) * shard_size
42-
lora_a = lora_a[start_idx:end_idx, :]
42+
lora_a = lora_a[:,start_idx:end_idx]
4343
return lora_a
4444

4545
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
@@ -122,7 +122,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
122122
shard_size = self.lora_b_stacked[0].shape[2]
123123
start_idx = self.tp_rank * shard_size
124124
end_idx = (self.tp_rank + 1) * shard_size
125-
lora_b = lora_b[:, start_idx:end_idx]
125+
lora_b = lora_b[ start_idx:end_idx,:]
126126
return lora_b
127127

128128
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:

vllm/lora/layers/vocal_parallel_embedding.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,13 @@ def set_lora(
9595
bias: Optional[torch.Tensor] = None,
9696
):
9797
self.reset_lora(index)
98-
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
99-
lora_a, non_blocking=True)
98+
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
99+
# so we need transpose here
100+
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
101+
lora_a.T, non_blocking=True)
100102
self.lora_b_stacked[index,
101-
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
102-
lora_b.T, non_blocking=True)
103+
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
104+
lora_b, non_blocking=True)
103105
if embeddings_tensor is not None:
104106
self.embeddings_tensors[
105107
index,

vllm/lora/lora_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def create_dummy_lora_weights(
8686
embeddings_tensor_dim: Optional[int] = None,
8787
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
8888
pin_memory = str(device) == "cpu" and is_pin_memory_available()
89-
lora_a = torch.zeros([input_dim, rank],
89+
lora_a = torch.zeros([rank, input_dim],
9090
dtype=dtype,
9191
device=device,
9292
pin_memory=pin_memory)
93-
lora_b = torch.zeros([rank, output_dim],
93+
lora_b = torch.zeros([output_dim, rank],
9494
dtype=dtype,
9595
device=device,
9696
pin_memory=pin_memory)

0 commit comments

Comments
 (0)