1
- import json
2
1
import os
3
2
import sys
4
3
from typing import List , Optional , Any , Dict , Tuple
@@ -135,12 +134,8 @@ def load_model(
135
134
setattr (config , "bf16" , dtype == "bfloat16" )
136
135
config_kwargs .pop ("torch_dtype" , None )
137
136
138
- use_ptuning_v2 = kwargs .get ("use_ptuning_v2" , False )
139
- if use_ptuning_v2 and adapter_model :
140
- with open (f"{ adapter_model } /config.json" , "r" ) as prefix_encoder_file :
141
- prefix_encoder_config = json .loads (prefix_encoder_file .read ())
142
- config .pre_seq_len = prefix_encoder_config ["pre_seq_len" ]
143
- config .prefix_projection = prefix_encoder_config ["prefix_projection" ]
137
+ if kwargs .get ("using_ptuning_v2" , False ) and adapter_model :
138
+ config .pre_seq_len = kwargs .get ("pre_seq_len" , 128 )
144
139
145
140
# Load and prepare pretrained models (without valuehead).
146
141
model = self .model_class .from_pretrained (
@@ -205,7 +200,7 @@ def load_adapter_model(
205
200
model_kwargs : Dict ,
206
201
** kwargs : Any ,
207
202
) -> PreTrainedModel :
208
- use_ptuning_v2 = kwargs .get ("use_ptuning_v2 " , False )
203
+ using_ptuning_v2 = kwargs .get ("using_ptuning_v2 " , False )
209
204
resize_embeddings = kwargs .get ("resize_embeddings" , False )
210
205
if adapter_model and resize_embeddings and not is_chatglm :
211
206
model_vocab_size = model .get_input_embeddings ().weight .size (0 )
@@ -218,10 +213,10 @@ def load_adapter_model(
218
213
logger .info ("Resize model embeddings to fit tokenizer" )
219
214
model .resize_token_embeddings (tokenzier_vocab_size )
220
215
221
- if use_ptuning_v2 :
216
+ if using_ptuning_v2 :
222
217
prefix_state_dict = torch .load (os .path .join (adapter_model , "pytorch_model.bin" ))
223
218
new_prefix_state_dict = {
224
- k [len ("transformer.prefix_encoder." ) :]: v
219
+ k [len ("transformer.prefix_encoder." ):]: v
225
220
for k , v in prefix_state_dict .items ()
226
221
if k .startswith ("transformer.prefix_encoder." )
227
222
}
0 commit comments