Skip to content

Commit

Permalink
modularize peft configure and load
Browse files Browse the repository at this point in the history
  • Loading branch information
hepengfe committed Apr 7, 2023
1 parent 9e1b0cd commit eb1d450
Showing 1 changed file with 58 additions and 53 deletions.
111 changes: 58 additions & 53 deletions peft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,45 @@ def __init__(self, arguments):



self.configure_n_convert_peft()


if self.arguments.trainable_params_percentage and not self.arguments.mode in ["compactor","fine_tuning"]:
# not check compactor
assert abs(self.arguments.trainable_params_percentage - cur_trainable_params_percentage) < 0.002, f"trainable_params_percentage {self.arguments.trainable_params_percentage} is not matched with cur_trainable_params_percentage {cur_trainable_params_percentage}"

# deactivate


# NOTE: set lm head trainable again
# if hasattr(self.model, "lm_head"): # peft model
# self.model.lm_head.weight.requires_grad = True
# else: # adapter model
# self.model.heads["seq2seq-head-sst-2"][0].weight.requires_grad
# self.arguments.run_name += f"lm_head_trainable"

time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# self.model = self.model_cache
if self.arguments.trainable_params_percentage:
self.arguments.run_name += f"_trainable_params_percentage_{self.arguments.trainable_params_percentage}"
# prepare runname before passing to trainer


if self.arguments.num_training_tasks:
self.arguments.run_name += f"_num_training_tasks_{self.arguments.num_training_tasks}"
self.arguments.run_name += f"_{time}"

# import pdb; pdb.set_trace()
# print("check run_name", self.arguments.run_name)
self.set_up_hf_trainer()
self.tokenizer = self.tokenizer


def configure_n_convert_peft(self):
# model loading procedure:
# 1. load model from model_names_or_path (self.load_model())
# 2. not satisfied with peft, load model from self.model_cache and convert again. self.model = deepcopy(self.model_cache)
if arguments.mode == "prompt_tuning":
if self.arguments.mode == "prompt_tuning":
cur_prompt_len = 1
assert self.arguments.trainable_params_percentage is not None or self.arguments.prompt_len > 0, "either prompt_len or trainable_params_percentage should be set"
config = PromptTuningConfig(
Expand Down Expand Up @@ -118,7 +151,7 @@ def __init__(self, arguments):
print("trainable params percentage is {}".format(cur_trainable_params_percentage))
cur_prompt_len += 1
self.arguments.run_name += "_prompt_len_{}".format(cur_prompt_len-1)
elif arguments.mode == "prefix_tuning":
elif self.arguments.mode == "prefix_tuning":
from transformers.adapters import PrefixTuningConfig
# from peft import PrefixTuningConfig

Expand All @@ -143,7 +176,7 @@ def __init__(self, arguments):



elif arguments.mode == "lora":
elif self.arguments.mode == "lora":
# peft package
cur_lora_r = 15 if self.arguments.lora_r is None else self.arguments.lora_r
assert self.arguments.trainable_params_percentage is not None or self.arguments.lora_r > 0, "either lora_r or trainable_params_percentage should be set"
Expand Down Expand Up @@ -200,7 +233,7 @@ def __init__(self, arguments):
self.arguments.run_name += "_lora_r_" + str(cur_lora_r)
if self.arguments.lora_modules:
self.arguments.run_name += "_lora_modules_" + self.arguments.lora_modules
elif arguments.mode == "ia3":
elif self.arguments.mode == "ia3":
from transformers.adapters import IA3Config
cur_lora_r = 15 if self.arguments.lora_r is None else self.arguments.lora_r
assert self.arguments.trainable_params_percentage is not None or self.arguments.lora_r > 0, "either lora_r or trainable_params_percentage should be set"
Expand All @@ -220,14 +253,14 @@ def __init__(self, arguments):
self.arguments.run_name += "_lora_r_" + str(cur_lora_r)


elif arguments.mode in ["adapter", "compactor"]:
elif self.arguments.mode in ["adapter", "compactor"]:
cur_reduction_factor = 64 if self.arguments.reduction_factor is None else self.arguments.reduction_factor
assert self.arguments.trainable_params_percentage is not None or self.arguments.reduction_factor > 0, "either reduction_factor or trainable_params_percentage should be set"

from transformers.adapters import AdapterConfig, HoulsbyConfig, CompacterConfig
# check existing adapter and remove them
# config = AdapterConfig()
if arguments.mode == "adapter":
if self.arguments.mode == "adapter":
config = HoulsbyConfig(reduction_factor=cur_reduction_factor)
else:
config = CompacterConfig(reduction_factor=cur_reduction_factor,
Expand All @@ -236,7 +269,7 @@ def __init__(self, arguments):
cur_trainable_params_percentage = self.convert_to_peft(config)
while self.arguments.trainable_params_percentage and cur_trainable_params_percentage < self.arguments.trainable_params_percentage:
cur_reduction_factor /=1.01
if arguments.mode == "adapter":
if self.arguments.mode == "adapter":
config = HoulsbyConfig(reduction_factor=cur_reduction_factor)
else:
config = CompacterConfig(reduction_factor=cur_reduction_factor,
Expand All @@ -245,66 +278,66 @@ def __init__(self, arguments):
print(f"cur_trainable_params_percentage: {cur_trainable_params_percentage}, cur_reduction_factor: {cur_reduction_factor}")
# only keep 4 digits for reduction factor
self.arguments.run_name += f"_reduction_factor_{cur_reduction_factor:.4f}"
if arguments.mode == "compactor":
if self.arguments.mode == "compactor":
self.arguments.run_name += f"_phm_dim_{self.arguments.phm_dimension}"
elif arguments.mode == "parallel_adapter":
elif self.arguments.mode == "parallel_adapter":
from transformers.adapters import ParallelConfig
config = ParallelConfig(reduction_factor= self.arguments.reduction_factor)
cur_trainable_params_percentage = self.convert_to_peft(config)
print(f"cur_trainable_params_percentage: {cur_trainable_params_percentage}")
self.arguments.run_name += f"_reduction_factor_{self.arguments.reduction_factor:.4f}"
elif arguments.mode == "embedding_tuning":
elif self.arguments.mode == "embedding_tuning":
self.convert_to_embedding_tuning()
if self.num_soft_tokens > 0:
self.num_soft_tokens = 0
print("num_soft_tokens is set to 0 for embedding tuning mode")
elif arguments.mode == "bitfit":
elif self.arguments.mode == "bitfit":
self.convert_to_peft()
elif arguments.mode == "fine_tuning":
elif self.arguments.mode == "fine_tuning":
# no converting needed
pass
elif arguments.mode == "lm_head_tuning":
elif self.arguments.mode == "lm_head_tuning":
for param in self.model.parameters():
param.requires_grad = False
# lm head takes almost 10% paramaters
for name, module in self.model.named_modules():
if "lm_head" in name:
module.weight.requires_grad = True
elif arguments.mode == "layer_tuning":
elif self.arguments.mode == "layer_tuning":

for param in self.model.parameters():
param.requires_grad = False

layers = []
# NOTE: we only fine-tune attention weights for now
if arguments.layer_name == "first_encoder_layer":
if self.arguments.layer_name == "first_encoder_layer":
layers.append(self.model.encoder.block[0].layer[0])
elif arguments.layer_name == "last_encoder_layer":
elif self.arguments.layer_name == "last_encoder_layer":
layers.append(self.model.encoder.block[-1].layer[0])
elif arguments.layer_name == "first_decoder_layer":
elif self.arguments.layer_name == "first_decoder_layer":
layers.append(self.model.decoder.block[0].layer[0])
elif arguments.layer_name == "last_decoder_layer":
elif self.arguments.layer_name == "last_decoder_layer":
layers.append(self.model.decoder.block[-1].layer[0])
elif arguments.layer_name == "custom":
elif self.arguments.layer_name == "custom":
# all decoder layer
modules = self.model.decoder.block
for m in modules:
layers.append(m.layer[0])
elif arguments.layer_name == "custom2":
elif self.arguments.layer_name == "custom2":
# all decoder layer
modules = self.model.decoder.block
for m in modules:
layers.append(m.layer[0])
else:
raise NotImplementedError(f"layer_name {arguments.layer_name} is not implemented")
raise NotImplementedError(f"layer_name {self.arguments.layer_name} is not implemented")

for l in layers:
for name, module in l.named_modules():
# if "selfattention" in name.lower():
if hasattr(module, "weight"):
module.weight.requires_grad = True
print("activate gradient for ", name)
elif arguments.mode == "lora+adapter":
elif self.arguments.mode == "lora+adapter":
from transformers.adapters import AdapterConfig, HoulsbyConfig, CompacterConfig


Expand Down Expand Up @@ -374,7 +407,7 @@ def __init__(self, arguments):


# invalid
elif arguments.mode == "unipelt":
elif self.arguments.mode == "unipelt":
from transformers.adapters import UniPELTConfig, PrefixTuningConfig, PfeifferConfig, LoRAConfig, HoulsbyConfig
gating = False
reset_peft=False
Expand Down Expand Up @@ -404,37 +437,9 @@ def __init__(self, arguments):


else:
raise NotImplementedError(f"mode {arguments.mode} is not implemented")
if self.arguments.trainable_params_percentage and not self.arguments.mode in ["compactor","fine_tuning"]:
# not check compactor
assert abs(self.arguments.trainable_params_percentage - cur_trainable_params_percentage) < 0.002, f"trainable_params_percentage {self.arguments.trainable_params_percentage} is not matched with cur_trainable_params_percentage {cur_trainable_params_percentage}"

# deactivate
raise NotImplementedError(f"mode {self.arguments.mode} is not implemented")



# NOTE: set lm head trainable again
# if hasattr(self.model, "lm_head"): # peft model
# self.model.lm_head.weight.requires_grad = True
# else: # adapter model
# self.model.heads["seq2seq-head-sst-2"][0].weight.requires_grad
# self.arguments.run_name += f"lm_head_trainable"

time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# self.model = self.model_cache
if self.arguments.trainable_params_percentage:
self.arguments.run_name += f"_trainable_params_percentage_{self.arguments.trainable_params_percentage}"
# prepare runname before passing to trainer


if self.arguments.num_training_tasks:
self.arguments.run_name += f"_num_training_tasks_{self.arguments.num_training_tasks}"
self.arguments.run_name += f"_{time}"

# import pdb; pdb.set_trace()
# print("check run_name", self.arguments.run_name)
self.set_up_hf_trainer()
self.tokenizer = self.tokenizer


def set_up_hf_trainer(self):
del self.model_cache
Expand Down

0 comments on commit eb1d450

Please sign in to comment.