Skip to content

Commit 230e03f

Browse files
authored
mark_lora_as_trainable (PaddlePaddle#5241)
1 parent 586755b commit 230e03f

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

paddlenlp/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@
1818
GPLinkerForEventExtraction,
1919
GPLinkerForRelationExtraction,
2020
)
21-
from .lora import LoRAConfig, LoRALinear, get_lora_model
21+
from .lora import *
2222
from .sequence import sequence_mask
2323
from .tcn import TCN, TemporalBlock

paddlenlp/layers/lora.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
import paddle.nn.functional as F
2424

2525
from ..utils.env import LORA_CONFIG_NAME
26+
from ..utils.log import logger
27+
28+
__all__ = [
29+
"LoRAConfig",
30+
"LoRALinear",
31+
"get_lora_model",
32+
"mark_only_lora_as_trainable",
33+
]
2634

2735

2836
class LoRALinear(nn.Linear):
@@ -116,6 +124,17 @@ def _find_and_replace_module(model, module_name, lora_config):
116124
setattr(parent_module, attribute_chain[-1], lora_module)
117125

118126

127+
def mark_only_lora_as_trainable(model: nn.Layer) -> None:
128+
freeze_numel, trainable_numel = 0, 0
129+
for name, weight in model.state_dict().items():
130+
if "lora" not in name:
131+
weight.stop_gradient = True
132+
freeze_numel += weight.numel().numpy()[0]
133+
else:
134+
trainable_numel += weight.numel().numpy()[0]
135+
logger.info(f"{freeze_numel:.2e} parameters are frozen, {trainable_numel:.2e} LoRA parameters are trainable")
136+
137+
119138
@dataclass
120139
class LoRAConfig:
121140
"""

tests/layers/test_lora.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import numpy as np
2222
import paddle
2323

24-
from paddlenlp.layers import LoRAConfig, LoRALinear, get_lora_model
24+
from paddlenlp.layers import (
25+
LoRAConfig,
26+
LoRALinear,
27+
get_lora_model,
28+
mark_only_lora_as_trainable,
29+
)
2530
from paddlenlp.transformers import AutoModel
2631

2732

@@ -88,14 +93,22 @@ def test_get_lora_model(self):
8893
"__internal_testing__/tiny-random-bert", hidden_dropout_prob=0, attention_probs_dropout_prob=0
8994
)
9095
lora_model = get_lora_model(model, lora_config)
96+
mark_only_lora_as_trainable(lora_model)
9197
state_dict = lora_model.state_dict()
9298
for weight_name in state_dict:
99+
is_target_module = False
93100
for target_module in lora_config.target_modules:
94101
if re.fullmatch(target_module, weight_name):
95-
if "lora" in weight_name:
96-
self.assertFalse(state_dict[weight_name].stop_gradient)
97-
else:
98-
self.assertTrue(state_dict[weight_name].stop_gradient)
102+
is_target_module = True
103+
# if this is a target module, lora weights are trainable, non-lora weights are not
104+
if is_target_module:
105+
if "lora" in weight_name:
106+
self.assertFalse(state_dict[weight_name].stop_gradient)
107+
else:
108+
self.assertTrue(state_dict[weight_name].stop_gradient)
109+
# if this is not a target module, all weights are not trainable
110+
else:
111+
self.assertTrue(state_dict[weight_name].stop_gradient)
99112
input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20]))
100113
model.train()
101114
train_forward_results = model(input_ids)

0 commit comments

Comments
 (0)