@@ -50,15 +50,17 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_
50
50
kernel_size = org_module .kernel_size
51
51
self .stride = org_module .stride
52
52
self .padding = org_module .padding
53
- self .lora_A = nn .Parameter ( org_module .weight .new_zeros ((self . lora_dim , in_dim , * kernel_size )))
54
- self .lora_B = nn .Parameter ( org_module .weight .new_zeros ((out_dim , self . lora_dim , 1 , 1 )))
53
+ self .lora_A = nn .ParameterList ([ org_module .weight .new_zeros ((1 , in_dim , * kernel_size )) for _ in range ( self . lora_dim )] )
54
+ self .lora_B = nn .ParameterList ([ org_module .weight .new_zeros ((out_dim , 1 , 1 , 1 )) for _ in range ( self . lora_dim )] )
55
55
else :
56
- self .lora_A = nn .Parameter ( org_module .weight .new_zeros ((self . lora_dim , in_dim )))
57
- self .lora_B = nn .Parameter ( org_module .weight .new_zeros ((out_dim , self .lora_dim )) )
56
+ self .lora_A = nn .ParameterList ([ org_module .weight .new_zeros ((1 , in_dim )) for _ in range ( self . lora_dim )] )
57
+ self .lora_B = nn .ParameterList ([ org_module .weight .new_zeros ((out_dim , 1 )) for _ in range ( self .lora_dim )] )
58
58
59
59
# same as microsoft's
60
- torch .nn .init .kaiming_uniform_ (self .lora_A , a = math .sqrt (5 ))
61
- torch .nn .init .zeros_ (self .lora_B )
60
+ for lora in self .lora_A :
61
+ torch .nn .init .kaiming_uniform_ (lora , a = math .sqrt (5 ))
62
+ for lora in self .lora_B :
63
+ torch .nn .init .zeros_ (lora )
62
64
63
65
self .multiplier = multiplier
64
66
self .org_module = org_module # remove in applying
@@ -76,38 +78,18 @@ def forward(self, x):
76
78
trainable_rank = trainable_rank - trainable_rank % self .unit # make sure the rank is a multiple of unit
77
79
78
80
# 一部のパラメータを固定して、残りのパラメータを学習する
79
-
80
- # make lora_A
81
- if trainable_rank > 0 :
82
- lora_A_nt1 = [self .lora_A [:trainable_rank ].detach ()]
83
- else :
84
- lora_A_nt1 = []
85
-
86
- lora_A_t = self .lora_A [trainable_rank : trainable_rank + self .unit ]
87
-
88
- if trainable_rank < self .lora_dim - self .unit :
89
- lora_A_nt2 = [self .lora_A [trainable_rank + self .unit :].detach ()]
90
- else :
91
- lora_A_nt2 = []
92
-
93
- lora_A = torch .cat (lora_A_nt1 + [lora_A_t ] + lora_A_nt2 , dim = 0 )
94
-
95
- # make lora_B
96
- if trainable_rank > 0 :
97
- lora_B_nt1 = [self .lora_B [:, :trainable_rank ].detach ()]
98
- else :
99
- lora_B_nt1 = []
100
-
101
- lora_B_t = self .lora_B [:, trainable_rank : trainable_rank + self .unit ]
102
-
103
- if trainable_rank < self .lora_dim - self .unit :
104
- lora_B_nt2 = [self .lora_B [:, trainable_rank + self .unit :].detach ()]
105
- else :
106
- lora_B_nt2 = []
107
-
108
- lora_B = torch .cat (lora_B_nt1 + [lora_B_t ] + lora_B_nt2 , dim = 1 )
109
-
110
- # print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size())
81
+ for i in range (0 , trainable_rank ):
82
+ self .lora_A [i ].requires_grad = False
83
+ self .lora_B [i ].requires_grad = False
84
+ for i in range (trainable_rank , trainable_rank + self .unit ):
85
+ self .lora_A [i ].requires_grad = True
86
+ self .lora_B [i ].requires_grad = True
87
+ for i in range (trainable_rank + self .unit , self .lora_dim ):
88
+ self .lora_A [i ].requires_grad = False
89
+ self .lora_B [i ].requires_grad = False
90
+
91
+ lora_A = torch .cat (tuple (self .lora_A ), dim = 0 )
92
+ lora_B = torch .cat (tuple (self .lora_B ), dim = 1 )
111
93
112
94
# calculate with lora_A and lora_B
113
95
if self .is_conv2d_3x3 :
@@ -116,13 +98,13 @@ def forward(self, x):
116
98
else :
117
99
ab = x
118
100
if self .is_conv2d :
119
- ab = ab .reshape (ab .size (0 ), ab .size (1 ), - 1 ).transpose (1 , 2 )
101
+ ab = ab .reshape (ab .size (0 ), ab .size (1 ), - 1 ).transpose (1 , 2 ) # (N, C, H, W) -> (N, H*W, C)
120
102
121
103
ab = torch .nn .functional .linear (ab , lora_A )
122
104
ab = torch .nn .functional .linear (ab , lora_B )
123
105
124
106
if self .is_conv2d :
125
- ab = ab .transpose (1 , 2 ).reshape (ab .size (0 ), - 1 , * x .size ()[2 :])
107
+ ab = ab .transpose (1 , 2 ).reshape (ab .size (0 ), - 1 , * x .size ()[2 :]) # (N, H*W, C) -> (N, C, H, W)
126
108
127
109
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
128
110
result = result + ab * self .scale * math .sqrt (self .lora_dim / (trainable_rank + self .unit ))
@@ -131,34 +113,52 @@ def forward(self, x):
131
113
return result
132
114
133
115
def state_dict (self , destination = None , prefix = "" , keep_vars = False ):
134
- # state dictを通常のLoRAと同じにする
135
- state_dict = super ().state_dict (destination , prefix , keep_vars )
116
+ # state dictを通常のLoRAと同じにする:
117
+ # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
118
+ sd = super ().state_dict (destination , prefix , keep_vars )
136
119
137
- lora_A_weight = state_dict . pop ( self .lora_name + ". lora_A" )
120
+ lora_A_weight = torch . cat ( tuple ( self .lora_A ), dim = 0 )
138
121
if self .is_conv2d and not self .is_conv2d_3x3 :
139
122
lora_A_weight = lora_A_weight .unsqueeze (- 1 ).unsqueeze (- 1 )
140
- state_dict [self .lora_name + ".lora_down.weight" ] = lora_A_weight
141
123
142
- lora_B_weight = state_dict . pop ( self .lora_name + ". lora_B" )
124
+ lora_B_weight = torch . cat ( tuple ( self .lora_B ), dim = 1 )
143
125
if self .is_conv2d and not self .is_conv2d_3x3 :
144
126
lora_B_weight = lora_B_weight .unsqueeze (- 1 ).unsqueeze (- 1 )
145
- state_dict [self .lora_name + ".lora_up.weight" ] = lora_B_weight
146
127
147
- return state_dict
128
+ sd [self .lora_name + ".lora_down.weight" ] = lora_A_weight if keep_vars else lora_A_weight .detach ()
129
+ sd [self .lora_name + ".lora_up.weight" ] = lora_B_weight if keep_vars else lora_B_weight .detach ()
130
+
131
+ i = 0
132
+ while True :
133
+ key_a = f"{ self .lora_name } .lora_A.{ i } "
134
+ key_b = f"{ self .lora_name } .lora_B.{ i } "
135
+ if key_a in sd :
136
+ sd .pop (key_a )
137
+ sd .pop (key_b )
138
+ else :
139
+ break
140
+ i += 1
141
+ return sd
148
142
149
143
def load_state_dict (self , state_dict , strict = True ):
150
144
# 通常のLoRAと同じstate dictを読み込めるようにする
151
145
state_dict = state_dict .copy ()
152
146
153
- lora_A_weight = state_dict .pop (self .lora_name + ".lora_down.weight" )
154
- if self .is_conv2d and not self .is_conv2d_3x3 :
155
- lora_A_weight = lora_A_weight .squeeze (- 1 ).squeeze (- 1 )
156
- state_dict [self .lora_name + ".lora_A" ] = lora_A_weight
147
+ lora_A_weight = state_dict .pop (self .lora_name + ".lora_down.weight" , None )
148
+ lora_B_weight = state_dict .pop (self .lora_name + ".lora_up.weight" , None )
157
149
158
- lora_B_weight = state_dict .pop (self .lora_name + ".lora_up.weight" )
150
+ if lora_A_weight is None or lora_B_weight is None :
151
+ if strict :
152
+ raise KeyError (f"{ self .lora_name } .lora_down/up.weight is not found" )
153
+ else :
154
+ return
155
+
159
156
if self .is_conv2d and not self .is_conv2d_3x3 :
157
+ lora_A_weight = lora_A_weight .squeeze (- 1 ).squeeze (- 1 )
160
158
lora_B_weight = lora_B_weight .squeeze (- 1 ).squeeze (- 1 )
161
- state_dict [self .lora_name + ".lora_B" ] = lora_B_weight
159
+
160
+ state_dict .update ({f"{ self .lora_name } .lora_A.{ i } " : nn .Parameter (lora_A_weight [i ]) for i in range (lora_A_weight .size (0 ))})
161
+ state_dict .update ({f"{ self .lora_name } .lora_B.{ i } " : nn .Parameter (lora_B_weight [:, i ]) for i in range (lora_B_weight .size (1 ))})
162
162
163
163
super ().load_state_dict (state_dict , strict = strict )
164
164
0 commit comments