Skip to content

Llama inference instability in fp16 producing inf in the middle of the model #27179

@fxmarty

Description

@fxmarty

System Info

  • transformers version: 4.35.0.dev0
  • Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
  • Python version: 3.9.16
  • Huggingface_hub version: 0.17.3
  • Safetensors version: 0.3.1
  • Accelerate version: 0.25.0.dev0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu118 (True)
  • Using GPU in script?: A100

Who can help?

@ydshieh @fxmarty @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi, I encounter inference instability with llama running in fp16 when left padding is used, and especially when full rows are masked out in the 4D attention mask.

At some point in the forward, inf values may appear in the intermediate logits, ultimately leading to tensors filled with nan and raising the error:

Traceback (most recent call last):
  File "=debug.py", line 38, in <module>
    outputs = model.generate(
  File "/fsx/felix/condaenvs/fx/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 1704, in generate
    return self.sample(
  File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 2822, in sample
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Note that the inf specifically appear at a padding position.

Reproduction:

from transformers import AutoTokenizer, pipeline, logging, AutoModelForCausalLM
import torch

model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
token = "[specify your token]"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, token=token)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, token=token)

sentence = "Felix Marty is a French"

# Alternatively, the issue can be reproduced with:
# sentence = "Elon Musk is a South"
# max_length=9

inp = tokenizer(sentence, return_tensors='pt', padding="max_length", max_length=9).to("cuda")

print("inp", inp["input_ids"].shape)
print("inp", inp)
torch.set_printoptions(threshold=10000000)

print("\n\n*** Generate:")
with torch.no_grad():
    outputs = model.generate(
        **inp,
        max_new_tokens=10,
        do_sample=True,
        top_p=0.9,
        temperature=float(0.01),
        top_k=40
    )

print(tokenizer.batch_decode(outputs))

Printing torch.all(torch.isfinite()) at some points in the model, it appears the inf start to appear in the MLP at self.gate_proj(x)) * self.up_proj(x) and things go crazy from there.

What's interesting is that for example fixing (two left padding tokens)
image

to

image

solves the issue.

It makes me think that the solution implemented for SDPA to avoid fully masked rows in the attention mask may actually be required for some other cases as this one #26572 - but it is unclear why it relates to overflow here.

WDYT @gante @ydshieh? Is this something you have ever observed?

Expected behavior

No inf spawning in the middle of inference with fp16 model

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions