forked from FLock-io/testnet-training-node-quickstart
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
97 lines (83 loc) · 2.57 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from dataclasses import dataclass
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
from dataset import SFTDataCollator, SFTDataset
from merge import merge_lora_to_base_model
from utils.constants import model2template
@dataclass
class LoraTrainingArguments:
per_device_train_batch_size: int
gradient_accumulation_steps: int
num_train_epochs: int
lora_rank: int
lora_alpha: int
lora_dropout: int
def train_lora(
model_id: str, context_length: int, training_args: LoraTrainingArguments
):
assert model_id in model2template, f"model_id {model_id} not supported"
lora_config = LoraConfig(
r=training_args.lora_rank,
target_modules=[
"q_proj",
"v_proj",
],
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
task_type="CAUSAL_LM",
)
# Load model in 4-bit to do qLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
training_args = SFTConfig(
per_device_train_batch_size=training_args.per_device_train_batch_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
warmup_steps=100,
learning_rate=2e-4,
bf16=True,
logging_steps=20,
output_dir="outputs",
optim="paged_adamw_8bit",
remove_unused_columns=False,
num_train_epochs=training_args.num_train_epochs,
max_seq_length=context_length,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map={"": 0},
token=os.environ["HF_TOKEN"],
)
# Load dataset
dataset = SFTDataset(
file="demo_data.jsonl",
tokenizer=tokenizer,
max_seq_length=context_length,
template=model2template[model_id],
)
# Define trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
peft_config=lora_config,
data_collator=SFTDataCollator(tokenizer, max_seq_length=context_length),
)
# Train model
trainer.train()
# save model
trainer.save_model("outputs")
# remove checkpoint folder
os.system("rm -rf outputs/checkpoint-*")
# upload lora weights and tokenizer
print("Training Completed.")