-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Open
Labels
Description
System Info
transformersversion: 4.57.1- Platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.35
- Python version: 3.10.18
- Huggingface_hub version: 0.35.3
- Safetensors version: 0.6.2
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA H100 PCIe
Who can help?
@xadupre (original author of code in question), @zucchini-nlp
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Steps to Reproduce
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load any causal LM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.eval()
# Create a single embedding (simulating a prefix like a style token)
single_embedding = torch.randn(1, 1, 768) # [batch=1, seq_len=1, hidden_dim=768]
# Generate with length-1 inputs_embeds
with torch.no_grad():
outputs = model.generate(
inputs_embeds=single_embedding,
max_length=20,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)
# Decode and observe gibberish
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
# Output will be incoherent repetitive tokens like "the the the I if the..."Comparison with working case (length ≥ 2):
# Add a second embedding (e.g., BOS token embedding)
bos_embedding = model.get_input_embeddings()(torch.tensor([[tokenizer.bos_token_id]]))
two_embeddings = torch.cat([single_embedding, bos_embedding], dim=1) # [1, 2, 768]
# Generate with length-2 inputs_embeds
with torch.no_grad():
outputs = model.generate(
inputs_embeds=two_embeddings,
max_length=20,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
# Output will be coherentExpected behavior
Expected Behavior
generate() should produce coherent text regardless of whether inputs_embeds has length 1 or length > 1, as long as the embeddings are valid.
Actual Behavior
With length-1 inputs_embeds, generate() produces incoherent, repetitive gibberish that appears to be high-frequency tokens without proper conditioning on previous context.
Suggested Fix
The _cache_dependant_input_preparation method needs to properly handle the transition from inputs_embeds mode to input_ids mode after the first generation step. Specifically:
- After the first token is generated from
inputs_embeds, setinputs_embeds = Nonefor subsequent iterations - Or, maintain proper bookkeeping so that the embeddings prefix is correctly tracked throughout the autoregressive loop
- Or, ensure Exception 4 logic properly handles the case where we've transitioned from embeddings to token IDs