Skip to content

Commit

Permalink
Update dpo_llama3.py
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopoli authored May 10, 2024
1 parent d5f3993 commit fec06d1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions model_adaptation/dpo_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
max_seq_length = 8192
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model_name = "Meta-Llama-3-8B-Instruct"
model_name = "swap-uniba/LLaMAntino-3-ANITA-8B-Inst-DPO-ITA"

model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_name,
Expand Down Expand Up @@ -106,7 +106,7 @@ def apply_dpo_template(example):
dpo_trainer.train()

### SAVE the NEW MODEL ###
new_model = model_name+"_adapters"
new_model = model_name+"_DPO_adapters"
dpo_trainer.save_model(new_model)

# Reload model in FP16 and merge it with LoRA weights
Expand All @@ -115,7 +115,7 @@ def apply_dpo_template(example):
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.bfloat16,
device_map="balanced"
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()
Expand Down

0 comments on commit fec06d1

Please sign in to comment.