Skip to content

Commit 482cb28

Browse files
authored
Fix tie_word_embeddings handling for GGUF models (#35085)
* fix tie_word_embeddings Signed-off-by: Isotr0py <2037008807@qq.com> * fix Signed-off-by: Isotr0py <2037008807@qq.com> --------- Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 3544705 commit 482cb28

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/transformers/modeling_gguf_pytorch_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
291291
# FIXME: Currnetly this implementation is only for flan-t5 architecture.
292292
# It needs to be developed for supporting legacy t5.
293293
elif "t5" in architecture or "t5encoder" in architecture:
294-
parsed_parameters["config"]["tie_word_embeddings"] = False
295294
parsed_parameters["config"]["is_gated_act"] = True
296295
updated_architecture = "t5"
297296
else:
@@ -326,6 +325,12 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
326325
if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES:
327326
raise ValueError(f"Architecture {architecture + model_size} not supported")
328327

328+
# Handle tie_word_embeddings, if lm_head.weight is not present in tensors,
329+
# tie_word_embeddings is true otherwise false
330+
parsed_parameters["config"]["tie_word_embeddings"] = all(
331+
"output.weight" != tensor.name for tensor in reader.tensors
332+
)
333+
329334
# List all key-value pairs in a columnized format
330335
for gguf_key, field in reader.fields.items():
331336
gguf_key = gguf_key.replace(architecture, updated_architecture)

0 commit comments

Comments
 (0)