File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed
keras_nlp/src/utils/transformers Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -38,6 +38,20 @@ def convert_weights(backbone, loader):
38
38
keras_variable = backbone .token_embedding .embeddings ,
39
39
hf_weight_key = "shared.weight" ,
40
40
)
41
+ loader .port_weight (
42
+ keras_variable = backbone .encoder_position_embedding .position_embeddings ,
43
+ hf_weight_key = "encoder.embed_positions.weight" ,
44
+ hook_fn = lambda hf_tensor , keras_shape : np .reshape (
45
+ hf_tensor [2 :, :], keras_shape
46
+ ),
47
+ )
48
+ loader .port_weight (
49
+ keras_variable = backbone .decoder_position_embedding .position_embeddings ,
50
+ hf_weight_key = "decoder.embed_positions.weight" ,
51
+ hook_fn = lambda hf_tensor , keras_shape : np .reshape (
52
+ hf_tensor [2 :, :], keras_shape
53
+ ),
54
+ )
41
55
42
56
# Encoder blocks
43
57
for index in range (backbone .num_layers ):
You can’t perform that action at this time.
0 commit comments