Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: each element in list of batch should be of equal size #20229

Closed
loretoparisi opened this issue Aug 26, 2024 · 1 comment
Closed
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@loretoparisi
Copy link

loretoparisi commented Aug 26, 2024

Bug description

When running the example to train Llama3 from scratch with the Tensor Parallel example here
https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel

using TorchTune as DataLoader of an unstructured dataset with a given batch size, I get the error indicating that all samples in all batches must have the same size:

[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 170, in collate
[rank4]:     raise RuntimeError('each element in list of batch should be of equal size')
[rank4]: RuntimeError: each element in list of batch should be of equal size

The only modification I did was to change the DataLoader using torchtune in the example data.py script here. Details following.

What version are you seeing the problem on?

v1.x

How to reproduce the bug

install `torchtune` and download the text dataset:


curl https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt > t8.shakespeare.txt 


and create this function

```python
from torchtune.datasets import text_completion_dataset
def load_dataset(seq_length=2048):
    #https://pytorch.org/torchtune/main/generated/torchtune.datasets.text_completion_dataset.html#torchtune.datasets.text_completion_dataset
    dataset = text_completion_dataset(
        tokenizer,
        source="text",
        column="text",
        data_files="t8.shakespeare.txt",
        split="train",
        max_seq_len=seq_length,
        packed=False
    )
    return dataset

add this method to the example file data.py here

import torch
from torch.utils.data import Dataset


def get_random_tokens(vocab_size,size):
    # random tokens list
    tokens = torch.randint(
        vocab_size,
        size=size,
        # Set a seed to make this toy dataset the same on each rank
        # Fabric will add a `DistributedSampler` to shard the data correctly
        generator=torch.Generator().manual_seed(42),
    )
    return tokens

def get_text_completion_dataset_tokens(seq_length,batch_size):
    # https://pytorch.org/torchtune/main/generated/torchtune.utils.padded_collate.html#torchtune.utils.padded_collate
    from torchtune.utils import padded_collate
    dataset = load_dataset(seq_length=seq_length)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False, collate_fn=padded_collate)
    tokens = []
    for sample in dataloader:
        batch = sample['tokens'].tolist()
        for sample in batch:
            tokens.append( sample )
    return tokens
                
class RandomTokenDataset(Dataset):
    def __init__(self, vocab_size: int, seq_length: int, batch_size:int):
        
        self.vocab_size = vocab_size
        self.seq_length = seq_length
        self.batch_size = batch_size
        
        #self.tokens = get_random_tokens(self.vocab_size,(len(self), self.seq_length + 1))
        self.tokens = get_text_completion_dataset_tokens(seq_length,batch_size)
        
    def __len__(self) -> int:
        return 128

    def __getitem__(self, item: int):
        return self.tokens[item]

This will end with the error above.



### Error messages and logs

[rank4]: Traceback (most recent call last):
[rank4]:   File "train.py", line 233, in <module>
[rank4]:     train()
[rank4]:   File "train.py", line 222, in train
[rank4]:     trainer.fit(model)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
[rank4]:     call._call_and_handle_interrupt(
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank4]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank4]:     return function(*args, **kwargs)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
[rank4]:     self._run(model, ckpt_path=ckpt_path)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
[rank4]:     results = self._run_stage()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
[rank4]:     self.fit_loop.run()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank4]:     self.advance()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank4]:     self.epoch_loop.run(self._data_fetcher)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank4]:     self.advance(data_fetcher)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 212, in advance
[rank4]:     batch, _, __ = next(data_fetcher)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 133, in __next__
[rank4]:     batch = super().__next__()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 60, in __next__
[rank4]:     batch = next(self.iterator)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 341, in __next__
[rank4]:     out = next(self._iterator)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 78, in __next__
[rank4]:     out[i] = next(self.iterators[i])
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
[rank4]:     data = self._next_data()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
[rank4]:     return self._process_data(data)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
[rank4]:     data.reraise()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/_utils.py", line 706, in reraise
[rank4]:     raise exception
[rank4]: RuntimeError: Caught RuntimeError in DataLoader worker process 0.
[rank4]: Original Traceback (most recent call last):
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
[rank4]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
[rank4]:     return self.collate_fn(data)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 317, in default_collate
[rank4]:     return collate(batch, collate_fn_map=default_collate_fn_map)

Environment

Current environment
<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:
                - NVIDIA A10G
                - NVIDIA A10G
                - NVIDIA A10G
                - NVIDIA A10G
        - available:         True
        - version:           12.1
* Lightning:
        - lightning:         2.3.3
        - lightning-utilities: 0.11.6
        - pytorch-lightning: 1.9.5
        - torch:             2.4.0
        - torchao:           0.3.1
        - torchmetrics:      1.4.1
        - torchtune:         0.2.1
        - torchvision:       0.19.0
* Packages:
        - aiofiles:          22.1.0
        - aiohappyeyeballs:  2.4.0
        - aiohttp:           3.10.5
        - aiosignal:         1.3.1
        - aiosqlite:         0.18.0
        - alembic:           1.10.2
        - antlr4-python3-runtime: 4.9.3
        - anyio:             3.6.2
        - appdirs:           1.4.4
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - argparse:          1.4.0
        - asttokens:         2.2.1
        - async-generator:   1.10
        - async-timeout:     4.0.3
        - attrs:             22.2.0
        - autofaiss:         2.15.8
        - babel:             2.12.1
        - backcall:          0.2.0
        - beautifulsoup4:    4.11.2
        - bleach:            6.0.0
        - blobfile:          2.1.1
        - boto3:             1.35.2
        - botocore:          1.35.2
        - certifi:           2022.12.7
        - certipy:           0.1.3
        - cffi:              1.15.1
        - charset-normalizer: 3.1.0
        - click:             8.1.7
        - coloredlogs:       15.0.1
        - comm:              0.1.2
        - cryptography:      39.0.2
        - datasets:          2.21.0
        - debugpy:           1.6.6
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.8
        - docker-pycreds:    0.4.0
        - embedding-reader:  1.7.0
        - executing:         1.2.0
        - fairscale:         0.4.6
        - faiss-cpu:         1.8.0.post1
        - fastjsonschema:    2.16.3
        - filelock:          3.15.4
        - fire:              0.4.0
        - flatbuffers:       24.3.25
        - frozenlist:        1.4.1
        - fsspec:            2024.6.1
        - fuzzywuzzy:        0.18.0
        - gitdb:             4.0.11
        - gitpython:         3.1.43
        - greenlet:          2.0.2
        - hnswlib:           0.7.0
        - huggingface-hub:   0.24.6
        - humanfriendly:     10.0
        - idna:              3.4
        - importlib-metadata: 6.1.0
        - importlib-resources: 5.12.0
        - ipykernel:         6.21.3
        - ipython:           8.11.0
        - ipython-genutils:  0.2.0
        - jedi:              0.18.2
        - jinja2:            3.1.2
        - jmespath:          1.0.1
        - joblib:            1.3.2
        - json5:             0.9.11
        - jsonschema:        4.17.3
        - jupyter-client:    8.0.3
        - jupyter-core:      5.3.0
        - jupyter-events:    0.6.3
        - jupyter-server:    2.5.0
        - jupyter-server-fileid: 0.8.0
        - jupyter-server-terminals: 0.4.4
        - jupyter-server-ydoc: 0.8.0
        - jupyter-telemetry: 0.1.0
        - jupyter-ydoc:      0.2.3
        - jupyterhub:        3.1.1
        - jupyterlab:        3.6.2
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-server: 2.20.0
        - levenshtein:       0.23.0
        - lightning:         2.3.3
        - lightning-utilities: 0.11.6
        - lxml:              4.9.4
        - mako:              1.2.4
        - markupsafe:        2.1.2
        - matplotlib-inline: 0.1.6
        - mistune:           2.0.5
        - mpmath:            1.3.0
        - multidict:         6.0.5
        - multiprocess:      0.70.16
        - nbclassic:         0.5.3
        - nbclient:          0.7.2
        - nbconvert:         7.2.10
        - nbformat:          5.7.3
        - nest-asyncio:      1.5.6
        - networkx:          3.1
        - nltk:              3.9.1
        - notebook:          6.5.3
        - notebook-shim:     0.2.2
        - numpy:             1.24.4
        - nvidia-cublas-cu12: 12.1.3.1
        - nvidia-cuda-cupti-cu12: 12.1.105
        - nvidia-cuda-nvrtc-cu12: 12.1.105
        - nvidia-cuda-runtime-cu12: 12.1.105
        - nvidia-cudnn-cu12: 9.1.0.70
        - nvidia-cufft-cu12: 11.0.2.54
        - nvidia-curand-cu12: 10.3.2.106
        - nvidia-cusolver-cu12: 11.4.5.107
        - nvidia-cusparse-cu12: 12.1.0.106
        - nvidia-nccl-cu12:  2.20.5
        - nvidia-nvjitlink-cu12: 12.6.20
        - nvidia-nvtx-cu12:  12.1.105
        - oauthlib:          3.2.2
        - omegaconf:         2.3.0
        - onnx:              1.14.1
        - onnxruntime-gpu:   1.14.1
        - optimum:           1.16.2
        - packaging:         23.0
        - pamela:            1.0.0
        - pandas:            1.3.5
        - pandocfilters:     1.5.0
        - parso:             0.8.3
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            10.4.0
        - pip:               20.0.2
        - pkgutil-resolve-name: 1.3.10
        - platformdirs:      3.1.1
        - prometheus-client: 0.16.0
        - prompt-toolkit:    3.0.38
        - protobuf:          5.27.3
        - psutil:            5.9.5
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyarrow:           15.0.2
        - pycparser:         2.21
        - pycryptodomex:     3.20.0
        - pygments:          2.14.0
        - pyopenssl:         23.0.0
        - pyrsistent:        0.19.3
        - python-dateutil:   2.8.2
        - python-json-logger: 2.0.7
        - python-levenshtein: 0.23.0
        - pytorch-lightning: 1.9.5
        - pytz:              2022.7.1
        - pyyaml:            6.0
        - pyzmq:             26.0.3
        - rapidfuzz:         3.4.0
        - regex:             2024.7.24
        - requests:          2.28.2
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - ruamel.yaml:       0.17.21
        - ruamel.yaml.clib:  0.2.7
        - s3transfer:        0.10.2
        - safetensors:       0.4.4
        - scikit-learn:      1.3.2
        - scipy:             1.10.1
        - send2trash:        1.8.0
        - sentence-transformers: 2.2.2
        - sentencepiece:     0.2.0
        - sentry-sdk:        2.13.0
        - setproctitle:      1.3.3
        - setuptools:        45.2.0
        - six:               1.16.0
        - smmap:             5.0.1
        - sniffio:           1.3.0
        - soupsieve:         2.4
        - sqlalchemy:        2.0.7
        - stack-data:        0.6.2
        - sympy:             1.13.2
        - termcolor:         2.4.0
        - terminado:         0.17.1
        - threadpoolctl:     3.5.0
        - tiktoken:          0.7.0
        - tinycss2:          1.2.1
        - tokenizers:        0.19.1
        - tomli:             2.0.1
        - torch:             2.4.0
        - torchao:           0.3.1
        - torchmetrics:      1.4.1
        - torchtune:         0.2.1
        - torchvision:       0.19.0
        - tornado:           6.2
        - tqdm:              4.66.1
        - traitlets:         5.9.0
        - transformers:      4.44.0
        - triton:            3.0.0
        - typing-extensions: 4.5.0
        - urllib3:           1.26.15
        - wandb:             0.17.7
        - wcwidth:           0.2.6
        - webencodings:      0.5.1
        - websocket-client:  1.5.1
        - wheel:             0.34.2
        - xxhash:            3.5.0
        - y-py:              0.5.9
        - yarl:              1.9.4
        - ypy-websocket:     0.8.4
        - zipp:              3.15.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.10
        - release:           5.10.219-208.866.amzn2.x86_64
        - version:           #1 SMP Tue Jun 18 14:00:06 UTC 2024

</details>

More info

The reason of this modification is that the provided example is too abstract to actually test a train meaningful a complete i.e. using a real tokenizer, so that I have added torchtune that provides both these functionalities.
It would be a good add on to have this example adapted in this way.

@loretoparisi loretoparisi added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 26, 2024
@loretoparisi
Copy link
Author

This issue has been solved by pytorch/torchtune#1416

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants
@loretoparisi and others