-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Description
System Info
transformersversion: 4.35.0.dev0- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.9.16
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.3.1
- Accelerate version: 0.25.0.dev0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.0+cu118 (True)
- Using GPU in script?: A100
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Hi, I encounter inference instability with llama running in fp16 when left padding is used, and especially when full rows are masked out in the 4D attention mask.
At some point in the forward, inf values may appear in the intermediate logits, ultimately leading to tensors filled with nan and raising the error:
Traceback (most recent call last):
File "=debug.py", line 38, in <module>
outputs = model.generate(
File "/fsx/felix/condaenvs/fx/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 1704, in generate
return self.sample(
File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 2822, in sample
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
Note that the inf specifically appear at a padding position.
Reproduction:
from transformers import AutoTokenizer, pipeline, logging, AutoModelForCausalLM
import torch
model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
token = "[specify your token]"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, token=token)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
with torch.device("cuda"):
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, token=token)
sentence = "Felix Marty is a French"
# Alternatively, the issue can be reproduced with:
# sentence = "Elon Musk is a South"
# max_length=9
inp = tokenizer(sentence, return_tensors='pt', padding="max_length", max_length=9).to("cuda")
print("inp", inp["input_ids"].shape)
print("inp", inp)
torch.set_printoptions(threshold=10000000)
print("\n\n*** Generate:")
with torch.no_grad():
outputs = model.generate(
**inp,
max_new_tokens=10,
do_sample=True,
top_p=0.9,
temperature=float(0.01),
top_k=40
)
print(tokenizer.batch_decode(outputs))Printing torch.all(torch.isfinite()) at some points in the model, it appears the inf start to appear in the MLP at self.gate_proj(x)) * self.up_proj(x) and things go crazy from there.
What's interesting is that for example fixing (two left padding tokens)

to
solves the issue.
It makes me think that the solution implemented for SDPA to avoid fully masked rows in the attention mask may actually be required for some other cases as this one #26572 - but it is unclear why it relates to overflow here.
WDYT @gante @ydshieh? Is this something you have ever observed?
Expected behavior
No inf spawning in the middle of inference with fp16 model
