Skip to content

Commit 3bda907

Browse files
authored
Merge pull request oobabooga#366 from oobabooga/lora
Add LoRA support
2 parents 4c13067 + 614dad0 commit 3bda907

File tree

10 files changed

+82
-12
lines changed

10 files changed

+82
-12
lines changed

css/main.css

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
.tabs.svelte-710i53 {
22
margin-top: 0
33
}
4+
45
.py-6 {
56
padding-top: 2.5rem
67
}
8+
79
.dark #refresh-button {
810
background-color: #ffffff1f;
911
}
12+
1013
#refresh-button {
1114
flex: none;
1215
margin: 0;
@@ -17,22 +20,28 @@
1720
border-radius: 10px;
1821
background-color: #0000000d;
1922
}
23+
2024
#download-label, #upload-label {
2125
min-height: 0
2226
}
27+
2328
#accordion {
2429
}
30+
2531
.dark svg {
2632
fill: white;
2733
}
34+
2835
svg {
2936
display: unset !important;
3037
vertical-align: middle !important;
3138
margin: 5px;
3239
}
40+
3341
ol li p, ul li p {
3442
display: inline-block;
3543
}
36-
#main, #parameters, #chat-settings, #interface-mode {
44+
45+
#main, #parameters, #chat-settings, #interface-mode, #lora {
3746
border: 0;
3847
}

download-model.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
101101
classifications = []
102102
has_pytorch = False
103103
has_safetensors = False
104+
is_lora = False
104105
while True:
105106
content = requests.get(f"{base}{page}{cursor.decode()}").content
106107

@@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
110111

111112
for i in range(len(dict)):
112113
fname = dict[i]['path']
114+
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
115+
is_lora = True
113116

114-
is_pytorch = re.match("pytorch_model.*\.bin", fname)
117+
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
115118
is_safetensors = re.match("model.*\.safetensors", fname)
116119
is_tokenizer = re.match("tokenizer.*\.model", fname)
117120
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
@@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
130133
has_pytorch = True
131134
classifications.append('pytorch')
132135

136+
133137
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
134138
cursor = base64.b64encode(cursor)
135139
cursor = cursor.replace(b'=', b'%3D')
@@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
140144
if classifications[i] == 'pytorch':
141145
links.pop(i)
142146

143-
return links
147+
return links, is_lora
144148

145149
if __name__ == '__main__':
146150
model = args.MODEL
@@ -159,15 +163,16 @@ def get_download_links_from_huggingface(model, branch):
159163
except ValueError as err_branch:
160164
print(f"Error: {err_branch}")
161165
sys.exit()
166+
167+
links, is_lora = get_download_links_from_huggingface(model, branch)
168+
base_folder = 'models' if not is_lora else 'loras'
162169
if branch != 'main':
163-
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
170+
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
164171
else:
165-
output_folder = Path("models") / model.split('/')[-1]
172+
output_folder = Path(base_folder) / model.split('/')[-1]
166173
if not output_folder.exists():
167174
output_folder.mkdir()
168175

169-
links = get_download_links_from_huggingface(model, branch)
170-
171176
# Downloading the files
172177
print(f"Downloading the model to {output_folder}")
173178
pool = multiprocessing.Pool(processes=args.threads)

loras/place-your-loras-here.txt

Whitespace-only changes.

modules/LoRA.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pathlib import Path
2+
3+
from peft import PeftModel
4+
5+
import modules.shared as shared
6+
from modules.models import load_model
7+
8+
9+
def add_lora_to_model(lora_name):
10+
11+
# Is there a more efficient way of returning to the base model?
12+
if lora_name == "None":
13+
print("Reloading the model to remove the LoRA...")
14+
shared.model, shared.tokenizer = load_model(shared.model_name)
15+
else:
16+
print(f"Adding the LoRA {lora_name} to the model...")
17+
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"))

modules/callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import modules.shared as shared
99

10+
1011
# Copied from https://github.com/PygmalionAI/gradio-ui/
1112
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
1213

modules/chat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import modules.shared as shared
1313
from modules.extensions import apply_extensions
1414
from modules.html_generator import generate_chat_html
15-
from modules.text_generation import encode, generate_reply, get_max_prompt_length
15+
from modules.text_generation import (encode, generate_reply,
16+
get_max_prompt_length)
1617

1718

1819
# This gets the new line characters right.

modules/shared.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
model = None
44
tokenizer = None
5-
model_name = ""
5+
model_name = "None"
6+
lora_name = "None"
67
soft_prompt_tensor = None
78
soft_prompt = False
89
is_RWKV = False
@@ -52,6 +53,10 @@
5253
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
5354
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
5455
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
56+
},
57+
'lora_prompts': {
58+
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
59+
'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
5560
}
5661
}
5762

@@ -67,6 +72,7 @@ def str2bool(v):
6772

6873
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
6974
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
75+
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
7076
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
7177
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
7278
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ flexgen==0.1.7
44
gradio==3.18.0
55
markdown
66
numpy
7+
peft==0.2.0
78
requests
89
rwkv==0.4.2
910
safetensors==0.3.0

server.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import modules.shared as shared
1616
import modules.ui as ui
1717
from modules.html_generator import generate_chat_html
18+
from modules.LoRA import add_lora_to_model
1819
from modules.models import load_model, load_soft_prompt
1920
from modules.text_generation import generate_reply
2021

@@ -48,6 +49,9 @@ def get_available_extensions():
4849
def get_available_softprompts():
4950
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
5051

52+
def get_available_loras():
53+
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
54+
5155
def load_model_wrapper(selected_model):
5256
if selected_model != shared.model_name:
5357
shared.model_name = selected_model
@@ -59,6 +63,17 @@ def load_model_wrapper(selected_model):
5963

6064
return selected_model
6165

66+
def load_lora_wrapper(selected_lora):
67+
shared.lora_name = selected_lora
68+
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
69+
70+
if not shared.args.cpu:
71+
gc.collect()
72+
torch.cuda.empty_cache()
73+
add_lora_to_model(selected_lora)
74+
75+
return selected_lora, default_text
76+
6277
def load_preset_values(preset_menu, return_dict=False):
6378
generate_params = {
6479
'do_sample': True,
@@ -145,6 +160,10 @@ def create_settings_menus(default_preset):
145160
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
146161
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
147162

163+
with gr.Row():
164+
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
165+
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
166+
148167
with gr.Accordion('Soft prompt', open=False):
149168
with gr.Row():
150169
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
@@ -156,6 +175,7 @@ def create_settings_menus(default_preset):
156175

157176
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
158177
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
178+
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
159179
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
160180
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
161181

@@ -181,6 +201,7 @@ def set_interface_arguments(interface_mode, extensions, cmd_active):
181201
available_presets = get_available_presets()
182202
available_characters = get_available_characters()
183203
available_softprompts = get_available_softprompts()
204+
available_loras = get_available_loras()
184205

185206
# Default extensions
186207
extensions_module.available_extensions = get_available_extensions()
@@ -213,10 +234,16 @@ def set_interface_arguments(interface_mode, extensions, cmd_active):
213234
print()
214235
shared.model_name = available_models[i]
215236
shared.model, shared.tokenizer = load_model(shared.model_name)
237+
if shared.args.lora:
238+
print(shared.args.lora)
239+
shared.lora_name = shared.args.lora
240+
add_lora_to_model(shared.lora_name)
216241

217242
# Default UI settings
218243
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
219-
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
244+
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
245+
if default_text == '':
246+
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
220247
title ='Text generation web UI'
221248
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
222249
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''

settings-template.json

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323
"presets": {
2424
"default": "NovelAI-Sphinx Moth",
2525
"pygmalion-*": "Pygmalion",
26-
"RWKV-*": "Naive",
27-
"(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)"
26+
"RWKV-*": "Naive"
2827
},
2928
"prompts": {
3029
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
3130
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
3231
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
3332
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
33+
},
34+
"lora_prompts": {
35+
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
36+
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
3437
}
3538
}

0 commit comments

Comments
 (0)