1
1
import json
2
2
import os
3
3
import sys
4
- from typing import List
5
- from typing import Optional
4
+ from typing import List , Optional , Any , Dict , Tuple
6
5
7
6
import torch
8
7
from loguru import logger
15
14
AutoModelForCausalLM ,
16
15
BitsAndBytesConfig ,
17
16
PreTrainedTokenizer ,
17
+ PreTrainedModel ,
18
18
)
19
19
from transformers .utils .versions import require_version
20
20
25
25
26
26
27
27
class BaseModelAdapter :
28
- """The base and the default model adapter."""
28
+ """ The base and default model adapter. """
29
29
30
30
model_names = []
31
31
32
32
def match (self , model_name ) -> bool :
33
+ """
34
+ Check if the given model name matches any of the predefined model names.
35
+
36
+ Args:
37
+ model_name (str): The model name to check.
38
+
39
+ Returns:
40
+ bool: True if the model name matches any of the predefined model names, False otherwise.
41
+ """
42
+
33
43
return any (m in model_name for m in self .model_names ) if self .model_names else True
34
44
35
45
def load_model (
36
46
self ,
37
47
model_name_or_path : Optional [str ] = None ,
38
48
adapter_model : Optional [str ] = None ,
39
- ** kwargs
40
- ):
41
- """ Load model through transformers. """
42
- model_name_or_path = self .default_model_name_or_path if model_name_or_path is None else model_name_or_path
43
- tokenizer_kwargs = {"trust_remote_code" : True , "use_fast" : False }
44
- tokenizer_kwargs .update (self .tokenizer_kwargs )
49
+ ** kwargs : Any ,
50
+ ) -> Tuple [PreTrainedModel , PreTrainedTokenizer ]:
51
+ """
52
+ Load a model and tokenizer based on the provided model name or path.
53
+
54
+ Args:
55
+ model_name_or_path (str, optional): The name or path of the model. Defaults to None.
56
+ adapter_model (str, optional): The adapter model to load the tokenizer from. Defaults to None.
57
+ **kwargs: Additional keyword arguments.
45
58
59
+ Returns:
60
+ Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
61
+ """
62
+
63
+ model_name_or_path = model_name_or_path or self .default_model_name_or_path
64
+ tokenizer_kwargs = {"trust_remote_code" : True , "use_fast" : False }
65
+ tokenizer_kwargs |= self .tokenizer_kwargs
66
+
67
+ # load a tokenizer from adapter model if it exists.
46
68
if adapter_model is not None :
47
69
try :
48
- tokenizer = self .tokenizer_class .from_pretrained (adapter_model , ** tokenizer_kwargs )
70
+ tokenizer = self .tokenizer_class .from_pretrained (
71
+ adapter_model , ** tokenizer_kwargs ,
72
+ )
49
73
except OSError :
50
- tokenizer = self .tokenizer_class .from_pretrained (model_name_or_path , ** tokenizer_kwargs )
74
+ tokenizer = self .tokenizer_class .from_pretrained (
75
+ model_name_or_path , ** tokenizer_kwargs ,
76
+ )
51
77
else :
52
- tokenizer = self .tokenizer_class .from_pretrained (model_name_or_path , ** tokenizer_kwargs )
78
+ tokenizer = self .tokenizer_class .from_pretrained (
79
+ model_name_or_path , ** tokenizer_kwargs ,
80
+ )
53
81
54
82
config_kwargs = self .model_kwargs
55
83
device = kwargs .get ("device" , "cuda" )
@@ -109,12 +137,10 @@ def load_model(
109
137
110
138
use_ptuning_v2 = kwargs .get ("use_ptuning_v2" , False )
111
139
if use_ptuning_v2 and adapter_model :
112
- prefix_encoder_file = open (f'{ adapter_model } /config.json' , 'r' )
113
- prefix_encoder_config = json .loads (prefix_encoder_file .read ())
114
- prefix_encoder_file .close ()
115
-
116
- config .pre_seq_len = prefix_encoder_config ['pre_seq_len' ]
117
- config .prefix_projection = prefix_encoder_config ['prefix_projection' ]
140
+ with open (f"{ adapter_model } /config.json" , "r" ) as prefix_encoder_file :
141
+ prefix_encoder_config = json .loads (prefix_encoder_file .read ())
142
+ config .pre_seq_len = prefix_encoder_config ["pre_seq_len" ]
143
+ config .prefix_projection = prefix_encoder_config ["prefix_projection" ]
118
144
119
145
# Load and prepare pretrained models (without valuehead).
120
146
model = self .model_class .from_pretrained (
@@ -148,16 +174,37 @@ def load_model(
148
174
149
175
return model , tokenizer
150
176
151
- def load_lora_model (self , model , adapter_model , model_kwargs ):
177
+ def load_lora_model (
178
+ self , model : PreTrainedModel , adapter_model : str , model_kwargs : Dict ,
179
+ ) -> PeftModel :
180
+ """
181
+ Load a LoRA model.
182
+
183
+ This function loads a LoRA model using the specified pretrained model and adapter model.
184
+
185
+ Args:
186
+ model (PreTrainedModel): The base pretrained model.
187
+ adapter_model (str): The name or path of the adapter model.
188
+ model_kwargs (dict): Additional keyword arguments for the model.
189
+
190
+ Returns:
191
+ PeftModel: The loaded LoRA model.
192
+ """
152
193
return PeftModel .from_pretrained (
153
194
model ,
154
195
adapter_model ,
155
196
torch_dtype = model_kwargs .get ("torch_dtype" , torch .float16 ),
156
197
)
157
198
158
199
def load_adapter_model (
159
- self , model , tokenizer , adapter_model , is_chatglm , model_kwargs , ** kwargs
160
- ):
200
+ self ,
201
+ model : PreTrainedModel ,
202
+ tokenizer : PreTrainedTokenizer ,
203
+ adapter_model : str ,
204
+ is_chatglm : bool ,
205
+ model_kwargs : Dict ,
206
+ ** kwargs : Any ,
207
+ ) -> PreTrainedModel :
161
208
use_ptuning_v2 = kwargs .get ("use_ptuning_v2" , False )
162
209
resize_embeddings = kwargs .get ("resize_embeddings" , False )
163
210
if adapter_model and resize_embeddings and not is_chatglm :
@@ -173,10 +220,11 @@ def load_adapter_model(
173
220
174
221
if use_ptuning_v2 :
175
222
prefix_state_dict = torch .load (os .path .join (adapter_model , "pytorch_model.bin" ))
176
- new_prefix_state_dict = {}
177
- for k , v in prefix_state_dict .items ():
178
- if k .startswith ("transformer.prefix_encoder." ):
179
- new_prefix_state_dict [k [len ("transformer.prefix_encoder." ):]] = v
223
+ new_prefix_state_dict = {
224
+ k [len ("transformer.prefix_encoder." ) :]: v
225
+ for k , v in prefix_state_dict .items ()
226
+ if k .startswith ("transformer.prefix_encoder." )
227
+ }
180
228
model .transformer .prefix_encoder .load_state_dict (new_prefix_state_dict )
181
229
model .transformer .prefix_encoder .float ()
182
230
else :
@@ -213,13 +261,21 @@ def default_model_name_or_path(self):
213
261
214
262
215
263
def register_model_adapter (cls ):
216
- """Register a model adapter."""
264
+ """ Register a model adapter. """
217
265
model_adapters .append (cls ())
218
266
219
267
220
268
@cache
221
269
def get_model_adapter (model_name : str ) -> BaseModelAdapter :
222
- """Get a model adapter for a model name."""
270
+ """
271
+ Get a model adapter for a given model name.
272
+
273
+ Args:
274
+ model_name (str): The name of the model.
275
+
276
+ Returns:
277
+ ModelAdapter: The model adapter that matches the given model name.
278
+ """
223
279
for adapter in model_adapters :
224
280
if adapter .match (model_name ):
225
281
return adapter
@@ -233,8 +289,23 @@ def load_model(
233
289
quantize : Optional [int ] = 16 ,
234
290
device : Optional [str ] = "cuda" ,
235
291
load_in_8bit : Optional [bool ] = False ,
236
- ** kwargs
237
- ):
292
+ ** kwargs : Any ,
293
+ ) -> Tuple [PreTrainedModel , PreTrainedTokenizer ]:
294
+ """
295
+ Load a pre-trained model and tokenizer.
296
+
297
+ Args:
298
+ model_name (str): The name of the model.
299
+ model_name_or_path (Optional[str], optional): The path or name of the pre-trained model. Defaults to None.
300
+ adapter_model (Optional[str], optional): The name of the adapter model. Defaults to None.
301
+ quantize (Optional[int], optional): The quantization level. Defaults to 16.
302
+ device (Optional[str], optional): The device to load the model on. Defaults to "cuda".
303
+ load_in_8bit (Optional[bool], optional): Whether to load the model in 8-bit mode. Defaults to False.
304
+ **kwargs (Any): Additional keyword arguments.
305
+
306
+ Returns:
307
+ Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
308
+ """
238
309
model_name = model_name .lower ()
239
310
240
311
if "tiger" in model_name :
@@ -496,6 +567,7 @@ def default_model_name_or_path(self):
496
567
497
568
498
569
register_model_adapter (ChatglmModelAdapter )
570
+ register_model_adapter (Chatglm3ModelAdapter )
499
571
register_model_adapter (LlamaModelAdapter )
500
572
register_model_adapter (MossModelAdapter )
501
573
register_model_adapter (PhoenixModelAdapter )
0 commit comments