Skip to content

Commit 12c898e

Browse files
cosmo3769pkgoogle
authored andcommitted
added missing port (#1789)
1 parent 46aa723 commit 12c898e

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

keras_nlp/src/utils/transformers/convert_bart.py

+14
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@ def convert_weights(backbone, loader):
3838
keras_variable=backbone.token_embedding.embeddings,
3939
hf_weight_key="shared.weight",
4040
)
41+
loader.port_weight(
42+
keras_variable=backbone.encoder_position_embedding.position_embeddings,
43+
hf_weight_key="encoder.embed_positions.weight",
44+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
45+
hf_tensor[2:, :], keras_shape
46+
),
47+
)
48+
loader.port_weight(
49+
keras_variable=backbone.decoder_position_embedding.position_embeddings,
50+
hf_weight_key="decoder.embed_positions.weight",
51+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
52+
hf_tensor[2:, :], keras_shape
53+
),
54+
)
4155

4256
# Encoder blocks
4357
for index in range(backbone.num_layers):

0 commit comments

Comments
 (0)