Skip to content

Commit

Permalink
Merge pull request #49 from jianzhnie/dev
Browse files Browse the repository at this point in the history
update alpaca
  • Loading branch information
jianzhnie authored May 23, 2023
2 parents 641ed6d + ba5dd99 commit 3c11ac3
Show file tree
Hide file tree
Showing 25 changed files with 805 additions and 344 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,5 @@ assets/
work_dir/
work_dirs/
prompt_data/
prompt_data
work_dir_lora/
321 changes: 150 additions & 171 deletions README.md

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions chatgpt/models/apply_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Apply the LoRA weights on top of a base model.
Usage:
python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B
Dependency:
pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b
"""
import argparse
from typing import Tuple

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer


def apply_lora(
base_model_path: str,
lora_path: str,
load_8bit: bool = False,
target_model_path: str = None,
save_target_model: bool = False
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Applies the LoRA adapter to a base model and saves the resulting target model (optional).
Args:
base_model_path (str): The path to the base model to which the LoRA adapter will be applied.
lora_path (str): The path to the LoRA adapter.
target_model_path (str): The path where the target model will be saved (if `save_target_model=True`).
save_target_model (bool, optional): Whether to save the target model or not. Defaults to False.
Returns:
Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the target model and its tokenizer.
"""
# Load the base model and tokenizer
print(f'Loading the base model from {base_model_path}')
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map='auto',
)

# Load the tokenizer
if base_model.config.model_type == 'llama':
# Due to the name of Transformers' LlamaTokenizer, we have to do this
base_tokenizer = LlamaTokenizer.from_pretrained(
base_model_path,
padding_side='right',
use_fast=True,
)
else:
base_tokenizer = AutoTokenizer.from_pretrained(
base_model_path,
padding_side='right',
use_fast=True,
)

# Load the LoRA adapter
print(f'Loading the LoRA adapter from {lora_path}')
model = PeftModel.from_pretrained(
base_model,
lora_path,
torch_dtype=torch.float16,
)

if save_target_model and target_model_path is not None:
print(f'Saving the target model to {target_model_path}')
model.save_pretrained(target_model_path)
base_tokenizer.save_pretrained(target_model_path)

return model, base_tokenizer


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--base-model-path', type=str, required=True)
parser.add_argument('--target-model-path', type=str, required=True)
parser.add_argument('--lora-path', type=str, required=True)
parser.add_argument('--save-target-model', type=bool, default=False)

args = parser.parse_args()

apply_lora(args.base_model_path, args.target_model_path, args.lora_path,
args.save_target_model)
70 changes: 70 additions & 0 deletions chatgpt/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Helpers to support streaming generate output.
Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
"""
import traceback
from queue import Queue
from threading import Thread

import transformers


class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func

def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False


class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False

def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)

def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass

self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)

self.thread = Thread(target=gentask)
self.thread.start()

def __iter__(self):
return self

def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
24 changes: 0 additions & 24 deletions docs/rlhf_dataset.md

This file was deleted.

File renamed without changes.
171 changes: 171 additions & 0 deletions examples/alpaca/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import argparse
import sys
from typing import Union

import torch

sys.path.append('../../')
from transformers import GenerationConfig

from chatgpt.models.apply_lora import apply_lora


class Prompter(object):
def __init__(self) -> None:
self.PROMPT_DICT = {
'prompt_input':
('Below is an instruction that describes a task, paired with an input that provides further context. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:'
),
'prompt_no_input':
('Below is an instruction that describes a task. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n### Response:'),
}
self.reponse_split = '### Response:'

def generate_prompt(
self,
instruction: str,
input: Union[None, str] = None,
response: Union[None, str] = None,
):
prompt_input, prompt_no_input = self.PROMPT_DICT[
'prompt_input'], self.PROMPT_DICT['prompt_no_input']
if input is not None:
prompt_text = prompt_input.format(instruction=instruction,
input=input)
else:
prompt_text = prompt_no_input.format(instruction=instruction)

if response:
prompt_text = f'{prompt_text}{response}'
return prompt_text

def get_response(self, output: str) -> str:
return output.split(self.reponse_split)[1].strip()


def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path',
default=None,
type=str,
required=True,
help='Path to pre-trained model or shortcut name')
parser.add_argument('--lora_model_name_or_path',
default=None,
type=str,
required=True,
help='Path to pre-trained model or shortcut name ')
parser.add_argument('--stop_token',
type=str,
default=None,
help='Token at which text generation is stopped')
parser.add_argument(
'--temperature',
type=float,
default=1.0,
help='1.0 has no effect, lower tend toward greedy sampling')
parser.add_argument(
'--repetition_penalty',
type=float,
default=1.0,
help='primarily useful for CTRL model; in that case, use 1.2')
parser.add_argument('--top_k', type=int, default=0)
parser.add_argument('--top_p', type=float, default=0.9)
parser.add_argument('--num_beams', type=int, default=0)
parser.add_argument('--prefix',
type=str,
default='',
help='Text added prior to input.')
parser.add_argument('--padding_text',
type=str,
default='',
help='Deprecated, the use of `--prefix` is preferred.')
parser.add_argument('--seed',
type=int,
default=42,
help='random seed for initialization')
parser.add_argument('--no_cuda',
action='store_true',
help='Avoid using CUDA when available')
parser.add_argument('--num_return_sequences',
type=int,
default=1,
help='The number of samples to generate.')
parser.add_argument(
'--load_8bit',
action='store_true',
help='Whether to use load_8bit instead of 32-bit',
)
args = parser.parse_args()

args.device = torch.device(
'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
return args


def complete_prompts(model, tokenizer, generation_config, prompter,
prompt_text, device):
inputs = tokenizer(prompt_text, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)

with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=128,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return prompter.get_response(output)


def main(args):
generation_config = GenerationConfig(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
num_beams=args.num_beams,
do_sample=True,
no_repeat_ngram_size=6,
repetition_penalty=1.8,
num_return_sequences=args.num_return_sequences)

model, tokenizer = apply_lora(args.model_name_or_path,
args.lora_model_name_or_path,
load_8bit=args.load_8bit)
prompter = Prompter()

instruction_list = [
'Tell me about alpacas.',
'Tell me about the president of Mexico in 2019.',
'Tell me about the king of France in 2019.',
'List all Canadian provinces in alphabetical order.',
'Write a Python program that prints the first 10 Fibonacci numbers.',
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' \
instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples \
of both three and five print 'FizzBuzz'.",
"Tell me five words that rhyme with 'shock'.",
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
'Count up from 1 to 500.',
]
# testing code for readme
for instruction in instruction_list:
prompt_text = prompter.generate_prompt(instruction, input=None)
result = complete_prompts(model,
tokenizer,
generation_config,
prompter,
prompt_text,
device=args.device)
print(result)


if __name__ == '__main__':
args = args_parser()
main(args)
Loading

0 comments on commit 3c11ac3

Please sign in to comment.