Skip to content

Commit 553a867

Browse files
author
xusenlin
committed
Fix bug for p-tuning
1 parent b5d1b02 commit 553a867

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

api/adapter/model.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import os
32
import sys
43
from typing import List, Optional, Any, Dict, Tuple
@@ -135,12 +134,8 @@ def load_model(
135134
setattr(config, "bf16", dtype == "bfloat16")
136135
config_kwargs.pop("torch_dtype", None)
137136

138-
use_ptuning_v2 = kwargs.get("use_ptuning_v2", False)
139-
if use_ptuning_v2 and adapter_model:
140-
with open(f"{adapter_model}/config.json", "r") as prefix_encoder_file:
141-
prefix_encoder_config = json.loads(prefix_encoder_file.read())
142-
config.pre_seq_len = prefix_encoder_config["pre_seq_len"]
143-
config.prefix_projection = prefix_encoder_config["prefix_projection"]
137+
if kwargs.get("using_ptuning_v2", False) and adapter_model:
138+
config.pre_seq_len = kwargs.get("pre_seq_len", 128)
144139

145140
# Load and prepare pretrained models (without valuehead).
146141
model = self.model_class.from_pretrained(
@@ -205,7 +200,7 @@ def load_adapter_model(
205200
model_kwargs: Dict,
206201
**kwargs: Any,
207202
) -> PreTrainedModel:
208-
use_ptuning_v2 = kwargs.get("use_ptuning_v2", False)
203+
using_ptuning_v2 = kwargs.get("using_ptuning_v2", False)
209204
resize_embeddings = kwargs.get("resize_embeddings", False)
210205
if adapter_model and resize_embeddings and not is_chatglm:
211206
model_vocab_size = model.get_input_embeddings().weight.size(0)
@@ -218,10 +213,10 @@ def load_adapter_model(
218213
logger.info("Resize model embeddings to fit tokenizer")
219214
model.resize_token_embeddings(tokenzier_vocab_size)
220215

221-
if use_ptuning_v2:
216+
if using_ptuning_v2:
222217
prefix_state_dict = torch.load(os.path.join(adapter_model, "pytorch_model.bin"))
223218
new_prefix_state_dict = {
224-
k[len("transformer.prefix_encoder.") :]: v
219+
k[len("transformer.prefix_encoder."):]: v
225220
for k, v in prefix_state_dict.items()
226221
if k.startswith("transformer.prefix_encoder.")
227222
}

api/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class Settings(BaseModel):
116116
default=get_bool_env("USING_PTUNING_V2"),
117117
description="Whether to load the model using ptuning_v2."
118118
)
119+
pre_seq_len: Optional[bool] = Field(
120+
default=get_bool_env("PRE_SEQ_LEN"),
121+
description="PRE_SEQ_LEN for ptuning_v2."
122+
)
119123

120124
# context related
121125
context_length: Optional[int] = Field(

api/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def create_generate_model():
4040
apply_ntk_scaling_patch(SETTINGS.alpha)
4141

4242
include = {
43-
"model_name", "quantize", "device", "device_map", "num_gpus",
43+
"model_name", "quantize", "device", "device_map", "num_gpus", "pre_seq_len",
4444
"load_in_8bit", "load_in_4bit", "using_ptuning_v2", "dtype", "resize_embeddings"
4545
}
4646
kwargs = SETTINGS.model_dump(include=include)

0 commit comments

Comments
 (0)