@@ -125,7 +125,9 @@ def compute_mask(self, inputs, mask=None):
125
125
def compute_output_shape (self , input_shape ):
126
126
return input_shape + (self .output_dim ,)
127
127
128
- def enable_lora (self , rank ):
128
+ def enable_lora (
129
+ self , rank , a_initializer = "he_uniform" , b_initializer = "zeros"
130
+ ):
129
131
if self .embeddings_constraint :
130
132
raise ValueError (
131
133
"Lora is incompatible with embedding constraints. "
@@ -145,13 +147,13 @@ def enable_lora(self, rank):
145
147
self .lora_embeddings_a = self .add_weight (
146
148
name = "lora_embeddings_a" ,
147
149
shape = (self .embeddings .shape [0 ], rank ),
148
- initializer = "zeros" ,
150
+ initializer = initializers . get ( a_initializer ) ,
149
151
regularizer = self .embeddings_regularizer ,
150
152
)
151
153
self .lora_embeddings_b = self .add_weight (
152
154
name = "lora_embeddings_b" ,
153
155
shape = (rank , self .embeddings .shape [1 ]),
154
- initializer = "zeros" ,
156
+ initializer = initializers . get ( b_initializer ) ,
155
157
regularizer = self .embeddings_regularizer ,
156
158
)
157
159
self .embeddings .trainable = False
0 commit comments