-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from jianzhnie/dev
update alpaca
- Loading branch information
Showing
25 changed files
with
805 additions
and
344 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,3 +135,5 @@ assets/ | |
work_dir/ | ||
work_dirs/ | ||
prompt_data/ | ||
prompt_data | ||
work_dir_lora/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
""" | ||
Apply the LoRA weights on top of a base model. | ||
Usage: | ||
python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B | ||
Dependency: | ||
pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b | ||
""" | ||
import argparse | ||
from typing import Tuple | ||
|
||
import torch | ||
from peft import PeftModel | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer | ||
|
||
|
||
def apply_lora( | ||
base_model_path: str, | ||
lora_path: str, | ||
load_8bit: bool = False, | ||
target_model_path: str = None, | ||
save_target_model: bool = False | ||
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: | ||
"""Applies the LoRA adapter to a base model and saves the resulting target model (optional). | ||
Args: | ||
base_model_path (str): The path to the base model to which the LoRA adapter will be applied. | ||
lora_path (str): The path to the LoRA adapter. | ||
target_model_path (str): The path where the target model will be saved (if `save_target_model=True`). | ||
save_target_model (bool, optional): Whether to save the target model or not. Defaults to False. | ||
Returns: | ||
Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the target model and its tokenizer. | ||
""" | ||
# Load the base model and tokenizer | ||
print(f'Loading the base model from {base_model_path}') | ||
base_model = AutoModelForCausalLM.from_pretrained( | ||
base_model_path, | ||
load_in_8bit=load_8bit, | ||
torch_dtype=torch.float16, | ||
device_map='auto', | ||
) | ||
|
||
# Load the tokenizer | ||
if base_model.config.model_type == 'llama': | ||
# Due to the name of Transformers' LlamaTokenizer, we have to do this | ||
base_tokenizer = LlamaTokenizer.from_pretrained( | ||
base_model_path, | ||
padding_side='right', | ||
use_fast=True, | ||
) | ||
else: | ||
base_tokenizer = AutoTokenizer.from_pretrained( | ||
base_model_path, | ||
padding_side='right', | ||
use_fast=True, | ||
) | ||
|
||
# Load the LoRA adapter | ||
print(f'Loading the LoRA adapter from {lora_path}') | ||
model = PeftModel.from_pretrained( | ||
base_model, | ||
lora_path, | ||
torch_dtype=torch.float16, | ||
) | ||
|
||
if save_target_model and target_model_path is not None: | ||
print(f'Saving the target model to {target_model_path}') | ||
model.save_pretrained(target_model_path) | ||
base_tokenizer.save_pretrained(target_model_path) | ||
|
||
return model, base_tokenizer | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--base-model-path', type=str, required=True) | ||
parser.add_argument('--target-model-path', type=str, required=True) | ||
parser.add_argument('--lora-path', type=str, required=True) | ||
parser.add_argument('--save-target-model', type=bool, default=False) | ||
|
||
args = parser.parse_args() | ||
|
||
apply_lora(args.base_model_path, args.target_model_path, args.lora_path, | ||
args.save_target_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
Helpers to support streaming generate output. | ||
Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py | ||
""" | ||
import traceback | ||
from queue import Queue | ||
from threading import Thread | ||
|
||
import transformers | ||
|
||
|
||
class Stream(transformers.StoppingCriteria): | ||
def __init__(self, callback_func=None): | ||
self.callback_func = callback_func | ||
|
||
def __call__(self, input_ids, scores) -> bool: | ||
if self.callback_func is not None: | ||
self.callback_func(input_ids[0]) | ||
return False | ||
|
||
|
||
class Iteratorize: | ||
""" | ||
Transforms a function that takes a callback | ||
into a lazy iterator (generator). | ||
""" | ||
def __init__(self, func, kwargs={}, callback=None): | ||
self.mfunc = func | ||
self.c_callback = callback | ||
self.q = Queue() | ||
self.sentinel = object() | ||
self.kwargs = kwargs | ||
self.stop_now = False | ||
|
||
def _callback(val): | ||
if self.stop_now: | ||
raise ValueError | ||
self.q.put(val) | ||
|
||
def gentask(): | ||
try: | ||
ret = self.mfunc(callback=_callback, **self.kwargs) | ||
except ValueError: | ||
pass | ||
except: | ||
traceback.print_exc() | ||
pass | ||
|
||
self.q.put(self.sentinel) | ||
if self.c_callback: | ||
self.c_callback(ret) | ||
|
||
self.thread = Thread(target=gentask) | ||
self.thread.start() | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
obj = self.q.get(True, None) | ||
if obj is self.sentinel: | ||
raise StopIteration | ||
else: | ||
return obj | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self.stop_now = True |
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import argparse | ||
import sys | ||
from typing import Union | ||
|
||
import torch | ||
|
||
sys.path.append('../../') | ||
from transformers import GenerationConfig | ||
|
||
from chatgpt.models.apply_lora import apply_lora | ||
|
||
|
||
class Prompter(object): | ||
def __init__(self) -> None: | ||
self.PROMPT_DICT = { | ||
'prompt_input': | ||
('Below is an instruction that describes a task, paired with an input that provides further context. ' | ||
'Write a response that appropriately completes the request.\n\n' | ||
'### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:' | ||
), | ||
'prompt_no_input': | ||
('Below is an instruction that describes a task. ' | ||
'Write a response that appropriately completes the request.\n\n' | ||
'### Instruction:\n{instruction}\n\n### Response:'), | ||
} | ||
self.reponse_split = '### Response:' | ||
|
||
def generate_prompt( | ||
self, | ||
instruction: str, | ||
input: Union[None, str] = None, | ||
response: Union[None, str] = None, | ||
): | ||
prompt_input, prompt_no_input = self.PROMPT_DICT[ | ||
'prompt_input'], self.PROMPT_DICT['prompt_no_input'] | ||
if input is not None: | ||
prompt_text = prompt_input.format(instruction=instruction, | ||
input=input) | ||
else: | ||
prompt_text = prompt_no_input.format(instruction=instruction) | ||
|
||
if response: | ||
prompt_text = f'{prompt_text}{response}' | ||
return prompt_text | ||
|
||
def get_response(self, output: str) -> str: | ||
return output.split(self.reponse_split)[1].strip() | ||
|
||
|
||
def args_parser(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model_name_or_path', | ||
default=None, | ||
type=str, | ||
required=True, | ||
help='Path to pre-trained model or shortcut name') | ||
parser.add_argument('--lora_model_name_or_path', | ||
default=None, | ||
type=str, | ||
required=True, | ||
help='Path to pre-trained model or shortcut name ') | ||
parser.add_argument('--stop_token', | ||
type=str, | ||
default=None, | ||
help='Token at which text generation is stopped') | ||
parser.add_argument( | ||
'--temperature', | ||
type=float, | ||
default=1.0, | ||
help='1.0 has no effect, lower tend toward greedy sampling') | ||
parser.add_argument( | ||
'--repetition_penalty', | ||
type=float, | ||
default=1.0, | ||
help='primarily useful for CTRL model; in that case, use 1.2') | ||
parser.add_argument('--top_k', type=int, default=0) | ||
parser.add_argument('--top_p', type=float, default=0.9) | ||
parser.add_argument('--num_beams', type=int, default=0) | ||
parser.add_argument('--prefix', | ||
type=str, | ||
default='', | ||
help='Text added prior to input.') | ||
parser.add_argument('--padding_text', | ||
type=str, | ||
default='', | ||
help='Deprecated, the use of `--prefix` is preferred.') | ||
parser.add_argument('--seed', | ||
type=int, | ||
default=42, | ||
help='random seed for initialization') | ||
parser.add_argument('--no_cuda', | ||
action='store_true', | ||
help='Avoid using CUDA when available') | ||
parser.add_argument('--num_return_sequences', | ||
type=int, | ||
default=1, | ||
help='The number of samples to generate.') | ||
parser.add_argument( | ||
'--load_8bit', | ||
action='store_true', | ||
help='Whether to use load_8bit instead of 32-bit', | ||
) | ||
args = parser.parse_args() | ||
|
||
args.device = torch.device( | ||
'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu') | ||
return args | ||
|
||
|
||
def complete_prompts(model, tokenizer, generation_config, prompter, | ||
prompt_text, device): | ||
inputs = tokenizer(prompt_text, return_tensors='pt') | ||
input_ids = inputs['input_ids'].to(device) | ||
|
||
with torch.no_grad(): | ||
generation_output = model.generate( | ||
input_ids=input_ids, | ||
generation_config=generation_config, | ||
return_dict_in_generate=True, | ||
output_scores=True, | ||
max_new_tokens=128, | ||
) | ||
s = generation_output.sequences[0] | ||
output = tokenizer.decode(s) | ||
return prompter.get_response(output) | ||
|
||
|
||
def main(args): | ||
generation_config = GenerationConfig( | ||
temperature=args.temperature, | ||
top_k=args.top_k, | ||
top_p=args.top_p, | ||
num_beams=args.num_beams, | ||
do_sample=True, | ||
no_repeat_ngram_size=6, | ||
repetition_penalty=1.8, | ||
num_return_sequences=args.num_return_sequences) | ||
|
||
model, tokenizer = apply_lora(args.model_name_or_path, | ||
args.lora_model_name_or_path, | ||
load_8bit=args.load_8bit) | ||
prompter = Prompter() | ||
|
||
instruction_list = [ | ||
'Tell me about alpacas.', | ||
'Tell me about the president of Mexico in 2019.', | ||
'Tell me about the king of France in 2019.', | ||
'List all Canadian provinces in alphabetical order.', | ||
'Write a Python program that prints the first 10 Fibonacci numbers.', | ||
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' \ | ||
instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples \ | ||
of both three and five print 'FizzBuzz'.", | ||
"Tell me five words that rhyme with 'shock'.", | ||
"Translate the sentence 'I have no mouth but I must scream' into Spanish.", | ||
'Count up from 1 to 500.', | ||
] | ||
# testing code for readme | ||
for instruction in instruction_list: | ||
prompt_text = prompter.generate_prompt(instruction, input=None) | ||
result = complete_prompts(model, | ||
tokenizer, | ||
generation_config, | ||
prompter, | ||
prompt_text, | ||
device=args.device) | ||
print(result) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = args_parser() | ||
main(args) |
Oops, something went wrong.