Skip to content

Commit

Permalink
percentage adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
hepengfe committed Mar 28, 2023
1 parent ae9c744 commit 7cab11c
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions peft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, arguments):

)
cur_trainable_params_percentage = self.convert_to_peft(config)
while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.000001:
while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.00001:
if cur_trainable_params_percentage > self.arguments.trainable_params_percentage:
cur_prompt_len -= 1
else:
Expand All @@ -116,17 +116,21 @@ def __init__(self, arguments):
self.arguments.run_name += "_prompt_len_{}".format(cur_prompt_len)
elif arguments.mode == "prefix_tuning":
from transformers.adapters import PrefixTuningConfig
cur_prefix_len = 10 if self.arguments.prefix_len is None else self.arguments.prefix_len
from peft import PrefixTuningConfig

cur_prefix_len = 100 if self.arguments.prefix_len is None else self.arguments.prefix_len
assert self.arguments.trainable_params_percentage is not None or self.arguments.prefix_len > 0, "either prefix_len or trainable_params_percentage should be set"
if self.arguments.trainable_params_percentage is not None:
config = PrefixTuningConfig(flat=True, prefix_length=cur_prefix_len)
# config = PrefixTuningConfig(prefix_length=cur_prefix_len, bottleneck_size=512)
config = PrefixTuningConfig(task_type=task_type, num_virtual_tokens=cur_prefix_len)
cur_trainable_params_percentage = self.convert_to_peft(config)
while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.001:
while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.00005:
if cur_trainable_params_percentage > self.arguments.trainable_params_percentage:
cur_prefix_len -= 1
else:
cur_prefix_len += 1
config = PrefixTuningConfig(flat=True, prefix_length=cur_prefix_len)
# config = PrefixTuningConfig(prefix_length=cur_prefix_len, bottleneck_size=512)
config = PrefixTuningConfig(task_type=task_type, num_virtual_tokens=cur_prefix_len)
cur_trainable_params_percentage = self.convert_to_peft(config, reset_peft=True)
print("prefix length is {}".format(cur_prefix_len))
self.arguments.run_name += "_prefix_len_{}".format(cur_prefix_len)
Expand All @@ -144,7 +148,7 @@ def __init__(self, arguments):
)
cur_trainable_params_percentage = self.convert_to_peft(config, reset_peft=True)

while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.001:
while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.0001:
if cur_trainable_params_percentage > self.arguments.trainable_params_percentage:
cur_lora_r -= 1
else:
Expand All @@ -160,7 +164,7 @@ def __init__(self, arguments):
print("cur_lora_r", cur_lora_r, "cur_trainable_params_percentage", cur_trainable_params_percentage)
self.arguments.lora_r = cur_lora_r
self.arguments.run_name += "_lora_r_" + str(cur_lora_r)
elif arguments.mode in ["adapter", "bitfit", "compactor"]:
elif arguments.mode in ["adapter", "compactor"]:
cur_reduction_factor = self.arguments.reduction_factor if self.arguments.reduction_factor is not None else 32
assert self.arguments.trainable_params_percentage is not None or self.arguments.reduction_factor > 0, "either reduction_factor or trainable_params_percentage should be set"

Expand All @@ -171,7 +175,7 @@ def __init__(self, arguments):
# config = AdapterConfig()
config = HoulsbyConfig(reduction_factor=cur_reduction_factor)
cur_trainable_params_percentage = self.convert_to_peft(config)
while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.001:
while abs(cur_trainable_params_percentage - self.arguments.trainable_params_percentage) > 0.00002:
if cur_trainable_params_percentage > self.arguments.trainable_params_percentage:
cur_reduction_factor += 1
else:
Expand All @@ -185,6 +189,8 @@ def __init__(self, arguments):
if self.num_soft_tokens > 0:
self.num_soft_tokens = 0
print("num_soft_tokens is set to 0 for embedding tuning mode")
elif arguments.mode == "bitfit":
self.convert_to_peft()
elif arguments.mode == "fine_tuning":
pass
elif arguments.mode == "layer_tuning":
Expand Down Expand Up @@ -277,7 +283,7 @@ def set_up_hf_trainer(self):
else:
dataset_dependent_data_collator = default_data_collator

if self.arguments.mode in ["adapter", "prefix_tuning", "compactor"]:
if self.arguments.mode in ["adapter", "compactor"]: # "prefix_tuning",
self.trainer = Seq2SeqAdapterTrainer(
model = self.model,
tokenizer = self.tokenizer,
Expand Down Expand Up @@ -418,6 +424,7 @@ def load_model(self):

# m.encoder.block[0].layer[0].SelfAttention.q.weight.bias
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_names_or_path, cache_dir=self.arguments.cache_dir,)

# m = T5PT.from_pretrained(
# self.model_names_or_path,
# # from_tf=bool(".ckpt" in self.model_names_or_path),
Expand All @@ -439,7 +446,7 @@ def load_model(self):
raise NotImplementedError("Model not supported: " + self.model_names_or_path)
# Wrap model in adapter package
# NOTE: temp implementation
if self.arguments.mode in ["adapter", "prefix_tuning", "compactor"] :
if self.arguments.mode in ["adapter", "compactor"] : # "prefix_tuning",
self.model = AutoAdapterModel.from_pretrained(self.model_names_or_path, cache_dir=self.arguments.cache_dir,)

if self.tokenizer.pad_token is None:
Expand Down Expand Up @@ -733,20 +740,22 @@ def convert_to_peft(self, peft_config=None, reset_peft=False):
peft_config (_type_): _description_
"""

if self.arguments.mode in ["adapter", "prefix_tuning", "compactor"]:
if self.arguments.mode in ["adapter", "compactor"]: # prefix_tuning

# add and activate adapter
self.model.add_adapter("sst-2", config = peft_config, overwrite_ok=reset_peft)
self.model.train_adapter("sst-2")
if self.arguments.model_arch == "encoder":
self.model.add_classification_head("classification-head-sst-2", num_labels=2, overwrite_ok=reset_peft)
elif self.arguments.model_arch == "encoder-decoder":
self.model.add_seq2seq_lm_head("seq2seq-head-sst-2", overwrite_ok=reset_peft)
# self.model.add_seq2seq_lm_head("seq2seq-head-sst-2", overwrite_ok=reset_peft)
pass
else:
raise NotImplementedError(
f"Not implemented for model arch: {self.arguments.model_arch}"
)
self.model.set_active_adapters("sst-2")
# self.model.freeze_model(True)
elif self.arguments.mode == "bitfit":
# if self.arguments.model_arch == "encoder":
# # deactivate gradients except for bias terms
Expand Down Expand Up @@ -785,7 +794,7 @@ def convert_to_peft(self, peft_config=None, reset_peft=False):

# assert is_bias_init == True, "bias should be initialized"

raise NotImplementedError("bitfit is not computed for trainable paramters yet")
# raise NotImplementedError("bitfit is not computed for trainable paramters yet")
# components = ["intermediate", "key", "query", "value", "output", "output_layernorm", "attention_layernorm", "all"]
# trainable_components = convert_to_actual_components(components)
# self._deactivate_relevant_gradients(trainable_components)
Expand Down Expand Up @@ -830,9 +839,10 @@ def convert_to_peft(self, peft_config=None, reset_peft=False):
if not module.bias.requires_grad:
print("activate gradient for ", name)
module.bias.requires_grad = True
for name, module in self.model.named_modules():
if "lm_head" in name:
module.weight.requires_grad = True
# lm head takes almost 10% paramaters, so we don't want to train it
# for name, module in self.model.named_modules():
# if "lm_head" in name:
# module.weight.requires_grad = True

else:
# NOTE: prompt tuning
Expand All @@ -846,19 +856,24 @@ def convert_to_peft(self, peft_config=None, reset_peft=False):
self.model = deepcopy(self.model_cache)
# add tokens in models and tokenizers + freeze model
self.model.enable_input_require_grads()

self.model = get_peft_model(self.model, peft_config)


return self.check_trainable_parameters()



def check_trainable_parameters(self, print_params_required_grad = False):
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
# print_params_required_grad = True
if print_params_required_grad:
for n, p in self.model.named_parameters():
if p.requires_grad:
print(n)
print(p.data.shape, n)
print(f"Total Params: {total_params}, Trainable Params: {trainable_params}, Trainable Ratio: {trainable_params/total_params}")

return trainable_params/total_params


Expand Down

0 comments on commit 7cab11c

Please sign in to comment.