Skip to content

Commit

Permalink
Correct LoRA weights merging (keras-team#1784)
Browse files Browse the repository at this point in the history
Correction of the merging code between the model's original layer weights and the LoRA model weights.
This respect the principle of LoRA to dispose of the LoRA layers once we don't plan on training it more bur more importantly allows us to save and load the model as a ".keras" file.
  • Loading branch information
BastienHot authored Mar 21, 2024
1 parent 9f287fd commit 02c9bae
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,10 @@ def call(self, inputs):
B_weights = value_lora_layer.B.kernel # (1, 12, 64) (b, c, d)
increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
value_lora_layer.original_layer.kernel.assign_add(increment_weights)

# Put back in place the original layers with updated weights
self_attention_layer._query_dense = query_lora_layer.original_layer
self_attention_layer._value_dense = value_lora_layer.original_layer

"""
We are now all set to generate text with our LoRA model :).
Expand Down

0 comments on commit 02c9bae

Please sign in to comment.