Skip to content

Commit

Permalink
1. 修复百川量化参数bug
Browse files Browse the repository at this point in the history
2. 修复上个版本的import拼写错误
3. 新增web.py(api)支持post方法调用百川
  • Loading branch information
shiweijiezero committed Jul 19, 2023
1 parent b417400 commit 8299418
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 39 deletions.
30 changes: 20 additions & 10 deletions uniform_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import (
LlamaForCausalLM, LlamaTokenizer,
AutoModel, AutoTokenizer, AutoModelForCausalLM,
BloomForCausalLM, BloomTokenizerFast, AutoConfig, BitsAndBytesConfig, )
BloomForCausalLM, BloomTokenizerFast, AutoConfig, BitsAndBytesConfig, GenerationConfig)
from transformers.utils.versions import require_version

from peft import (
Expand Down Expand Up @@ -184,9 +184,9 @@ def get_peft_class(peft_type):
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=args.compute_dtype,
bnb_4bit_use_double_quant=args.double_quantization,
bnb_4bit_quant_type=args.quantization_type
bnb_4bit_compute_dtype=None,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)

config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
Expand All @@ -201,6 +201,7 @@ def get_peft_class(peft_type):
trust_remote_code=True,
**config_kwargs
)
model.generation_config = GenerationConfig.from_pretrained(args.model_name_or_path)

# Register auto class to save the custom code files.
if hasattr(baichuan_config, "auto_map") and "AutoConfig" in baichuan_config.auto_map:
Expand Down Expand Up @@ -410,18 +411,23 @@ def generate_and_tokenize_prompt(data_point):

return tokenized_with_response

model_name = args.model_name_or_path.split('/')[-1]
data_name = "+".join([d.split("/")[-1].strip(".json") for d in args.data])
lr_str = str(args.learning_rate)
output_dir = f"saved_models/{model_name}_{data_name}_{lr_str}/{args.peft_type}"
if args.output_dir == "none":
model_name = args.model_name_or_path.split('/')[-1]
data_name = "+".join([d.split("/")[-1].strip(".json") for d in args.data])
lr_str = str(args.learning_rate)
output_dir = f"saved_models/{model_name}_{data_name}_{lr_str}/{args.peft_type}"
logging_name = f"{model_name}_{data_name}_{lr_str}_{args.peft_type}"
else:
output_dir = args.output_dir
logging_name = f"{output_dir}_{args.peft_type}"

# control logging
if args.report_to == "wandb":
import wandb
wandb.init(
project="Alpaca-CoT",
config=args,
name=f"{model_name}_{data_name}_{lr_str}_{args.peft_type}"
name=logging_name
)

# 2. split dataset
Expand Down Expand Up @@ -533,8 +539,12 @@ def generate_and_tokenize_prompt(data_point):
help='The list/str of integrations to report the results and logs to')
parser.add_argument('--quantization_bit', default=None, type=int, help="The number of bits to quantize the model.")
parser.add_argument('--compute_dtype', default="fp16", type=str)
parser.add_argument('--output_dir', default="none", type=str)


args, _ = parser.parse_known_args()
print(args)
# print arguments
for k, v in sorted(vars(args).items()):
print(k, '=', v)

train(args)
127 changes: 104 additions & 23 deletions utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import copy
import json
import logging
from typing import List, Optional

import torch
from datasets import load_dataset
from transformers import GenerationConfig, AutoConfig, BitsAndBytesConfig, PreTrainedModel

from peft import (
prepare_model_for_int8_training,
LoraConfig,
get_peft_model,
PeftModel,
TaskType
)
from transformers import GenerationConfig, AutoConfig, BitAndBytesConfig, PreTrainedModel
from transformers.trainer import TRAINER_STATE_NAME
from .config import *
from .device import get_device_map

from transformers.utils.versions import require_version

logging = logging.getLogger(__name__)


def generate_prompt(data_point):
prompt_ = PROMPT_DICT['prompt_input'] if data_point["input"] else PROMPT_DICT['prompt_no_input']
return prompt_.format_map(data_point)
Expand Down Expand Up @@ -180,17 +179,27 @@ def get_predict_data(args):


def get_fine_tuned_model(args):
def _get_model_class(llm_type, model_path):
if llm_type not in AVAILABLE_MODEL:
llm_type = "Auto"
return MODEL_CLASSES[llm_type], model_path
else:
load_path = llm_type + "_" + model_path
if llm_type in ['moss']:
load_path = llm_type
return MODEL_CLASSES[llm_type], COMMON_PATH + MODEL_PATHS[load_path]
if (args.model_type == "baichuan"):
def _get_model_class(llm_type):
if llm_type not in AVAILABLE_MODEL:
llm_type = "Auto"
return MODEL_CLASSES[llm_type]

model_path = args.model_path
model_class = _get_model_class(args.model_type)
else:
def _get_model_class(llm_type, model_path):
if llm_type not in AVAILABLE_MODEL:
llm_type = "Auto"
return MODEL_CLASSES[llm_type], model_path
else:
load_path = llm_type + "_" + model_path
if llm_type in ['moss']:
load_path = llm_type
return MODEL_CLASSES[llm_type], COMMON_PATH + MODEL_PATHS[load_path]

model_class, model_path = _get_model_class(args.model_type, args.size)

model_class, model_path = _get_model_class(args.model_type, args.size)
if args.model_type == "chatglm":
model = model_class.model.from_pretrained(model_path,
trust_remote_code=True,
Expand All @@ -216,15 +225,83 @@ def _get_model_class(llm_type, model_path):
trust_remote_code=True,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map= get_device_map(model_type="moss", load_in_8bit=True))
device_map=get_device_map(model_type="moss", load_in_8bit=True))

tokenizer = model_class.tokenizer.from_pretrained(model_path,trust_remote_code=True)
tokenizer = model_class.tokenizer.from_pretrained(model_path, trust_remote_code=True)
if args.lora_dir != 'none':
model = PeftModel.from_pretrained(
model,
args.lora_dir,
device_map={"": DEVICE_TYPE}
)
elif args.model_type == "baichuan":
baichuan_config = AutoConfig.from_pretrained(model_path,
trust_remote_code=True, )
tokenizer = model_class.tokenizer.from_pretrained(
model_path,
trust_remote_code=True,
use_fast=False)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0 # set as the <unk> token

config_kwargs = {}
# Quantization configurations by bitsandbytes
if args.quantization_bit is not None:
if args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)

elif args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=None,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)

config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
print("Quantizing model to {} bit.".format(args.quantization_bit))

# `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto"

# Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
config=baichuan_config,
torch_dtype=torch.bfloat16 if args.compute_dtype == "bf16" else torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
**config_kwargs
)
model.generation_config = GenerationConfig.from_pretrained(args.model_name_or_path)

# Register auto class to save the custom code files.
if hasattr(baichuan_config, "auto_map") and "AutoConfig" in baichuan_config.auto_map:
baichuan_config.__class__.register_for_auto_class()
if hasattr(baichuan_config, "auto_map") and "AutoTokenizer" in baichuan_config.auto_map:
tokenizer.__class__.register_for_auto_class()
if hasattr(baichuan_config, "auto_map") and "AutoModelForCausalLM" in baichuan_config.auto_map:
model.__class__.register_for_auto_class()

if args.lora_dir != "none":
print("loading LoRA weight")
model = PeftModel.from_pretrained(
model,
args.lora_dir,
device_map={"": DEVICE_TYPE}
)
model.requires_grad_(False)

else:
model = model_class.model.from_pretrained(model_path,
load_in_8bit=False,
Expand All @@ -238,7 +315,7 @@ def _get_model_class(llm_type, model_path):
args.lora_dir,
device_map={"": DEVICE_TYPE}
)
model.half()
model.half() if args.quantization_bit is None else model
return model, tokenizer


Expand Down Expand Up @@ -272,7 +349,8 @@ def _get_model_class(llm_type, model_path):
else:
lora_model = None

if 'q_proj' in MODEL_LORA_TARGET_MODULES[args.model_type] and 'v_proj' in MODEL_LORA_TARGET_MODULES[args.model_type]:
if 'q_proj' in MODEL_LORA_TARGET_MODULES[args.model_type] and 'v_proj' in MODEL_LORA_TARGET_MODULES[
args.model_type]:
lora_type = 'q_v_proj'
elif 'query_key_value' in MODEL_LORA_TARGET_MODULES[args.model_type]:
lora_type = 'query_key_value'
Expand All @@ -289,11 +367,13 @@ def generate_service_prompt(instruction, llm, lora):
return PROMPT_DICT['prompt_format_before'] + instruction + PROMPT_DICT['prompt_format_after']
else:
if llm in ['moss']:
return META_INSTRUCTION.get('moss',"") + PROMPT_DICT['prompt_format_before'] + instruction + PROMPT_DICT['prompt_format_after']
return META_INSTRUCTION.get('moss', "") + PROMPT_DICT['prompt_format_before'] + instruction + PROMPT_DICT[
'prompt_format_after']
return PROMPT_DICT['prompt_format_before'] + instruction + PROMPT_DICT['prompt_format_after']


def get_generation_config(llm):

generation_configs = GenerationConfig(
temperature=GENERATE_CONFIG['temperature'],
top_p=GENERATE_CONFIG['top_p'],
Expand All @@ -320,9 +400,9 @@ def prepare_model_for_training(
model: PreTrainedModel,
output_embedding_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"]
# for LLaMA, BLOOM and Falcon settings
) -> PreTrainedModel:

for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
Expand All @@ -333,10 +413,11 @@ def prepare_model_for_training(
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
model.config.use_cache = False # turn off when gradient checkpointing is enabled

if hasattr(model, output_embedding_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
Expand All @@ -349,4 +430,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))

return model
return model
49 changes: 43 additions & 6 deletions web.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from fastapi import FastAPI
import datetime

from fastapi import FastAPI, Request
import uvicorn
import time
import json
Expand All @@ -12,7 +14,7 @@


parser = argparse.ArgumentParser(description='Process some llm info.')
parser.add_argument('--llm', type=str, default="chatglm", choices=AVAILABLE_MODEL,
parser.add_argument('--model_type', type=str, default="chatglm", choices=AVAILABLE_MODEL,
help='the base structure (not the model) used for model or fine-tuned model')
parser.add_argument('--model_path', type=str, default="7b",
help='the type for base model or the absolute path for fine-tuned model')
Expand All @@ -22,12 +24,18 @@
parser.add_argument('--lora_alpha', default=16, type=int)
parser.add_argument('--lora_dropout', default=0.05, type=float)
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed serving')
parser.add_argument('--quantization_bit', default=None, type=int, help="The number of bits to quantize the model.")
parser.add_argument('--compute_dtype', default="fp16", type=str)
args = parser.parse_args()

# GPU count
NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else None
device = torch.device("cuda") if NUM_GPUS>0 else torch.device("cpu")

# load model & tokenizer
model, tokenizer = get_fine_tuned_model(args)
model = model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
if torch.__version__ >= "2" and sys.platform != "win32" and sys.version_info < (3, 11):
model = torch.compile(model)


Expand All @@ -36,20 +44,27 @@

def server(instruction):
# 1. generate input
prompt = generate_service_prompt(instruction, args.llm, args.lora_dir)
prompt = generate_service_prompt(instruction, args.model_type, args.lora_dir)
# 2. encoder
generation_config = get_generation_config(args.llm)
generation_config = get_generation_config(args.model_type)
inputs_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(DEVICE_TYPE)
# 3. generate & decoder
outputs = model.generate(
input_ids=inputs_ids,
generation_config=generation_config
)
res = tokenizer.decode(outputs[0], skip_special_tokens=True)
output = generate_service_output(res, prompt, args.llm, args.lora_dir)
output = generate_service_output(res, prompt, args.model_type, args.lora_dir)
return output


# garbage collection
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

@app.get("/")
def read_root():
return {"Hello": "World"}
Expand All @@ -63,5 +78,27 @@ def read_item(query: str):
print(json.dumps(res, ensure_ascii=False))
return res

@app.post("/")
async def create_item(request: Request):
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get("prompt")
messages = []
messages.append({"role": "user", "content": prompt})
response = model.chat(tokenizer, messages)

now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")

answer = {
"response": response,
"status": 200,
"time": time
}
log = "["+ time +"]" + '",prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer

uvicorn.run(app, host="0.0.0.0", port=8410)

0 comments on commit 8299418

Please sign in to comment.