Skip to content

Commit 086c8f8

Browse files
author
xusenlin
committed
Add apply lora
1 parent 8481667 commit 086c8f8

File tree

6 files changed

+73
-14
lines changed

6 files changed

+73
-14
lines changed

api/utils/protocol.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ class ChatCompletionCreateParams(BaseModel):
7878
or exclusive selection of the relevant token.
7979
"""
8080

81+
logprobs: Optional[bool] = False
82+
"""Whether to return log probabilities of the output tokens or not.
83+
84+
If true, returns the log probabilities of each output token returned in the
85+
`content` of `message`. This option is currently not available on the
86+
`gpt-4-vision-preview` model.
87+
"""
88+
8189
max_tokens: Optional[int] = None
8290
"""The maximum number of [tokens](/tokenizer) to generate in the chat completion.
8391
@@ -146,6 +154,13 @@ class ChatCompletionCreateParams(BaseModel):
146154
functions the model may generate JSON inputs for.
147155
"""
148156

157+
top_logprobs: Optional[int] = None
158+
"""
159+
An integer between 0 and 5 specifying the number of most likely tokens to return
160+
at each token position, each with an associated log probability. `logprobs` must
161+
be set to `true` if this parameter is used.
162+
"""
163+
149164
top_p: Optional[float] = 1.0
150165
"""
151166
An alternative to sampling with temperature, called nucleus sampling, where the

examples/qwen-7b-chat/get_weather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,4 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5):
134134
logger.info("\n=========== next conversation ===========")
135135

136136
query = "波士顿天气如何?"
137-
run_conversation(query, functions=functions, stream=True)
137+
run_conversation(query, functions=functions, stream=False)

libs/langchain_llm/langchain_llm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
HuggingFaceLLM,
33
ChatHuggingFace,
44
)
5-
from ._vllm import XVLLM as VLLM
65
from ._vllm import ChatVLLM
6+
from ._vllm import XVLLM as VLLM
7+
from .utils import apply_lora
78

89

910
__all__ = [
1011
"HuggingFaceLLM",
1112
"ChatHuggingFace",
1213
"VLLM",
1314
"ChatVLLM",
15+
"apply_lora"
1416
]

libs/langchain_llm/langchain_llm/_huggingface.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,6 @@ def _validate_environment(values: Dict) -> Dict:
170170
values["context_length"] = get_context_length(values["model"].config)
171171
logger.info(f"Context length is set to : {values['context_length']}")
172172

173-
# fix the tokenizer by adding the end-of-sequence (eos) token and the padding (pad) token if they are missing.
174-
if values["tokenizer"].eos_token_id is None:
175-
values["tokenizer"].eos_token = "<|endoftext|>"
176-
logger.info(f"Add eos token: {values['tokenizer'].eos_token}")
177-
178-
if values["tokenizer"].pad_token_id is None:
179-
if values["tokenizer"].unk_token_id is not None:
180-
values["tokenizer"].pad_token = values["tokenizer"].unk_token
181-
else:
182-
values["tokenizer"].pad_token = values["tokenizer"].eos_token
183-
logger.info(f"Add pad token: {values['tokenizer'].pad_token}")
184-
185173
return values
186174

187175
@property

libs/langchain_llm/langchain_llm/adapters/patcher.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
112112
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
113113
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
114114

115+
if tokenizer.eos_token_id is None:
116+
tokenizer.eos_token = "<|endoftext|>"
117+
logger.info(f"Add eos token: {tokenizer.eos_token}")
118+
119+
if tokenizer.pad_token_id is None:
120+
if tokenizer.unk_token_id is not None:
121+
tokenizer.pad_token = tokenizer.unk_token
122+
else:
123+
tokenizer.pad_token = tokenizer.eos_token
124+
logger.info(f"Add pad token: {tokenizer.pad_token}")
125+
115126

116127
def patch_config(
117128
config: "PretrainedConfig",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Optional
2+
3+
import torch
4+
from loguru import logger
5+
from peft import PeftModel
6+
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
7+
8+
9+
def apply_lora(
10+
base_model_path: str,
11+
lora_path: str,
12+
target_model_path: str,
13+
max_shard_size: Optional[str] = "2GB",
14+
safe_serialization: Optional[bool] = True,
15+
) -> PreTrainedModel:
16+
17+
logger.info(f"Loading the base model from {base_model_path}")
18+
base = AutoModelForCausalLM.from_pretrained(
19+
base_model_path,
20+
torch_dtype=torch.float16,
21+
low_cpu_mem_usage=True,
22+
trust_remote_code=True,
23+
)
24+
base_tokenizer = AutoTokenizer.from_pretrained(
25+
base_model_path,
26+
use_fast=False,
27+
trust_remote_code=True,
28+
)
29+
30+
logger.info(f"Loading the LoRA adapter from {lora_path}")
31+
32+
lora_model = PeftModel.from_pretrained(base, lora_path)
33+
34+
logger.info("Applying the LoRA")
35+
model = lora_model.merge_and_unload()
36+
37+
logger.info(f"Saving the target model to {target_model_path}")
38+
model.save_pretrained(
39+
target_model_path,
40+
max_shard_size=max_shard_size,
41+
safe_serialization=safe_serialization,
42+
)
43+
base_tokenizer.save_pretrained(target_model_path)

0 commit comments

Comments
 (0)