Skip to content

Commit

Permalink
feat: add a demo training script
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick committed May 8, 2024
1 parent ede54fe commit 927f7f4
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 0 deletions.
128 changes: 128 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
from typing import Any, Dict, List

import torch
from loguru import logger
from torch.utils.data import Dataset


class UnifiedSFTDataset(Dataset):
def __init__(self, file, tokenizer, max_seq_length):
self.tokenizer = tokenizer
self.system_format = "<bos>"
self.user_format = (
"<start_of_turn>user\n{content}<end_of_turn>\n<start_of_turn>model\n"
)
self.assistant_format = "{content}<|eot_id|>"
self.system = None

self.max_seq_length = max_seq_length
logger.info("Loading data: {}".format(file))
with open(file, "r", encoding="utf8") as f:
data_list = f.readlines()
logger.info("There are {} data in dataset".format(len(data_list)))
self.data_list = data_list

def __len__(self):
return len(self.data_list)

def __getitem__(self, index):
data = self.data_list[index]
data = json.loads(data)
input_ids, target_mask = [], []

# setting system information
if self.system_format is not None:
system = data["system"].strip() if "system" in data.keys() else self.system

if system is not None:
system_text = self.system_format.format(content=system)
input_ids = self.tokenizer.encode(system_text, add_special_tokens=False)
target_mask = [0] * len(input_ids)

conversations = data["conversations"]

for i in range(0, len(conversations) - 1, 2):
if (
conversations[i]["role"] != "user"
or conversations[i + 1]["role"] != "assistant"
):
raise ValueError("The role order of the conversation is not correct")
human = conversations[i]["content"].strip()
assistant = conversations[i + 1]["content"].strip()

human = self.user_format.format(
content=human, stop_token=self.tokenizer.eos_token
)
assistant = self.assistant_format.format(
content=assistant, stop_token=self.tokenizer.eos_token
)

input_tokens = self.tokenizer.encode(human, add_special_tokens=False)
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False)

input_ids += input_tokens + output_tokens
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens)

assert len(input_ids) == len(target_mask)

input_ids = input_ids[: self.max_seq_length]
target_mask = target_mask[: self.max_seq_length]
attention_mask = [1] * len(input_ids)
assert len(input_ids) == len(target_mask) == len(attention_mask)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"target_mask": target_mask,
}
return inputs


class SFTDataCollator(object):
def __init__(self, tokenizer, max_seq_length):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.pad_token_id = tokenizer.pad_token_id

def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
# 找出batch中的最大长度
lengths = [len(x["input_ids"]) for x in batch if x["input_ids"] is not None]
# 取出batch中的最大长度,如果超过max_seq_length,则取max_seq_length
batch_max_len = min(max(lengths), self.max_seq_length)
# batch_max_len = self.max_seq_length

input_ids_batch, attention_mask_batch, target_mask_batch = [], [], []
# truncate and padding
for x in batch:
input_ids = x["input_ids"]
attention_mask = x["attention_mask"]
target_mask = x["target_mask"]
if input_ids is None:
logger.info("some input_ids is None")
continue
padding_len = batch_max_len - len(input_ids)
# padding
input_ids = input_ids + [self.pad_token_id] * padding_len
attention_mask = attention_mask + [0] * padding_len
target_mask = target_mask + [0] * padding_len
# truncate
input_ids = input_ids[: self.max_seq_length]
attention_mask = attention_mask[: self.max_seq_length]
target_mask = target_mask[: self.max_seq_length]

input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
target_mask_batch.append(target_mask)

# 将list转换为tensor,得到最终的的模型输入
input_ids_batch = torch.tensor(input_ids_batch, dtype=torch.long)
attention_mask_batch = torch.tensor(attention_mask_batch, dtype=torch.long)
target_mask_batch = torch.tensor(target_mask_batch, dtype=torch.long)

labels = torch.where(target_mask_batch == 1, input_ids_batch, -100)
inputs = {
"input_ids": input_ids_batch,
"attention_mask": attention_mask_batch,
"labels": labels,
}
return inputs
86 changes: 86 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os

import torch
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TrainingArguments)
from trl import SFTTrainer

from dataset import SFTDataCollator, UnifiedSFTDataset
from merge import merge_lora_to_base_model
from utils import load_model, load_tokenizer

lora_config = LoraConfig(
r=8,
target_modules=[
"q_proj",
"o_proj",
"k_proj",
"v_proj",
"gate_proj",
"up_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)


model_id = "google/gemma-2b"
# Load model in 4-bit to do qLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

training_args = TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
warmup_steps=2,
max_steps=10,
learning_rate=2e-4,
bf16=True,
logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit",
remove_unused_columns=False,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map={"": 0},
token=os.environ["HF_TOKEN"],
)

# Load dataset
dataset = UnifiedSFTDataset(
file="demo_data.jsonl",
tokenizer=tokenizer,
max_seq_length=512,
)

# Define trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
peft_config=lora_config,
packing=True,
data_collator=SFTDataCollator(tokenizer, max_seq_length=512),
max_seq_length=512,
)

# Train model
trainer.train()

# save model
trainer.save_model("outputs")

# merge lora to base model
merge_lora_to_base_model(
model_name_or_path="google/gemma-2b",
adapter_name_or_path="outputs",
save_path="merged_model",
)
Loading

0 comments on commit 927f7f4

Please sign in to comment.