Skip to content

generate() produces incoherent output when inputs_embeds has length 1 #41863

@tyarkoni

Description

@tyarkoni

System Info

  • transformers version: 4.57.1
  • Platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.35
  • Python version: 3.10.18
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.6.2
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 PCIe

Who can help?

@xadupre (original author of code in question), @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to Reproduce

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load any causal LM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.eval()

# Create a single embedding (simulating a prefix like a style token)
single_embedding = torch.randn(1, 1, 768)  # [batch=1, seq_len=1, hidden_dim=768]

# Generate with length-1 inputs_embeds
with torch.no_grad():
    outputs = model.generate(
        inputs_embeds=single_embedding,
        max_length=20,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.eos_token_id,
    )

# Decode and observe gibberish
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
# Output will be incoherent repetitive tokens like "the the the I if the..."

Comparison with working case (length ≥ 2):

# Add a second embedding (e.g., BOS token embedding)
bos_embedding = model.get_input_embeddings()(torch.tensor([[tokenizer.bos_token_id]]))
two_embeddings = torch.cat([single_embedding, bos_embedding], dim=1)  # [1, 2, 768]

# Generate with length-2 inputs_embeds
with torch.no_grad():
    outputs = model.generate(
        inputs_embeds=two_embeddings,
        max_length=20,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.eos_token_id,
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
# Output will be coherent

Expected behavior

Expected Behavior

generate() should produce coherent text regardless of whether inputs_embeds has length 1 or length > 1, as long as the embeddings are valid.

Actual Behavior

With length-1 inputs_embeds, generate() produces incoherent, repetitive gibberish that appears to be high-frequency tokens without proper conditioning on previous context.

Suggested Fix

The _cache_dependant_input_preparation method needs to properly handle the transition from inputs_embeds mode to input_ids mode after the first generation step. Specifically:

  1. After the first token is generated from inputs_embeds, set inputs_embeds = None for subsequent iterations
  2. Or, maintain proper bookkeeping so that the embeddings prefix is correctly tracked throughout the autoregressive loop
  3. Or, ensure Exception 4 logic properly handles the case where we've transitioned from embeddings to token IDs

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions