Description
Bug description
When running the example to train Llama3 from scratch with the Tensor Parallel example here
https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel
using TorchTune as DataLoader of an unstructured dataset with a given batch size, I get the error indicating that all samples in all batches must have the same size:
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 170, in collate
[rank4]: raise RuntimeError('each element in list of batch should be of equal size')
[rank4]: RuntimeError: each element in list of batch should be of equal size
The only modification I did was to change the DataLoader using torchtune in the example data.py
script here. Details following.
What version are you seeing the problem on?
v1.x
How to reproduce the bug
install `torchtune` and download the text dataset:
curl https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt > t8.shakespeare.txt
and create this function
```python
from torchtune.datasets import text_completion_dataset
def load_dataset(seq_length=2048):
#https://pytorch.org/torchtune/main/generated/torchtune.datasets.text_completion_dataset.html#torchtune.datasets.text_completion_dataset
dataset = text_completion_dataset(
tokenizer,
source="text",
column="text",
data_files="t8.shakespeare.txt",
split="train",
max_seq_len=seq_length,
packed=False
)
return dataset
add this method to the example file data.py here
import torch
from torch.utils.data import Dataset
def get_random_tokens(vocab_size,size):
# random tokens list
tokens = torch.randint(
vocab_size,
size=size,
# Set a seed to make this toy dataset the same on each rank
# Fabric will add a `DistributedSampler` to shard the data correctly
generator=torch.Generator().manual_seed(42),
)
return tokens
def get_text_completion_dataset_tokens(seq_length,batch_size):
# https://pytorch.org/torchtune/main/generated/torchtune.utils.padded_collate.html#torchtune.utils.padded_collate
from torchtune.utils import padded_collate
dataset = load_dataset(seq_length=seq_length)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False, collate_fn=padded_collate)
tokens = []
for sample in dataloader:
batch = sample['tokens'].tolist()
for sample in batch:
tokens.append( sample )
return tokens
class RandomTokenDataset(Dataset):
def __init__(self, vocab_size: int, seq_length: int, batch_size:int):
self.vocab_size = vocab_size
self.seq_length = seq_length
self.batch_size = batch_size
#self.tokens = get_random_tokens(self.vocab_size,(len(self), self.seq_length + 1))
self.tokens = get_text_completion_dataset_tokens(seq_length,batch_size)
def __len__(self) -> int:
return 128
def __getitem__(self, item: int):
return self.tokens[item]
This will end with the error above.
### Error messages and logs
[rank4]: Traceback (most recent call last):
[rank4]: File "train.py", line 233, in <module>
[rank4]: train()
[rank4]: File "train.py", line 222, in train
[rank4]: trainer.fit(model)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
[rank4]: call._call_and_handle_interrupt(
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank4]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank4]: return function(*args, **kwargs)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
[rank4]: self._run(model, ckpt_path=ckpt_path)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
[rank4]: results = self._run_stage()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
[rank4]: self.fit_loop.run()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank4]: self.advance()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank4]: self.epoch_loop.run(self._data_fetcher)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank4]: self.advance(data_fetcher)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 212, in advance
[rank4]: batch, _, __ = next(data_fetcher)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 133, in __next__
[rank4]: batch = super().__next__()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 60, in __next__
[rank4]: batch = next(self.iterator)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 341, in __next__
[rank4]: out = next(self._iterator)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 78, in __next__
[rank4]: out[i] = next(self.iterators[i])
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
[rank4]: data = self._next_data()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
[rank4]: return self._process_data(data)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
[rank4]: data.reraise()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/_utils.py", line 706, in reraise
[rank4]: raise exception
[rank4]: RuntimeError: Caught RuntimeError in DataLoader worker process 0.
[rank4]: Original Traceback (most recent call last):
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
[rank4]: data = fetcher.fetch(index) # type: ignore[possibly-undefined]
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
[rank4]: return self.collate_fn(data)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 317, in default_collate
[rank4]: return collate(batch, collate_fn_map=default_collate_fn_map)
Environment
Current environment
<details>
<summary>Current environment</summary>
* CUDA:
- GPU:
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- available: True
- version: 12.1
* Lightning:
- lightning: 2.3.3
- lightning-utilities: 0.11.6
- pytorch-lightning: 1.9.5
- torch: 2.4.0
- torchao: 0.3.1
- torchmetrics: 1.4.1
- torchtune: 0.2.1
- torchvision: 0.19.0
* Packages:
- aiofiles: 22.1.0
- aiohappyeyeballs: 2.4.0
- aiohttp: 3.10.5
- aiosignal: 1.3.1
- aiosqlite: 0.18.0
- alembic: 1.10.2
- antlr4-python3-runtime: 4.9.3
- anyio: 3.6.2
- appdirs: 1.4.4
- argon2-cffi: 21.3.0
- argon2-cffi-bindings: 21.2.0
- argparse: 1.4.0
- asttokens: 2.2.1
- async-generator: 1.10
- async-timeout: 4.0.3
- attrs: 22.2.0
- autofaiss: 2.15.8
- babel: 2.12.1
- backcall: 0.2.0
- beautifulsoup4: 4.11.2
- bleach: 6.0.0
- blobfile: 2.1.1
- boto3: 1.35.2
- botocore: 1.35.2
- certifi: 2022.12.7
- certipy: 0.1.3
- cffi: 1.15.1
- charset-normalizer: 3.1.0
- click: 8.1.7
- coloredlogs: 15.0.1
- comm: 0.1.2
- cryptography: 39.0.2
- datasets: 2.21.0
- debugpy: 1.6.6
- decorator: 5.1.1
- defusedxml: 0.7.1
- dill: 0.3.8
- docker-pycreds: 0.4.0
- embedding-reader: 1.7.0
- executing: 1.2.0
- fairscale: 0.4.6
- faiss-cpu: 1.8.0.post1
- fastjsonschema: 2.16.3
- filelock: 3.15.4
- fire: 0.4.0
- flatbuffers: 24.3.25
- frozenlist: 1.4.1
- fsspec: 2024.6.1
- fuzzywuzzy: 0.18.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- greenlet: 2.0.2
- hnswlib: 0.7.0
- huggingface-hub: 0.24.6
- humanfriendly: 10.0
- idna: 3.4
- importlib-metadata: 6.1.0
- importlib-resources: 5.12.0
- ipykernel: 6.21.3
- ipython: 8.11.0
- ipython-genutils: 0.2.0
- jedi: 0.18.2
- jinja2: 3.1.2
- jmespath: 1.0.1
- joblib: 1.3.2
- json5: 0.9.11
- jsonschema: 4.17.3
- jupyter-client: 8.0.3
- jupyter-core: 5.3.0
- jupyter-events: 0.6.3
- jupyter-server: 2.5.0
- jupyter-server-fileid: 0.8.0
- jupyter-server-terminals: 0.4.4
- jupyter-server-ydoc: 0.8.0
- jupyter-telemetry: 0.1.0
- jupyter-ydoc: 0.2.3
- jupyterhub: 3.1.1
- jupyterlab: 3.6.2
- jupyterlab-pygments: 0.2.2
- jupyterlab-server: 2.20.0
- levenshtein: 0.23.0
- lightning: 2.3.3
- lightning-utilities: 0.11.6
- lxml: 4.9.4
- mako: 1.2.4
- markupsafe: 2.1.2
- matplotlib-inline: 0.1.6
- mistune: 2.0.5
- mpmath: 1.3.0
- multidict: 6.0.5
- multiprocess: 0.70.16
- nbclassic: 0.5.3
- nbclient: 0.7.2
- nbconvert: 7.2.10
- nbformat: 5.7.3
- nest-asyncio: 1.5.6
- networkx: 3.1
- nltk: 3.9.1
- notebook: 6.5.3
- notebook-shim: 0.2.2
- numpy: 1.24.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.6.20
- nvidia-nvtx-cu12: 12.1.105
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- onnx: 1.14.1
- onnxruntime-gpu: 1.14.1
- optimum: 1.16.2
- packaging: 23.0
- pamela: 1.0.0
- pandas: 1.3.5
- pandocfilters: 1.5.0
- parso: 0.8.3
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 10.4.0
- pip: 20.0.2
- pkgutil-resolve-name: 1.3.10
- platformdirs: 3.1.1
- prometheus-client: 0.16.0
- prompt-toolkit: 3.0.38
- protobuf: 5.27.3
- psutil: 5.9.5
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pyarrow: 15.0.2
- pycparser: 2.21
- pycryptodomex: 3.20.0
- pygments: 2.14.0
- pyopenssl: 23.0.0
- pyrsistent: 0.19.3
- python-dateutil: 2.8.2
- python-json-logger: 2.0.7
- python-levenshtein: 0.23.0
- pytorch-lightning: 1.9.5
- pytz: 2022.7.1
- pyyaml: 6.0
- pyzmq: 26.0.3
- rapidfuzz: 3.4.0
- regex: 2024.7.24
- requests: 2.28.2
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- ruamel.yaml: 0.17.21
- ruamel.yaml.clib: 0.2.7
- s3transfer: 0.10.2
- safetensors: 0.4.4
- scikit-learn: 1.3.2
- scipy: 1.10.1
- send2trash: 1.8.0
- sentence-transformers: 2.2.2
- sentencepiece: 0.2.0
- sentry-sdk: 2.13.0
- setproctitle: 1.3.3
- setuptools: 45.2.0
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.0
- soupsieve: 2.4
- sqlalchemy: 2.0.7
- stack-data: 0.6.2
- sympy: 1.13.2
- termcolor: 2.4.0
- terminado: 0.17.1
- threadpoolctl: 3.5.0
- tiktoken: 0.7.0
- tinycss2: 1.2.1
- tokenizers: 0.19.1
- tomli: 2.0.1
- torch: 2.4.0
- torchao: 0.3.1
- torchmetrics: 1.4.1
- torchtune: 0.2.1
- torchvision: 0.19.0
- tornado: 6.2
- tqdm: 4.66.1
- traitlets: 5.9.0
- transformers: 4.44.0
- triton: 3.0.0
- typing-extensions: 4.5.0
- urllib3: 1.26.15
- wandb: 0.17.7
- wcwidth: 0.2.6
- webencodings: 0.5.1
- websocket-client: 1.5.1
- wheel: 0.34.2
- xxhash: 3.5.0
- y-py: 0.5.9
- yarl: 1.9.4
- ypy-websocket: 0.8.4
- zipp: 3.15.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- release: 5.10.219-208.866.amzn2.x86_64
- version: #1 SMP Tue Jun 18 14:00:06 UTC 2024
</details>
More info
The reason of this modification is that the provided example is too abstract to actually test a train meaningful a complete i.e. using a real tokenizer, so that I have added torchtune
that provides both these functionalities.
It would be a good add on to have this example adapted in this way.