8
8
import hashlib
9
9
import os
10
10
from pathlib import Path
11
- from typing import Any , List , Optional
11
+ from typing import List , Optional
12
12
13
13
import torch
14
14
import torch .nn as nn
26
26
27
27
class QEffAutoLoraModelForCausalLM (QEFFAutoModelForCausalLM ):
28
28
"""
29
- QEff class for loading models with mutltiple LoRA adapters.
29
+ QEff class for loading models with multiple LoRA adapters.
30
30
Once exported and compiled, the qpc can perform mixed batch inference with provided prompt_to_lora_id_mapping.
31
31
32
32
Args:
33
33
:model (nn.Module): PyTorch model
34
34
:base_model_name (str): Model card name for base model
35
35
:adapter_weights (Dict): A dictionary contains lora_name to lora_weight mapping
36
36
:adapter_configs (Dict): A dictionary contains lora_name to lora_configs mapping
37
- :active_adapters (Set): A set of lora_names that are currently active
38
37
:max_num_adapters (int): Total number of active adapters that to be exported and compiled
39
38
:active_adapter_to_id (Dict): A dictionary contains active adapter's lora_name to lora_id mapping
40
39
@@ -65,7 +64,6 @@ def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwarg
65
64
self .base_model_name = pretrained_model_name_or_path
66
65
self .adapter_weights = {}
67
66
self .adapter_configs = {}
68
- self .active_adapters = set ()
69
67
self .max_num_adapters = 0
70
68
self .active_adapter_to_id = {}
71
69
@@ -81,13 +79,13 @@ def model_hash(self) -> str:
81
79
82
80
# create active adapter config dict
83
81
active_adapter_configs = {}
84
- for adpt in self .active_adapters :
82
+ for adpt in self .active_adapter_to_id . keys () :
85
83
active_adapter_configs [adpt ] = self .adapter_configs [adpt ].to_dict ()
86
84
mhash .update (to_hashable (active_adapter_configs ))
87
85
88
86
# create active adapter weight dict
89
87
active_adapter_weights = {}
90
- for adpt in self .active_adapters :
88
+ for adpt in self .active_adapter_to_id . keys () :
91
89
active_adapter_weights [adpt ] = {key : value .tolist () for key , value in self .adapter_weights [adpt ].items ()}
92
90
mhash .update (to_hashable (active_adapter_weights ))
93
91
@@ -97,69 +95,78 @@ def model_hash(self) -> str:
97
95
mhash = mhash .hexdigest ()[:16 ]
98
96
return mhash
99
97
100
- def download_adapter (self , adapter_model_id : str , adapter_name : str ):
98
+ def download_adapter (
99
+ self ,
100
+ adapter_model_id : str ,
101
+ adapter_name : str ,
102
+ adapter_weight : Optional [dict ] = None ,
103
+ adapter_config : Optional [PeftConfig ] = None ,
104
+ ):
101
105
"""Loads a new adapter from huggingface hub or local path into CPU cache
102
106
103
107
Args:
104
108
:adapter_model_id (str): Adapter model ID from huggingface hub or local path
105
109
:adapter_name (str): Adapter name to be used to set this adapter as current
106
110
"""
107
- if (adapter_name in self .adapter_weights .keys ()) and (adapter_name in self .adapter_configs .keys ()):
108
- logger .warning (f"Overwrite weights and configs for adapter name { adapter_name } " )
109
111
110
- self .adapter_weights [adapter_name ] = {
111
- k : v .numpy ().astype ("float16" ) for k , v in load_peft_weights (adapter_model_id ).items ()
112
- }
113
- self .adapter_configs [adapter_name ] = PeftConfig .from_pretrained (adapter_model_id )
114
-
115
- def load_adapter (self , adapter_model_id : str , adapter_name : str , ** kwargs : Any ):
116
- "Load adapter into CPU cache and Sets active adapter from one of the loaded adapters"
117
-
118
- # check if adapter name already exist, if so, overwrite it
112
+ # check if adapter name already loaded
119
113
if (adapter_name in self .adapter_weights .keys ()) and (adapter_name in self .adapter_configs .keys ()):
120
- logger .warning (f"Overwrite weights and configs for adapter name { adapter_name } " )
121
-
122
- adapter_weight = kwargs .pop ("adapter_weight" , None )
123
- adapter_config = kwargs .pop ("adapter_config" , None )
124
-
125
- if adapter_weight and adapter_config : # if sufficiently get adapter weight and adpater config
126
- self .adapter_weights [adapter_name ] = adapter_weight
127
- self .adapter_configs [adapter_name ] = adapter_config
128
- else : # load from hugging face
129
- self .adapter_weights [adapter_name ] = {
130
- k : v .numpy ().astype ("float16" ) for k , v in load_peft_weights (adapter_model_id ).items ()
131
- }
132
- self .adapter_configs [adapter_name ] = PeftConfig .from_pretrained (adapter_model_id )
133
-
134
- # check if adapters has same target module and rank
135
- assert (
136
- list (self .adapter_configs .values ())[0 ]
137
- and self .adapter_configs [adapter_name ].target_modules
138
- == list (self .adapter_configs .values ())[0 ].target_modules
139
- ), "Not all adapters have the same target modules"
140
-
141
- assert (
142
- list (self .adapter_configs .values ())[0 ]
143
- and self .adapter_configs [adapter_name ].r == list (self .adapter_configs .values ())[0 ].r
144
- ), "Not all adapters have the same ranks"
145
-
146
- # set active adapter id to current max if adapter_name is new
147
- if adapter_name not in self .active_adapter_to_id .keys ():
148
- self .active_adapter_to_id [adapter_name ] = self .max_num_adapters + 1 # reserve 0 for base
114
+ logger .warning (f"{ adapter_name } has been loaded. Skip download." )
115
+ else :
116
+ if adapter_weight and adapter_config : # if sufficiently get adapter weight and adpater config
117
+ self .adapter_weights [adapter_name ] = adapter_weight
118
+ self .adapter_configs [adapter_name ] = adapter_config
119
+ else : # donwload with adapter_model_id
120
+ self .adapter_weights [adapter_name ] = {
121
+ k : v .numpy ().astype ("float16" ) for k , v in load_peft_weights (adapter_model_id ).items ()
122
+ }
123
+ self .adapter_configs [adapter_name ] = PeftConfig .from_pretrained (adapter_model_id )
124
+
125
+ def load_adapter (
126
+ self ,
127
+ adapter_model_id : str ,
128
+ adapter_name : str ,
129
+ adapter_weight : Optional [dict ] = None ,
130
+ adapter_config : Optional [PeftConfig ] = None ,
131
+ ):
132
+ "Load adapter into CPU cache and Sets active adapter from one of the loaded adapters"
149
133
150
- # add active adapter to set
151
- self .active_adapters .add (adapter_name )
152
- self .max_num_adapters = len (self .active_adapters )
134
+ # check if adapter name already exist and activated
135
+ if adapter_name in self .active_adapter_to_id .keys ():
136
+ logger .warning (f"{ adapter_name } exists and activated. Please provide a different adapter_name." )
137
+ else :
138
+ self .download_adapter (adapter_model_id , adapter_name , adapter_weight , adapter_config )
139
+
140
+ # starting from the second adapter_name, check if adapters has same target module and rank
141
+ if list (self .adapter_configs .values ())[0 ] and (
142
+ self .adapter_configs [adapter_name ].target_modules
143
+ != list (self .adapter_configs .values ())[0 ].target_modules
144
+ ):
145
+ raise ValueError (
146
+ f"{ adapter_name } must have same target_modules as { list (self .adapter_configs .keys ())[0 ]} "
147
+ )
148
+ if list (self .adapter_configs .values ())[0 ] and (
149
+ self .adapter_configs [adapter_name ].r != list (self .adapter_configs .values ())[0 ].r
150
+ ):
151
+ raise ValueError (f"{ adapter_name } must have same rank as { list (self .adapter_configs .keys ())[0 ]} " )
152
+
153
+ # set active adapter id to current max if adapter_name is new
154
+ if adapter_name not in self .active_adapter_to_id .keys ():
155
+ self .active_adapter_to_id [adapter_name ] = self .max_num_adapters + 1 # reserve 0 for base
156
+
157
+ # add active adapter to set
158
+ self .max_num_adapters = len (self .active_adapter_to_id )
153
159
154
160
return self .active_adapter_to_id [adapter_name ]
155
161
156
162
def unload_adapter (self , adapter_name : str ):
157
- # remove from active list
158
- if adapter_name not in self .active_adapters :
159
- print (f"Adapter name { adapter_name } is not set active yet" )
163
+ "Deactivate adpater and remove it from CPU cache"
164
+
165
+ # step1: remove from active list if it's there
166
+ if adapter_name not in self .active_adapter_to_id .keys ():
167
+ logger .info (f"Adapter name { adapter_name } is not set active yet" )
160
168
return False
161
169
162
- self .active_adapters .discard (adapter_name )
163
170
self .max_num_adapters -= 1
164
171
self .active_adapter_to_id .pop (adapter_name )
165
172
@@ -173,18 +180,11 @@ def unload_adapter(self, adapter_name: str):
173
180
self .onnx_path = None
174
181
self .qpc_path = None
175
182
176
- # delete from cache
177
- if adapter_name not in self .adapter_weights .keys () and adapter_name not in self .adapter_configs .keys ():
178
- print (f"Adapter name { adapter_name } is not loaded yet" )
179
- return False
180
-
181
- if adapter_name in self .active_adapters :
182
- print (f"Adapter name { adapter_name } is stil in active list, do delete_adapter() before unloading" )
183
- return False
184
-
185
- self .adapter_weights .pop (adapter_name )
186
- self .adapter_configs .pop (adapter_name )
187
- logger .warning (f"Unloading { adapter_name } from CPU cache." )
183
+ # step2: delete from cache
184
+ if adapter_name in self .adapter_weights .keys () and adapter_name in self .adapter_configs .keys ():
185
+ self .adapter_weights .pop (adapter_name )
186
+ self .adapter_configs .pop (adapter_name )
187
+ logger .warning (f"Unloading { adapter_name } from CPU cache." )
188
188
189
189
return True
190
190
@@ -202,15 +202,10 @@ def load_adapter_weights_to_model(self):
202
202
# stack all adapters weights
203
203
a_tensor_list = list (range (self .max_num_adapters + 1 ))
204
204
b_tensor_list = list (range (self .max_num_adapters + 1 ))
205
- c_tensor_list = list (range (self .max_num_adapters + 1 ))
205
+ s_tensor_list = list (range (self .max_num_adapters + 1 ))
206
206
207
207
for lora_name , lora_id in self .active_adapter_to_id .items ():
208
- if (
209
- target_module == "q_proj"
210
- or target_module == "k_proj"
211
- or target_module == "v_proj"
212
- or target_module == "o_proj"
213
- ):
208
+ if target_module in ["q_proj" , "k_proj" , "v_proj" , "o_proj" ]:
214
209
a_tensor_list [lora_id ] = torch .from_numpy (
215
210
self .adapter_weights [lora_name ][
216
211
f"base_model.model.model.layers.{ i } .self_attn.{ target_module } .lora_A.weight"
@@ -224,25 +219,25 @@ def load_adapter_weights_to_model(self):
224
219
else :
225
220
raise NotImplementedError ("Target module not supported!!" )
226
221
227
- c_tensor_list [lora_id ] = torch .tensor (
222
+ s_tensor_list [lora_id ] = torch .tensor (
228
223
self .adapter_configs [lora_name ].lora_alpha / self .adapter_configs [lora_name ].r ,
229
224
dtype = torch .float16 ,
230
225
)
231
226
232
227
# dummy zero tensor for base model
233
228
a_tensor_list [0 ] = torch .zeros_like (a_tensor_list [1 ])
234
229
b_tensor_list [0 ] = torch .zeros_like (b_tensor_list [1 ])
235
- c_tensor_list [0 ] = torch .zeros_like (c_tensor_list [1 ])
230
+ s_tensor_list [0 ] = torch .zeros_like (s_tensor_list [1 ])
236
231
237
232
# stack weight tensors
238
- stacked_lora_A = (
233
+ stacked_lora_a = (
239
234
torch .stack (a_tensor_list , dim = 0 ).unsqueeze (1 ).transpose (2 , 3 )
240
235
) # <num_loras, 1, in_feature, r>
241
- stacked_lora_B = (
236
+ stacked_lora_b = (
242
237
torch .stack (b_tensor_list , dim = 0 ).unsqueeze (1 ).transpose (2 , 3 )
243
238
) # <num_loras, 1, r, out_feature>
244
- stacked_lora_C = (
245
- torch .stack (c_tensor_list , dim = 0 ).unsqueeze (1 ).unsqueeze (2 ).unsqueeze (3 )
239
+ stacked_lora_s = (
240
+ torch .stack (s_tensor_list , dim = 0 ).unsqueeze (1 ).unsqueeze (2 ).unsqueeze (3 )
246
241
) # <num_loras, 1, 1, 1>
247
242
248
243
# stored weight to corresponding ops
@@ -257,26 +252,18 @@ def load_adapter_weights_to_model(self):
257
252
else :
258
253
raise NotImplementedError ("Target module not supported!!" )
259
254
260
- module .lora_weight_A .copy_ (stacked_lora_A )
261
- module .lora_weight_B .copy_ (stacked_lora_B )
262
- module .lora_weight_C .copy_ (stacked_lora_C )
255
+ module .lora_a_weights .copy_ (stacked_lora_a )
256
+ module .lora_b_weights .copy_ (stacked_lora_b )
257
+ module .lora_scalings .copy_ (stacked_lora_s )
263
258
264
259
def init_adapter_model (self ):
265
260
"Initialize the fixed lora model with multiple adapter weigths standby"
266
261
267
262
# assume all adapters have same target_modules and ranks
268
- assert self .max_num_adapters == len (self .active_adapters ), "Inconsistent max_num_adapters and active_adapters"
269
-
270
- assert list (self .adapter_configs .values ())[0 ] and all (
271
- list (self .adapter_configs .values ())[i ].target_modules
272
- == list (self .adapter_configs .values ())[0 ].target_modules
273
- for i in range (self .max_num_adapters )
274
- ), "Not all adapters have the same target modules"
275
-
276
- assert list (self .adapter_configs .values ())[0 ] and all (
277
- list (self .adapter_configs .values ())[i ].r == list (self .adapter_configs .values ())[0 ].r
278
- for i in range (self .max_num_adapters )
279
- ), "Not all adapters have the same ranks"
263
+ if self .max_num_adapters != len (self .active_adapter_to_id ):
264
+ raise ValueError ("Inconsistent max_num_adapters and active adapters" )
265
+
266
+ # set lora rank
280
267
self .lora_rank = list (self .adapter_configs .values ())[0 ].r
281
268
282
269
# do the module replacement
@@ -328,7 +315,7 @@ def export(self, **kwargs) -> str:
328
315
329
316
if Path (onnx_path ).is_file ():
330
317
self .onnx_path = onnx_path
331
- print (f"Using existing onnx path:-{ self .onnx_path } " )
318
+ logger . info (f"Using existing onnx path:-{ self .onnx_path } " )
332
319
return self .onnx_path
333
320
334
321
# Export
@@ -405,14 +392,16 @@ def export_and_compile(
405
392
mxint8 = mxint8 ,
406
393
full_batch_size = full_batch_size ,
407
394
)
408
- print (f"Generated qpc:-{ qpc_path } " )
395
+ logger . info (f"Generated qpc:-{ qpc_path } " )
409
396
else :
410
397
self .qpc_path = qpc_path
411
- print (f"Using existing qpc path:-{ self .qpc_path } " )
398
+ logger . info (f"Using existing qpc path:-{ self .qpc_path } " )
412
399
413
400
return self .qpc_path
414
401
415
402
def run_cloud_ai_100 (self , prompts : List [str ], device_id : List [int ] = None , ** kwargs ):
403
+ "Execute on cloud ai 100 with prompt_to_lora_id_mapping passed in"
404
+
416
405
assert isinstance (self .qpc_path , str ), "Please run compile API first!"
417
406
generation_len = kwargs .pop ("generation_len" , None )
418
407
default_mapping = [0 for _ in range (len (prompts ))]
0 commit comments