Skip to content

Commit 6e6bb07

Browse files
committed
Make lora initializers configurable.
1 parent 096b848 commit 6e6bb07

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

keras/layers/core/dense.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def compute_output_shape(self, input_shape):
146146
output_shape[-1] = self.units
147147
return tuple(output_shape)
148148

149-
def enable_lora(self, rank):
149+
def enable_lora(
150+
self, rank, a_initializer="he_uniform", b_initializer="zeros"
151+
):
150152
if self.kernel_constraint:
151153
raise ValueError(
152154
"Lora is incompatible with kernel constraints. "
@@ -166,11 +168,11 @@ def enable_lora(self, rank):
166168
self.lora_kernel_a = self.add_weight(
167169
name="lora_kernel_a",
168170
shape=(self.kernel.shape[0], rank),
169-
initializer="zeros",
171+
initializer=initializers.get(a_initializer),
170172
regularizer=self.kernel_regularizer,
171173
)
172174
self.lora_kernel_b = self.add_weight(
173-
name="lora_kernel_b",
175+
name=initializers.get(b_initializer),
174176
shape=(rank, self.kernel.shape[1]),
175177
initializer="zeros",
176178
regularizer=self.kernel_regularizer,

keras/layers/core/einsum_dense.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def call(self, inputs):
222222
x = self.activation(x)
223223
return x
224224

225-
def enable_lora(self, rank):
225+
def enable_lora(
226+
self, rank, a_initializer="he_uniform", b_initializer="zeros"
227+
):
226228
if self.kernel_constraint:
227229
raise ValueError(
228230
"Lora is incompatible with kernel constraints. "
@@ -243,13 +245,13 @@ def enable_lora(self, rank):
243245
self.lora_kernel_a = self.add_weight(
244246
name="lora_kernel_a",
245247
shape=(self.kernel.shape[:-1] + (rank,)),
246-
initializer="zeros",
248+
initializer=initializers.get(a_initializer),
247249
regularizer=self.kernel_regularizer,
248250
)
249251
self.lora_kernel_b = self.add_weight(
250252
name="lora_kernel_b",
251253
shape=(rank, self.kernel.shape[-1]),
252-
initializer="zeros",
254+
initializer=initializers.get(b_initializer),
253255
regularizer=self.kernel_regularizer,
254256
)
255257
self.kernel.trainable = False

keras/layers/core/embedding.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def compute_mask(self, inputs, mask=None):
125125
def compute_output_shape(self, input_shape):
126126
return input_shape + (self.output_dim,)
127127

128-
def enable_lora(self, rank):
128+
def enable_lora(
129+
self, rank, a_initializer="he_uniform", b_initializer="zeros"
130+
):
129131
if self.embeddings_constraint:
130132
raise ValueError(
131133
"Lora is incompatible with embedding constraints. "
@@ -145,13 +147,13 @@ def enable_lora(self, rank):
145147
self.lora_embeddings_a = self.add_weight(
146148
name="lora_embeddings_a",
147149
shape=(self.embeddings.shape[0], rank),
148-
initializer="zeros",
150+
initializer=initializers.get(a_initializer),
149151
regularizer=self.embeddings_regularizer,
150152
)
151153
self.lora_embeddings_b = self.add_weight(
152154
name="lora_embeddings_b",
153155
shape=(rank, self.embeddings.shape[1]),
154-
initializer="zeros",
156+
initializer=initializers.get(b_initializer),
155157
regularizer=self.embeddings_regularizer,
156158
)
157159
self.embeddings.trainable = False

0 commit comments

Comments
 (0)