Skip to content

Commit cd313e8

Browse files
author
xusenlin
committed
2 parents 9bf7014 + e4cf600 commit cd313e8

File tree

19 files changed

+751
-360
lines changed

19 files changed

+751
-360
lines changed

api/adapter/model.py

Lines changed: 101 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import json
22
import os
33
import sys
4-
from typing import List
5-
from typing import Optional
4+
from typing import List, Optional, Any, Dict, Tuple
65

76
import torch
87
from loguru import logger
@@ -15,6 +14,7 @@
1514
AutoModelForCausalLM,
1615
BitsAndBytesConfig,
1716
PreTrainedTokenizer,
17+
PreTrainedModel,
1818
)
1919
from transformers.utils.versions import require_version
2020

@@ -25,31 +25,59 @@
2525

2626

2727
class BaseModelAdapter:
28-
"""The base and the default model adapter."""
28+
""" The base and default model adapter. """
2929

3030
model_names = []
3131

3232
def match(self, model_name) -> bool:
33+
"""
34+
Check if the given model name matches any of the predefined model names.
35+
36+
Args:
37+
model_name (str): The model name to check.
38+
39+
Returns:
40+
bool: True if the model name matches any of the predefined model names, False otherwise.
41+
"""
42+
3343
return any(m in model_name for m in self.model_names) if self.model_names else True
3444

3545
def load_model(
3646
self,
3747
model_name_or_path: Optional[str] = None,
3848
adapter_model: Optional[str] = None,
39-
**kwargs
40-
):
41-
""" Load model through transformers. """
42-
model_name_or_path = self.default_model_name_or_path if model_name_or_path is None else model_name_or_path
43-
tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False}
44-
tokenizer_kwargs.update(self.tokenizer_kwargs)
49+
**kwargs: Any,
50+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
51+
"""
52+
Load a model and tokenizer based on the provided model name or path.
53+
54+
Args:
55+
model_name_or_path (str, optional): The name or path of the model. Defaults to None.
56+
adapter_model (str, optional): The adapter model to load the tokenizer from. Defaults to None.
57+
**kwargs: Additional keyword arguments.
4558
59+
Returns:
60+
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
61+
"""
62+
63+
model_name_or_path = model_name_or_path or self.default_model_name_or_path
64+
tokenizer_kwargs = {"trust_remote_code": True, "use_fast": False}
65+
tokenizer_kwargs |= self.tokenizer_kwargs
66+
67+
# load a tokenizer from adapter model if it exists.
4668
if adapter_model is not None:
4769
try:
48-
tokenizer = self.tokenizer_class.from_pretrained(adapter_model, **tokenizer_kwargs)
70+
tokenizer = self.tokenizer_class.from_pretrained(
71+
adapter_model, **tokenizer_kwargs,
72+
)
4973
except OSError:
50-
tokenizer = self.tokenizer_class.from_pretrained(model_name_or_path, **tokenizer_kwargs)
74+
tokenizer = self.tokenizer_class.from_pretrained(
75+
model_name_or_path, **tokenizer_kwargs,
76+
)
5177
else:
52-
tokenizer = self.tokenizer_class.from_pretrained(model_name_or_path, **tokenizer_kwargs)
78+
tokenizer = self.tokenizer_class.from_pretrained(
79+
model_name_or_path, **tokenizer_kwargs,
80+
)
5381

5482
config_kwargs = self.model_kwargs
5583
device = kwargs.get("device", "cuda")
@@ -109,12 +137,10 @@ def load_model(
109137

110138
use_ptuning_v2 = kwargs.get("use_ptuning_v2", False)
111139
if use_ptuning_v2 and adapter_model:
112-
prefix_encoder_file = open(f'{adapter_model}/config.json', 'r')
113-
prefix_encoder_config = json.loads(prefix_encoder_file.read())
114-
prefix_encoder_file.close()
115-
116-
config.pre_seq_len = prefix_encoder_config['pre_seq_len']
117-
config.prefix_projection = prefix_encoder_config['prefix_projection']
140+
with open(f"{adapter_model}/config.json", "r") as prefix_encoder_file:
141+
prefix_encoder_config = json.loads(prefix_encoder_file.read())
142+
config.pre_seq_len = prefix_encoder_config["pre_seq_len"]
143+
config.prefix_projection = prefix_encoder_config["prefix_projection"]
118144

119145
# Load and prepare pretrained models (without valuehead).
120146
model = self.model_class.from_pretrained(
@@ -148,16 +174,37 @@ def load_model(
148174

149175
return model, tokenizer
150176

151-
def load_lora_model(self, model, adapter_model, model_kwargs):
177+
def load_lora_model(
178+
self, model: PreTrainedModel, adapter_model: str, model_kwargs: Dict,
179+
) -> PeftModel:
180+
"""
181+
Load a LoRA model.
182+
183+
This function loads a LoRA model using the specified pretrained model and adapter model.
184+
185+
Args:
186+
model (PreTrainedModel): The base pretrained model.
187+
adapter_model (str): The name or path of the adapter model.
188+
model_kwargs (dict): Additional keyword arguments for the model.
189+
190+
Returns:
191+
PeftModel: The loaded LoRA model.
192+
"""
152193
return PeftModel.from_pretrained(
153194
model,
154195
adapter_model,
155196
torch_dtype=model_kwargs.get("torch_dtype", torch.float16),
156197
)
157198

158199
def load_adapter_model(
159-
self, model, tokenizer, adapter_model, is_chatglm, model_kwargs, **kwargs
160-
):
200+
self,
201+
model: PreTrainedModel,
202+
tokenizer: PreTrainedTokenizer,
203+
adapter_model: str,
204+
is_chatglm: bool,
205+
model_kwargs: Dict,
206+
**kwargs: Any,
207+
) -> PreTrainedModel:
161208
use_ptuning_v2 = kwargs.get("use_ptuning_v2", False)
162209
resize_embeddings = kwargs.get("resize_embeddings", False)
163210
if adapter_model and resize_embeddings and not is_chatglm:
@@ -173,10 +220,11 @@ def load_adapter_model(
173220

174221
if use_ptuning_v2:
175222
prefix_state_dict = torch.load(os.path.join(adapter_model, "pytorch_model.bin"))
176-
new_prefix_state_dict = {}
177-
for k, v in prefix_state_dict.items():
178-
if k.startswith("transformer.prefix_encoder."):
179-
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
223+
new_prefix_state_dict = {
224+
k[len("transformer.prefix_encoder.") :]: v
225+
for k, v in prefix_state_dict.items()
226+
if k.startswith("transformer.prefix_encoder.")
227+
}
180228
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
181229
model.transformer.prefix_encoder.float()
182230
else:
@@ -213,13 +261,21 @@ def default_model_name_or_path(self):
213261

214262

215263
def register_model_adapter(cls):
216-
"""Register a model adapter."""
264+
""" Register a model adapter. """
217265
model_adapters.append(cls())
218266

219267

220268
@cache
221269
def get_model_adapter(model_name: str) -> BaseModelAdapter:
222-
"""Get a model adapter for a model name."""
270+
"""
271+
Get a model adapter for a given model name.
272+
273+
Args:
274+
model_name (str): The name of the model.
275+
276+
Returns:
277+
ModelAdapter: The model adapter that matches the given model name.
278+
"""
223279
for adapter in model_adapters:
224280
if adapter.match(model_name):
225281
return adapter
@@ -233,8 +289,23 @@ def load_model(
233289
quantize: Optional[int] = 16,
234290
device: Optional[str] = "cuda",
235291
load_in_8bit: Optional[bool] = False,
236-
**kwargs
237-
):
292+
**kwargs: Any,
293+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
294+
"""
295+
Load a pre-trained model and tokenizer.
296+
297+
Args:
298+
model_name (str): The name of the model.
299+
model_name_or_path (Optional[str], optional): The path or name of the pre-trained model. Defaults to None.
300+
adapter_model (Optional[str], optional): The name of the adapter model. Defaults to None.
301+
quantize (Optional[int], optional): The quantization level. Defaults to 16.
302+
device (Optional[str], optional): The device to load the model on. Defaults to "cuda".
303+
load_in_8bit (Optional[bool], optional): Whether to load the model in 8-bit mode. Defaults to False.
304+
**kwargs (Any): Additional keyword arguments.
305+
306+
Returns:
307+
Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
308+
"""
238309
model_name = model_name.lower()
239310

240311
if "tiger" in model_name:
@@ -496,6 +567,7 @@ def default_model_name_or_path(self):
496567

497568

498569
register_model_adapter(ChatglmModelAdapter)
570+
register_model_adapter(Chatglm3ModelAdapter)
499571
register_model_adapter(LlamaModelAdapter)
500572
register_model_adapter(MossModelAdapter)
501573
register_model_adapter(PhoenixModelAdapter)

0 commit comments

Comments
 (0)