Skip to content

Commit 9ff32fd

Browse files
committed
fix parameters are not freezed
1 parent a097c42 commit 9ff32fd

File tree

1 file changed

+53
-53
lines changed

1 file changed

+53
-53
lines changed

networks/dylora.py

+53-53
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,17 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_
5050
kernel_size = org_module.kernel_size
5151
self.stride = org_module.stride
5252
self.padding = org_module.padding
53-
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size)))
54-
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1)))
53+
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
54+
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
5555
else:
56-
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim)))
57-
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim)))
56+
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
57+
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
5858

5959
# same as microsoft's
60-
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
61-
torch.nn.init.zeros_(self.lora_B)
60+
for lora in self.lora_A:
61+
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
62+
for lora in self.lora_B:
63+
torch.nn.init.zeros_(lora)
6264

6365
self.multiplier = multiplier
6466
self.org_module = org_module # remove in applying
@@ -76,38 +78,18 @@ def forward(self, x):
7678
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
7779

7880
# 一部のパラメータを固定して、残りのパラメータを学習する
79-
80-
# make lora_A
81-
if trainable_rank > 0:
82-
lora_A_nt1 = [self.lora_A[:trainable_rank].detach()]
83-
else:
84-
lora_A_nt1 = []
85-
86-
lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit]
87-
88-
if trainable_rank < self.lora_dim - self.unit:
89-
lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()]
90-
else:
91-
lora_A_nt2 = []
92-
93-
lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0)
94-
95-
# make lora_B
96-
if trainable_rank > 0:
97-
lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()]
98-
else:
99-
lora_B_nt1 = []
100-
101-
lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit]
102-
103-
if trainable_rank < self.lora_dim - self.unit:
104-
lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()]
105-
else:
106-
lora_B_nt2 = []
107-
108-
lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1)
109-
110-
# print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size())
81+
for i in range(0, trainable_rank):
82+
self.lora_A[i].requires_grad = False
83+
self.lora_B[i].requires_grad = False
84+
for i in range(trainable_rank, trainable_rank + self.unit):
85+
self.lora_A[i].requires_grad = True
86+
self.lora_B[i].requires_grad = True
87+
for i in range(trainable_rank + self.unit, self.lora_dim):
88+
self.lora_A[i].requires_grad = False
89+
self.lora_B[i].requires_grad = False
90+
91+
lora_A = torch.cat(tuple(self.lora_A), dim=0)
92+
lora_B = torch.cat(tuple(self.lora_B), dim=1)
11193

11294
# calculate with lora_A and lora_B
11395
if self.is_conv2d_3x3:
@@ -116,13 +98,13 @@ def forward(self, x):
11698
else:
11799
ab = x
118100
if self.is_conv2d:
119-
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2)
101+
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
120102

121103
ab = torch.nn.functional.linear(ab, lora_A)
122104
ab = torch.nn.functional.linear(ab, lora_B)
123105

124106
if self.is_conv2d:
125-
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:])
107+
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
126108

127109
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
128110
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
@@ -131,34 +113,52 @@ def forward(self, x):
131113
return result
132114

133115
def state_dict(self, destination=None, prefix="", keep_vars=False):
134-
# state dictを通常のLoRAと同じにする
135-
state_dict = super().state_dict(destination, prefix, keep_vars)
116+
# state dictを通常のLoRAと同じにする:
117+
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
118+
sd = super().state_dict(destination, prefix, keep_vars)
136119

137-
lora_A_weight = state_dict.pop(self.lora_name + ".lora_A")
120+
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
138121
if self.is_conv2d and not self.is_conv2d_3x3:
139122
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
140-
state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight
141123

142-
lora_B_weight = state_dict.pop(self.lora_name + ".lora_B")
124+
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
143125
if self.is_conv2d and not self.is_conv2d_3x3:
144126
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
145-
state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight
146127

147-
return state_dict
128+
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
129+
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
130+
131+
i = 0
132+
while True:
133+
key_a = f"{self.lora_name}.lora_A.{i}"
134+
key_b = f"{self.lora_name}.lora_B.{i}"
135+
if key_a in sd:
136+
sd.pop(key_a)
137+
sd.pop(key_b)
138+
else:
139+
break
140+
i += 1
141+
return sd
148142

149143
def load_state_dict(self, state_dict, strict=True):
150144
# 通常のLoRAと同じstate dictを読み込めるようにする
151145
state_dict = state_dict.copy()
152146

153-
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight")
154-
if self.is_conv2d and not self.is_conv2d_3x3:
155-
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
156-
state_dict[self.lora_name + ".lora_A"] = lora_A_weight
147+
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
148+
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
157149

158-
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight")
150+
if lora_A_weight is None or lora_B_weight is None:
151+
if strict:
152+
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
153+
else:
154+
return
155+
159156
if self.is_conv2d and not self.is_conv2d_3x3:
157+
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
160158
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
161-
state_dict[self.lora_name + ".lora_B"] = lora_B_weight
159+
160+
state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))})
161+
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))})
162162

163163
super().load_state_dict(state_dict, strict=strict)
164164

0 commit comments

Comments
 (0)