Skip to content

Commit

Permalink
Fix/trl-dependency-and-training-args (FLock-io#8)
Browse files Browse the repository at this point in the history
* chore: upgrade trl

* fix: replace trainingargument to trlconfig
  • Loading branch information
nickcom007 authored Jun 6, 2024
1 parent 534d606 commit 50c99ac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
10 changes: 4 additions & 6 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import torch
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TrainingArguments)
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig

from dataset import SFTDataCollator, SFTDataset
from merge import merge_lora_to_base_model
Expand Down Expand Up @@ -44,7 +43,7 @@ def train_and_merge(
bnb_4bit_compute_dtype=torch.bfloat16,
)

training_args = TrainingArguments(
training_args = SFTConfig(
per_device_train_batch_size=training_args.per_device_train_batch_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
warmup_steps=100,
Expand All @@ -55,6 +54,7 @@ def train_and_merge(
optim="paged_adamw_8bit",
remove_unused_columns=False,
num_train_epochs=training_args.num_train_epochs,
max_seq_length=context_length,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
Expand All @@ -81,9 +81,7 @@ def train_and_merge(
train_dataset=dataset,
args=training_args,
peft_config=lora_config,
packing=True,
data_collator=SFTDataCollator(tokenizer, max_seq_length=context_length),
max_seq_length=context_length,
)

# Train model
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ torch>=1.13.1
transformers>=4.37.2
peft>=0.10.0
loguru
trl>=0.8.1
trl>=0.9.3
bitsandbytes
pyyaml

0 comments on commit 50c99ac

Please sign in to comment.