Skip to content

Commit

Permalink
Add Target Modules flag to allow passing in the target modules to the…
Browse files Browse the repository at this point in the history
… lora fine tuning (GoogleCloudPlatform#3441)
  • Loading branch information
probably-not authored Sep 5, 2024
1 parent 66e7e4e commit 1c25cf5
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import AutoTokenizer
from transformers import BitsAndBytesConfig
from transformers import TrainingArguments
from typing import List
from util import constants


Expand All @@ -23,6 +24,7 @@ def finetune_causal_language_modeling(
lora_rank: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
target_modules: List[str] = constants.CAUSAL_LANGUAGE_MODELING_LORA_TARGET_MODULES,
warmup_steps: int = 10,
max_steps: int = 10,
learning_rate: float = 2e-4,
Expand Down Expand Up @@ -101,7 +103,7 @@ def forward(self, x):
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["q_proj", "v_proj"],
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from transformers import BitsAndBytesConfig
from transformers import TrainingArguments
from trl import SFTTrainer
from typing import List
from util import constants


def finetune_instruct(
Expand All @@ -18,6 +20,7 @@ def finetune_instruct(
lora_rank: int = 64,
lora_alpha: int = 16,
lora_dropout: float = 0.1,
target_modules: List[str] = constants.INSTRUCT_LORA_TARGET_MODULES,
warmup_ratio: int = 0.03,
max_steps: int = 10,
max_seq_length: int = 512,
Expand Down Expand Up @@ -50,12 +53,7 @@ def finetune_instruct(
r=lora_rank,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
],
target_modules=target_modules,
)

per_device_train_batch_size = 4
Expand Down
8 changes: 8 additions & 0 deletions community-content/vertex_model_garden/model_oss/peft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@
' https://huggingface.co/docs/peft/task_guides/token-classification-lora.',
)

_TARGET_MODULES = flags.DEFINE_list(
'target_modules',
constants.CAUSAL_LANGUAGE_MODELING_LORA_TARGET_MODULES,
'The comma separated list of target modules for LoRa training.',
)

_WARMUP_STEPS = flags.DEFINE_integer(
'warmup_steps',
10,
Expand Down Expand Up @@ -151,6 +157,7 @@ def main(_) -> None:
lora_rank=_LORA_RANK.value,
lora_alpha=_LORA_ALPHA.value,
lora_dropout=_LORA_DROPOUT.value,
target_modules=_TARGET_MODULES.value,
warmup_steps=_WARMUP_STEPS.value,
max_steps=_MAX_STEPS.value,
learning_rate=_LEARNING_RATE.value,
Expand All @@ -164,6 +171,7 @@ def main(_) -> None:
lora_rank=_LORA_RANK.value,
lora_alpha=_LORA_ALPHA.value,
lora_dropout=_LORA_DROPOUT.value,
target_modules=_TARGET_MODULES.value,
warmup_ratio=_WARMUP_RATIO.value,
max_steps=_MAX_STEPS.value,
max_seq_length=_MAX_SEQ_LENGTH.value,
Expand Down
10 changes: 10 additions & 0 deletions community-content/vertex_model_garden/model_oss/util/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@
SEQUENCE_CLASSIFICATION_LORA = 'sequence-classification-lora'
CAUSAL_LANGUAGE_MODELING_LORA = 'causal-language-modeling-lora'
INSTRUCT_LORA = 'instruct-lora'
CAUSAL_LANGUAGE_MODELING_LORA_TARGET_MODULES = [
"q_proj",
"v_proj",
]
INSTRUCT_LORA_TARGET_MODULES = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]

# Precision modes for loading model weights.
PRECISION_MODE_4 = '4bit'
Expand Down

0 comments on commit 1c25cf5

Please sign in to comment.