Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3 - RuntimeError: shape '[-1, 0]' is invalid for input of size 41041920 #32170

Closed
2 of 4 tasks
jacob-morrison opened this issue Jul 23, 2024 · 11 comments · Fixed by #32192
Closed
2 of 4 tasks

Llama 3 - RuntimeError: shape '[-1, 0]' is invalid for input of size 41041920 #32170

jacob-morrison opened this issue Jul 23, 2024 · 11 comments · Fixed by #32192
Labels

Comments

@jacob-morrison
Copy link

jacob-morrison commented Jul 23, 2024

System Info

transformers version 4.43.1, other package versions here: https://github.com/allenai/open-instruct/blob/main/requirements.txt

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running: unset CUDA_LAUNCH_BLOCKING && accelerate launch --mixed_precision bf16 --num_machines 2 --num_processes 16 --machine_rank $BEAKER_REPLICA_RANK --main_process_ip $BEAKER_LEADER_REPLICA_HOSTNAME --main_process_port 29400 --use_deepspeed --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf --deepspeed_multinode_launcher standard open_instruct/finetune.py --model_name_or_path meta-llama/Meta-Llama-3.1-8B --tokenizer_name meta-llama/Meta-Llama-3.1-8B --use_slow_tokenizer --dataset_name allenai/tulu-v2-sft-mixture --use_flash_attn --max_seq_length 4096 --preprocessing_num_workers 16 --per_device_train_batch_size 1 --gradient_accumulation_steps 8 --learning_rate 5e-6 --lr_scheduler_type linear --warmup_ratio 0.03 --weight_decay 0. --num_train_epochs 2 --output_dir /output/ --with_tracking --report_to tensorboard --logging_steps 1 --reduce_loss sum using open-instruct

we encounter this error on the first step of finetuning:

2024-07-23T21:19:48.544516135Z /opt/miniconda3/lib/python3.10/site-packages/transformers/data/data_collator.py:656: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:274.)
2024-07-23T21:19:48.544518524Z batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
2024-07-23T21:19:49.155378393Z [rank2]: Traceback (most recent call last):
2024-07-23T21:19:49.155406373Z [rank2]: File "/stage/open_instruct/finetune.py", line 683, in
2024-07-23T21:19:49.155409168Z [rank2]: main()
2024-07-23T21:19:49.155410556Z [rank2]: File "/stage/open_instruct/finetune.py", line 602, in main
2024-07-23T21:19:49.155412476Z [rank2]: outputs = model(**batch, use_cache=False)
2024-07-23T21:19:49.155413980Z [rank2]: File "/opt/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
2024-07-23T21:19:49.155415839Z [rank2]: return self._call_impl(*args, **kwargs)
2024-07-23T21:19:49.155417058Z [rank2]: File "/opt/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
2024-07-23T21:19:49.155418501Z [rank2]: return forward_call(*args, **kwargs)
2024-07-23T21:19:49.155419655Z [rank2]: File "/opt/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
2024-07-23T21:19:49.155421076Z [rank2]: ret_val = func(*args, **kwargs)
2024-07-23T21:19:49.155422228Z [rank2]: File "/opt/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
2024-07-23T21:19:49.155423640Z [rank2]: loss = self.module(*inputs, **kwargs)
2024-07-23T21:19:49.155424827Z [rank2]: File "/opt/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
2024-07-23T21:19:49.155440561Z [rank2]: return self._call_impl(*args, **kwargs)
2024-07-23T21:19:49.155441869Z [rank2]: File "/opt/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
2024-07-23T21:19:49.155443280Z [rank2]: result = forward_call(*args, **kwargs)
2024-07-23T21:19:49.155444498Z [rank2]: File "/opt/miniconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1168, in forward
2024-07-23T21:19:49.155446074Z [rank2]: shift_logits = shift_logits.view(-1, self.config.vocab_size)
2024-07-23T21:19:49.155447329Z [rank2]: RuntimeError: shape '[-1, 0]' is invalid for input of size 41041920

after updating to transformers 4.43.1 to support Llama 3.1 finetuning. Any idea what's going on? We're not sure if other packages need to be updated, if this is a known issue, or something else.

Expected behavior

Llama 3.1 finetuning to run successfully

@ArthurZucker
Copy link
Collaborator

What's weird is that we did not change llama code per say, but we did change --use_flash_attn with #31629 which I think could be causing this. However, the error seems to be a config error for me: self.config.vocab_size seems set to 0.

@ArthurZucker
Copy link
Collaborator

Arf, could you try without this commit that I linked?

@ArthurZucker
Copy link
Collaborator

Appart from this one, #32135 is the only potential culprit I see. I don't have access to the script so if any of you can isolate a small reproducer it would help a lot!

@seokhyunan
Copy link

I encountered the same issue. I tried running the script to fine-tune llama 3.0, 3.1, and mistral 7B 0.3 with transformers versions 4.43.0 and 4.43.1 but encountered the same error. However, version 4.42.4 works fine for all the base models. @ArthurZucker

@seokhyunan
Copy link

seokhyunan commented Jul 24, 2024

Reverting the commit (#31629) did not resolve the issue.

@ArthurZucker
Copy link
Collaborator

Ouch, maybe #31446 if you can revert it.
Do you know if you can print the self.config.vocab size to make sure it's not 0? (or maybe the error is not traced properly we might need CUDA_LAUNCH_BLOCKING to have a proper traceback

@seokhyunan
Copy link

@ArthurZucker I confirmed that self.config.vocab_size is set to zero. However, I'm struggling with a revert conflict regarding #31446. Additionally, you might want to check the fine-tuning scripts used: finetune.py and finetune_lora_with_accelerate.sh.

In python3.11/site-packages/transformers/models/mistral/modeling_mistral.py:

print(f"#### SELF.CONFIG.VOCAB_SIZE: {self.config.vocab_size}")
shift_logits = shift_logits.view(-1, self.config.vocab_size)

Output:

#### SELF.CONFIG.VOCAB_SIZE: 0
[rank3]: Traceback (most recent call last):
[...]
[rank3]:   File "[...]/python3.11/site-packages/transformers/models/mistral/modeling_mistral.py", line 1053, in forward
[rank3]:     shift_logits = shift_logits.view(-1, self.config.vocab_size)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: shape '[-1, 0]' is invalid for input of size 67076096

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 24, 2024

from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import torch
device = "cuda"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(ckpt, attn_implementation="flash_attention_2", torch_dtype=torch.float16)
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt = ["Explain the thre body problem", "What is this?"]
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")
outputs = model(**inputs, labels = inputs["input_ids"])
print(outputs.loss)
outputs = model(inputs["input_ids"], labels = inputs["input_ids"])

I ran something like this which worked for me so I don't really know what's going one here 😓

@seokhyunan
Copy link

seokhyunan commented Jul 24, 2024

@ArthurZucker I found that model.resize_token_embeddings in finetune.py (lines 366-368) sets vocab_size to zero only in transformers versions 4.43.0 and 4.43.1. Do you have any ideas on how to resolve this issue?

print(f"### Model Config 1: {model.config}")
# resize does its own gather
if len(tokenizer) > embedding_size:
    # pad to multiple for tensor cores.
    model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
print(f"### Length of tokenizer: {len(tokenizer)}")
print(f"### Model Config 2: {model.config}")
exit(1)

Output:

### Model Config 1: LlamaConfig {
  "_name_or_path": "models/base/Meta-Llama-3-8B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.43.1",
  "use_cache": true,
  "vocab_size": 128256
}

### Length of tokenizer: 128257
### Model Config 2: LlamaConfig {
  "_name_or_path": "models/base/Meta-Llama-3-8B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.43.1",
  "use_cache": true,
  "vocab_size": 0
}

@seokhyunan
Copy link

@ArthurZucker Reverting the commit (#31979) resolved the issue.

@zucchini-nlp
Copy link
Member

Hey! Saw your comment under the linked PR. I just tried below from the current main branch and didn't get 0 for model.vocab_size. Changing the embeddings to higher or lower tokens also didn't affect anything. I guess it's only when deepspeed is used, will open a PR soon to fix it

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
vocab_size = model.vocab_size
model.resize_token_embeddings(vocab_size, pad_to_multiple_of=8)

assert model.vocab_size != 0
assert model.config.vocab_size == vocab_size
assert model.vocab_size == vocab_size

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants