Skip to content

Commit

Permalink
Update and rename finetune_glm.py to uniform_finetune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PhoebusSi authored Mar 29, 2023
1 parent 21677eb commit acf23ac
Showing 1 changed file with 61 additions and 30 deletions.
91 changes: 61 additions & 30 deletions finetune_glm.py → uniform_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LlamaForCausalLM, LlamaTokenizer,
AutoModel, AutoTokenizer,
BloomForCausalLM, BloomTokenizerFast)
from model_chatglm import ChatGLMForConditionalGeneration, ChatGLMTokenizer


from peft import (
prepare_model_for_int8_training,
Expand All @@ -42,8 +42,8 @@

}),
"chatglm": ModelClass(**{
"tokenizer": ChatGLMTokenizer,
"model": ChatGLMForConditionalGeneration,
"tokenizer": AutoTokenizer, #ChatGLMTokenizer,
"model": AutoModel, #ChatGLMForConditionalGeneration,
}),
"bloom": ModelClass(**{
"tokenizer": BloomTokenizerFast,
Expand All @@ -57,12 +57,13 @@

# add the custom dataset
DATA_PATH = {
"alpaca": "data/alpaca_data_cleaned.json",
"belle": "data/belle_data_cn.json",
"alpaca-belle": "data/alpaca_plus_belle_data.json",
"cot": "data/CoT_data.json",
"alpaca-cot": "data/alcapa_plus_cot.json",
"alpaca-belle-cot": "data/alcapa_plus_belle_plus_cot.json"
"alpaca": "alpaca_data_cleaned.json",
"belle": "/mnt/bn/qingyi-bn-lq/llama/belle-0.5M-cn/belle_data_cn.json",
"alpaca-belle": "/mnt/bn/qingyi-bn-lq/llama/belle-0.5M-cn/alpaca_plus_belle_data.json",
"cot": "/mnt/bn/qingyi-bn-lq/llama/all_formatted_data/CoT_data.json",
"alpaca-cot": "/mnt/bn/qingyi-bn-lq/llama/all_formatted_data/alcapa_plus_cot.json",
"alpaca-belle-cot": "/mnt/bn/qingyi-bn-lq/llama/all_formatted_data/alcapa_plus_belle_plus_cot.json",
"belle1.5m": "/mnt/bn/qingyi-bn-lq/llama/all_formatted_data/belle_data1.5M_cn.json.json"
}

PROMPT_DICT = {
Expand Down Expand Up @@ -104,10 +105,18 @@ def get_model_class(model_type):

model_class = get_model_class(args.model_type)

model = model_class.model.from_pretrained(args.model_name_or_path,
load_in_8bit=True,
device_map=device_map)
tokenizer = model_class.tokenizer.from_pretrained(args.model_name_or_path) # default add_eos_token=False
if args.model_type == "chatglm":
# chatglm can not set load_in_8bit=True: ChatGLMForConditionalGeneration does not support gradient checkpointing.
model = model_class.model.from_pretrained(args.model_name_or_path,
trust_remote_code=True,
device_map=device_map)
tokenizer = model_class.tokenizer.from_pretrained(args.model_name_or_path,trust_remote_code=True) # default add_eos_token=False
else:
model = model_class.model.from_pretrained(args.model_name_or_path,
load_in_8bit=True,
device_map=device_map)

tokenizer = model_class.tokenizer.from_pretrained(args.model_name_or_path) # default add_eos_token=False

# llama has no pad_id, maybe copy the stanford_alpaca's handling ?
if args.model_type == 'llama':
Expand Down Expand Up @@ -136,32 +145,55 @@ def train(args):
# 1. load data & model_class
data, model, tokenizer = get_data_model(args)

def tokenize(prompt):
result = tokenizer(prompt,
truncation=True,
max_length=args.cutoff_len,
# padding="max_length",
padding=False,
)
if "chatglm" in args.model_type:
def prompt_tokenize(prompt):
input_ids = tokenizer.encode(prompt)
return {
"input_ids": input_ids,
"labels": copy.deepcopy(input_ids)
}
def completion_tokenize(completion):
if completion[-4:] == '</s>':
input_ids = tokenizer.encode(completion[:-4]) #, add_special_tokens=False)
else:
input_ids = tokenizer.encode(completion) #, add_special_tokens=False)
return {
"input_ids": input_ids,
"labels": copy.deepcopy(input_ids)
}
else:
def tokenize(prompt):
result = tokenizer(prompt,
truncation=True,
max_length=args.cutoff_len,
# padding="max_length",
padding=False,
)

return {
"input_ids": result["input_ids"],
"attention_mask": result["attention_mask"],
"labels": copy.deepcopy(result["input_ids"])
}
return {
"input_ids": result["input_ids"],
"attention_mask": result["attention_mask"],
"labels": copy.deepcopy(result["input_ids"])
}


def generate_and_tokenize_prompt(data_point):
prompt_no_resp = generate_prompt(data_point)
tokenized_result = tokenize(prompt_no_resp)
if "chatglm" in args.model_type:
tokenized_result = prompt_tokenize(prompt_no_resp)
else:
tokenized_result = tokenize(prompt_no_resp)
source_len = len(tokenized_result['input_ids'])

prompt_with_response = prompt_no_resp + " " + data_point["output"]

if "llama" in args.model_type:
prompt_with_response += " " + tokenizer.eos_token
# if "llama" in args.model_type:
prompt_with_response += " " + tokenizer.eos_token

tokenized_with_response = tokenize(prompt_with_response)
if "chatglm" in args.model_type:
tokenized_with_response = completion_tokenize(prompt_with_response)
else:
tokenized_with_response = tokenize(prompt_with_response)

tokenized_with_response["labels"] = [IGNORE_INDEX] * source_len + tokenized_with_response["labels"][source_len:]

Expand Down Expand Up @@ -258,4 +290,3 @@ def generate_and_tokenize_prompt(data_point):
print(args)

train(args)

0 comments on commit acf23ac

Please sign in to comment.