We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 46aa723 commit 12c898eCopy full SHA for 12c898e
keras_nlp/src/utils/transformers/convert_bart.py
@@ -38,6 +38,20 @@ def convert_weights(backbone, loader):
38
keras_variable=backbone.token_embedding.embeddings,
39
hf_weight_key="shared.weight",
40
)
41
+ loader.port_weight(
42
+ keras_variable=backbone.encoder_position_embedding.position_embeddings,
43
+ hf_weight_key="encoder.embed_positions.weight",
44
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
45
+ hf_tensor[2:, :], keras_shape
46
+ ),
47
+ )
48
49
+ keras_variable=backbone.decoder_position_embedding.position_embeddings,
50
+ hf_weight_key="decoder.embed_positions.weight",
51
52
53
54
55
56
# Encoder blocks
57
for index in range(backbone.num_layers):
0 commit comments