Skip to content

Commit 162711b

Browse files
committed
Addressed review comments
Signed-off-by: Jou-An Chen <quic_jouachen@quicinc.com>
1 parent c3c00fc commit 162711b

File tree

4 files changed

+135
-123
lines changed

4 files changed

+135
-123
lines changed

QEfficient/lora/auto.py

Lines changed: 86 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import hashlib
99
import os
1010
from pathlib import Path
11-
from typing import Any, List, Optional
11+
from typing import List, Optional
1212

1313
import torch
1414
import torch.nn as nn
@@ -26,15 +26,14 @@
2626

2727
class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM):
2828
"""
29-
QEff class for loading models with mutltiple LoRA adapters.
29+
QEff class for loading models with multiple LoRA adapters.
3030
Once exported and compiled, the qpc can perform mixed batch inference with provided prompt_to_lora_id_mapping.
3131
3232
Args:
3333
:model (nn.Module): PyTorch model
3434
:base_model_name (str): Model card name for base model
3535
:adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping
3636
:adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping
37-
:active_adapters (Set): A set of lora_names that are currently active
3837
:max_num_adapters (int): Total number of active adapters that to be exported and compiled
3938
:active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping
4039
@@ -65,7 +64,6 @@ def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwarg
6564
self.base_model_name = pretrained_model_name_or_path
6665
self.adapter_weights = {}
6766
self.adapter_configs = {}
68-
self.active_adapters = set()
6967
self.max_num_adapters = 0
7068
self.active_adapter_to_id = {}
7169

@@ -81,13 +79,13 @@ def model_hash(self) -> str:
8179

8280
# create active adapter config dict
8381
active_adapter_configs = {}
84-
for adpt in self.active_adapters:
82+
for adpt in self.active_adapter_to_id.keys():
8583
active_adapter_configs[adpt] = self.adapter_configs[adpt].to_dict()
8684
mhash.update(to_hashable(active_adapter_configs))
8785

8886
# create active adapter weight dict
8987
active_adapter_weights = {}
90-
for adpt in self.active_adapters:
88+
for adpt in self.active_adapter_to_id.keys():
9189
active_adapter_weights[adpt] = {key: value.tolist() for key, value in self.adapter_weights[adpt].items()}
9290
mhash.update(to_hashable(active_adapter_weights))
9391

@@ -97,69 +95,78 @@ def model_hash(self) -> str:
9795
mhash = mhash.hexdigest()[:16]
9896
return mhash
9997

100-
def download_adapter(self, adapter_model_id: str, adapter_name: str):
98+
def download_adapter(
99+
self,
100+
adapter_model_id: str,
101+
adapter_name: str,
102+
adapter_weight: Optional[dict] = None,
103+
adapter_config: Optional[PeftConfig] = None,
104+
):
101105
"""Loads a new adapter from huggingface hub or local path into CPU cache
102106
103107
Args:
104108
:adapter_model_id (str): Adapter model ID from huggingface hub or local path
105109
:adapter_name (str): Adapter name to be used to set this adapter as current
106110
"""
107-
if (adapter_name in self.adapter_weights.keys()) and (adapter_name in self.adapter_configs.keys()):
108-
logger.warning(f"Overwrite weights and configs for adapter name {adapter_name}")
109111

110-
self.adapter_weights[adapter_name] = {
111-
k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items()
112-
}
113-
self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id)
114-
115-
def load_adapter(self, adapter_model_id: str, adapter_name: str, **kwargs: Any):
116-
"Load adapter into CPU cache and Sets active adapter from one of the loaded adapters"
117-
118-
# check if adapter name already exist, if so, overwrite it
112+
# check if adapter name already loaded
119113
if (adapter_name in self.adapter_weights.keys()) and (adapter_name in self.adapter_configs.keys()):
120-
logger.warning(f"Overwrite weights and configs for adapter name {adapter_name}")
121-
122-
adapter_weight = kwargs.pop("adapter_weight", None)
123-
adapter_config = kwargs.pop("adapter_config", None)
124-
125-
if adapter_weight and adapter_config: # if sufficiently get adapter weight and adpater config
126-
self.adapter_weights[adapter_name] = adapter_weight
127-
self.adapter_configs[adapter_name] = adapter_config
128-
else: # load from hugging face
129-
self.adapter_weights[adapter_name] = {
130-
k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items()
131-
}
132-
self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id)
133-
134-
# check if adapters has same target module and rank
135-
assert (
136-
list(self.adapter_configs.values())[0]
137-
and self.adapter_configs[adapter_name].target_modules
138-
== list(self.adapter_configs.values())[0].target_modules
139-
), "Not all adapters have the same target modules"
140-
141-
assert (
142-
list(self.adapter_configs.values())[0]
143-
and self.adapter_configs[adapter_name].r == list(self.adapter_configs.values())[0].r
144-
), "Not all adapters have the same ranks"
145-
146-
# set active adapter id to current max if adapter_name is new
147-
if adapter_name not in self.active_adapter_to_id.keys():
148-
self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base
114+
logger.warning(f"{adapter_name} has been loaded. Skip download.")
115+
else:
116+
if adapter_weight and adapter_config: # if sufficiently get adapter weight and adpater config
117+
self.adapter_weights[adapter_name] = adapter_weight
118+
self.adapter_configs[adapter_name] = adapter_config
119+
else: # donwload with adapter_model_id
120+
self.adapter_weights[adapter_name] = {
121+
k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items()
122+
}
123+
self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id)
124+
125+
def load_adapter(
126+
self,
127+
adapter_model_id: str,
128+
adapter_name: str,
129+
adapter_weight: Optional[dict] = None,
130+
adapter_config: Optional[PeftConfig] = None,
131+
):
132+
"Load adapter into CPU cache and Sets active adapter from one of the loaded adapters"
149133

150-
# add active adapter to set
151-
self.active_adapters.add(adapter_name)
152-
self.max_num_adapters = len(self.active_adapters)
134+
# check if adapter name already exist and activated
135+
if adapter_name in self.active_adapter_to_id.keys():
136+
logger.warning(f"{adapter_name} exists and activated. Please provide a different adapter_name.")
137+
else:
138+
self.download_adapter(adapter_model_id, adapter_name, adapter_weight, adapter_config)
139+
140+
# starting from the second adapter_name, check if adapters has same target module and rank
141+
if list(self.adapter_configs.values())[0] and (
142+
self.adapter_configs[adapter_name].target_modules
143+
!= list(self.adapter_configs.values())[0].target_modules
144+
):
145+
raise ValueError(
146+
f"{adapter_name} must have same target_modules as {list(self.adapter_configs.keys())[0]}"
147+
)
148+
if list(self.adapter_configs.values())[0] and (
149+
self.adapter_configs[adapter_name].r != list(self.adapter_configs.values())[0].r
150+
):
151+
raise ValueError(f"{adapter_name} must have same rank as {list(self.adapter_configs.keys())[0]}")
152+
153+
# set active adapter id to current max if adapter_name is new
154+
if adapter_name not in self.active_adapter_to_id.keys():
155+
self.active_adapter_to_id[adapter_name] = self.max_num_adapters + 1 # reserve 0 for base
156+
157+
# add active adapter to set
158+
self.max_num_adapters = len(self.active_adapter_to_id)
153159

154160
return self.active_adapter_to_id[adapter_name]
155161

156162
def unload_adapter(self, adapter_name: str):
157-
# remove from active list
158-
if adapter_name not in self.active_adapters:
159-
print(f"Adapter name {adapter_name} is not set active yet")
163+
"Deactivate adpater and remove it from CPU cache"
164+
165+
# step1: remove from active list if it's there
166+
if adapter_name not in self.active_adapter_to_id.keys():
167+
logger.info(f"Adapter name {adapter_name} is not set active yet")
160168
return False
161169

162-
self.active_adapters.discard(adapter_name)
163170
self.max_num_adapters -= 1
164171
self.active_adapter_to_id.pop(adapter_name)
165172

@@ -173,18 +180,11 @@ def unload_adapter(self, adapter_name: str):
173180
self.onnx_path = None
174181
self.qpc_path = None
175182

176-
# delete from cache
177-
if adapter_name not in self.adapter_weights.keys() and adapter_name not in self.adapter_configs.keys():
178-
print(f"Adapter name {adapter_name} is not loaded yet")
179-
return False
180-
181-
if adapter_name in self.active_adapters:
182-
print(f"Adapter name {adapter_name} is stil in active list, do delete_adapter() before unloading")
183-
return False
184-
185-
self.adapter_weights.pop(adapter_name)
186-
self.adapter_configs.pop(adapter_name)
187-
logger.warning(f"Unloading {adapter_name} from CPU cache.")
183+
# step2: delete from cache
184+
if adapter_name in self.adapter_weights.keys() and adapter_name in self.adapter_configs.keys():
185+
self.adapter_weights.pop(adapter_name)
186+
self.adapter_configs.pop(adapter_name)
187+
logger.warning(f"Unloading {adapter_name} from CPU cache.")
188188

189189
return True
190190

@@ -202,15 +202,10 @@ def load_adapter_weights_to_model(self):
202202
# stack all adapters weights
203203
a_tensor_list = list(range(self.max_num_adapters + 1))
204204
b_tensor_list = list(range(self.max_num_adapters + 1))
205-
c_tensor_list = list(range(self.max_num_adapters + 1))
205+
s_tensor_list = list(range(self.max_num_adapters + 1))
206206

207207
for lora_name, lora_id in self.active_adapter_to_id.items():
208-
if (
209-
target_module == "q_proj"
210-
or target_module == "k_proj"
211-
or target_module == "v_proj"
212-
or target_module == "o_proj"
213-
):
208+
if target_module in ["q_proj", "k_proj", "v_proj", "o_proj"]:
214209
a_tensor_list[lora_id] = torch.from_numpy(
215210
self.adapter_weights[lora_name][
216211
f"base_model.model.model.layers.{i}.self_attn.{target_module}.lora_A.weight"
@@ -224,25 +219,25 @@ def load_adapter_weights_to_model(self):
224219
else:
225220
raise NotImplementedError("Target module not supported!!")
226221

227-
c_tensor_list[lora_id] = torch.tensor(
222+
s_tensor_list[lora_id] = torch.tensor(
228223
self.adapter_configs[lora_name].lora_alpha / self.adapter_configs[lora_name].r,
229224
dtype=torch.float16,
230225
)
231226

232227
# dummy zero tensor for base model
233228
a_tensor_list[0] = torch.zeros_like(a_tensor_list[1])
234229
b_tensor_list[0] = torch.zeros_like(b_tensor_list[1])
235-
c_tensor_list[0] = torch.zeros_like(c_tensor_list[1])
230+
s_tensor_list[0] = torch.zeros_like(s_tensor_list[1])
236231

237232
# stack weight tensors
238-
stacked_lora_A = (
233+
stacked_lora_a = (
239234
torch.stack(a_tensor_list, dim=0).unsqueeze(1).transpose(2, 3)
240235
) # <num_loras, 1, in_feature, r>
241-
stacked_lora_B = (
236+
stacked_lora_b = (
242237
torch.stack(b_tensor_list, dim=0).unsqueeze(1).transpose(2, 3)
243238
) # <num_loras, 1, r, out_feature>
244-
stacked_lora_C = (
245-
torch.stack(c_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3)
239+
stacked_lora_s = (
240+
torch.stack(s_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3)
246241
) # <num_loras, 1, 1, 1>
247242

248243
# stored weight to corresponding ops
@@ -257,26 +252,18 @@ def load_adapter_weights_to_model(self):
257252
else:
258253
raise NotImplementedError("Target module not supported!!")
259254

260-
module.lora_weight_A.copy_(stacked_lora_A)
261-
module.lora_weight_B.copy_(stacked_lora_B)
262-
module.lora_weight_C.copy_(stacked_lora_C)
255+
module.lora_a_weights.copy_(stacked_lora_a)
256+
module.lora_b_weights.copy_(stacked_lora_b)
257+
module.lora_scalings.copy_(stacked_lora_s)
263258

264259
def init_adapter_model(self):
265260
"Initialize the fixed lora model with multiple adapter weigths standby"
266261

267262
# assume all adapters have same target_modules and ranks
268-
assert self.max_num_adapters == len(self.active_adapters), "Inconsistent max_num_adapters and active_adapters"
269-
270-
assert list(self.adapter_configs.values())[0] and all(
271-
list(self.adapter_configs.values())[i].target_modules
272-
== list(self.adapter_configs.values())[0].target_modules
273-
for i in range(self.max_num_adapters)
274-
), "Not all adapters have the same target modules"
275-
276-
assert list(self.adapter_configs.values())[0] and all(
277-
list(self.adapter_configs.values())[i].r == list(self.adapter_configs.values())[0].r
278-
for i in range(self.max_num_adapters)
279-
), "Not all adapters have the same ranks"
263+
if self.max_num_adapters != len(self.active_adapter_to_id):
264+
raise ValueError("Inconsistent max_num_adapters and active adapters")
265+
266+
# set lora rank
280267
self.lora_rank = list(self.adapter_configs.values())[0].r
281268

282269
# do the module replacement
@@ -328,7 +315,7 @@ def export(self, **kwargs) -> str:
328315

329316
if Path(onnx_path).is_file():
330317
self.onnx_path = onnx_path
331-
print(f"Using existing onnx path:-{self.onnx_path}")
318+
logger.info(f"Using existing onnx path:-{self.onnx_path}")
332319
return self.onnx_path
333320

334321
# Export
@@ -405,14 +392,16 @@ def export_and_compile(
405392
mxint8=mxint8,
406393
full_batch_size=full_batch_size,
407394
)
408-
print(f"Generated qpc:-{qpc_path}")
395+
logger.info(f"Generated qpc:-{qpc_path}")
409396
else:
410397
self.qpc_path = qpc_path
411-
print(f"Using existing qpc path:-{self.qpc_path}")
398+
logger.info(f"Using existing qpc path:-{self.qpc_path}")
412399

413400
return self.qpc_path
414401

415402
def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs):
403+
"Execute on cloud ai 100 with prompt_to_lora_id_mapping passed in"
404+
416405
assert isinstance(self.qpc_path, str), "Please run compile API first!"
417406
generation_len = kwargs.pop("generation_len", None)
418407
default_mapping = [0 for _ in range(len(prompts))]

QEfficient/lora/layers.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,47 @@
1717

1818
class LinearMultiLoRA(nn.Linear):
1919
def multilora_init(self, lora_rank, max_num_adapters):
20+
if lora_rank < 1 or max_num_adapters < 1:
21+
raise ValueError("lora_rank and max_num_adapters must be greater or equal to 1")
22+
2023
self.max_num_adapters = max_num_adapters
2124
self.lora_rank = lora_rank
2225

23-
self.lora_weight_A = nn.Parameter(
26+
self.lora_a_weights = nn.Parameter(
2427
self.weight.new_zeros(self.max_num_adapters + 1, 1, self.in_features, self.lora_rank)
2528
)
26-
self.lora_weight_A.requires_grad = False
27-
self.lora_weight_B = nn.Parameter(
29+
self.lora_a_weights.requires_grad = False
30+
self.lora_b_weights = nn.Parameter(
2831
self.weight.new_zeros(self.max_num_adapters + 1, 1, self.lora_rank, self.out_features)
2932
)
30-
self.lora_weight_B.requires_grad = False
31-
self.lora_weight_C = torch.full((self.max_num_adapters + 1, 1, 1, 1), 1.0, dtype=torch.float)
33+
self.lora_b_weights.requires_grad = False
34+
self.lora_scalings = torch.full((self.max_num_adapters + 1, 1, 1, 1), 1.0, dtype=torch.float)
3235

33-
nn.init.kaiming_uniform_(self.lora_weight_A, a=math.sqrt(5))
34-
nn.init.zeros_(self.lora_weight_B)
36+
nn.init.kaiming_uniform_(self.lora_a_weights, a=math.sqrt(5))
37+
nn.init.zeros_(self.lora_b_weights)
3538

3639
def forward(self, x: torch.Tensor, lora_ids: torch.Tensor):
3740
result = F.linear(x, self.weight, bias=self.bias)
3841

3942
# multilora implementation: lora_ids <batch_size, 1>
40-
other_indices_A = torch.arange(self.lora_weight_A.shape[2]).view(1, 1, -1)
41-
A_embedding = CtxGatherFuncCB.apply(self.lora_weight_A, lora_ids, other_indices_A) # <num_loras, 1, feature, r>
42-
other_indices_B = torch.arange(self.lora_weight_B.shape[2]).view(1, 1, -1)
43-
B_embedding = CtxGatherFuncCB.apply(self.lora_weight_B, lora_ids, other_indices_B) # <num_loras, 1, r, feature>
44-
other_indices_C = torch.arange(self.lora_weight_C.shape[2]).view(1, 1, -1)
45-
C_embedding = CtxGatherFuncCB.apply(self.lora_weight_C, lora_ids, other_indices_C) # <num_loras, 1, 1, 1>
46-
47-
A_embedding = A_embedding.squeeze(1)
48-
B_embedding = B_embedding.squeeze(1)
49-
C_embedding = C_embedding.squeeze(1)
50-
51-
result = result + x @ A_embedding @ B_embedding * C_embedding
43+
other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1)
44+
selected_lora_a_weights = CtxGatherFuncCB.apply(
45+
self.lora_a_weights, lora_ids, other_indices_a
46+
) # <num_loras, 1, feature, r>
47+
other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1)
48+
selected_lora_b_weights = CtxGatherFuncCB.apply(
49+
self.lora_b_weights, lora_ids, other_indices_b
50+
) # <num_loras, 1, r, feature>
51+
other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1)
52+
selected_lora_scalings = CtxGatherFuncCB.apply(
53+
self.lora_scalings, lora_ids, other_indices_s
54+
) # <num_loras, 1, 1, 1>
55+
56+
selected_lora_a_weights = selected_lora_a_weights.squeeze(1)
57+
selected_lora_b_weights = selected_lora_b_weights.squeeze(1)
58+
selected_lora_scalings = selected_lora_scalings.squeeze(1)
59+
60+
result = result + x @ selected_lora_a_weights @ selected_lora_b_weights * selected_lora_scalings
5261

5362
return result
5463

0 commit comments

Comments
 (0)