Skip to content

RuntimeError: each element in list of batch should be of equal size #20229

Closed
@loretoparisi

Description

@loretoparisi

Bug description

When running the example to train Llama3 from scratch with the Tensor Parallel example here
https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel

using TorchTune as DataLoader of an unstructured dataset with a given batch size, I get the error indicating that all samples in all batches must have the same size:

[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 170, in collate
[rank4]:     raise RuntimeError('each element in list of batch should be of equal size')
[rank4]: RuntimeError: each element in list of batch should be of equal size

The only modification I did was to change the DataLoader using torchtune in the example data.py script here. Details following.

What version are you seeing the problem on?

v1.x

How to reproduce the bug

install `torchtune` and download the text dataset:


curl https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt > t8.shakespeare.txt 


and create this function

```python
from torchtune.datasets import text_completion_dataset
def load_dataset(seq_length=2048):
    #https://pytorch.org/torchtune/main/generated/torchtune.datasets.text_completion_dataset.html#torchtune.datasets.text_completion_dataset
    dataset = text_completion_dataset(
        tokenizer,
        source="text",
        column="text",
        data_files="t8.shakespeare.txt",
        split="train",
        max_seq_len=seq_length,
        packed=False
    )
    return dataset

add this method to the example file data.py here

import torch
from torch.utils.data import Dataset


def get_random_tokens(vocab_size,size):
    # random tokens list
    tokens = torch.randint(
        vocab_size,
        size=size,
        # Set a seed to make this toy dataset the same on each rank
        # Fabric will add a `DistributedSampler` to shard the data correctly
        generator=torch.Generator().manual_seed(42),
    )
    return tokens

def get_text_completion_dataset_tokens(seq_length,batch_size):
    # https://pytorch.org/torchtune/main/generated/torchtune.utils.padded_collate.html#torchtune.utils.padded_collate
    from torchtune.utils import padded_collate
    dataset = load_dataset(seq_length=seq_length)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False, collate_fn=padded_collate)
    tokens = []
    for sample in dataloader:
        batch = sample['tokens'].tolist()
        for sample in batch:
            tokens.append( sample )
    return tokens
                
class RandomTokenDataset(Dataset):
    def __init__(self, vocab_size: int, seq_length: int, batch_size:int):
        
        self.vocab_size = vocab_size
        self.seq_length = seq_length
        self.batch_size = batch_size
        
        #self.tokens = get_random_tokens(self.vocab_size,(len(self), self.seq_length + 1))
        self.tokens = get_text_completion_dataset_tokens(seq_length,batch_size)
        
    def __len__(self) -> int:
        return 128

    def __getitem__(self, item: int):
        return self.tokens[item]

This will end with the error above.



### Error messages and logs

[rank4]: Traceback (most recent call last):
[rank4]:   File "train.py", line 233, in <module>
[rank4]:     train()
[rank4]:   File "train.py", line 222, in train
[rank4]:     trainer.fit(model)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
[rank4]:     call._call_and_handle_interrupt(
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank4]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank4]:     return function(*args, **kwargs)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
[rank4]:     self._run(model, ckpt_path=ckpt_path)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
[rank4]:     results = self._run_stage()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
[rank4]:     self.fit_loop.run()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank4]:     self.advance()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank4]:     self.epoch_loop.run(self._data_fetcher)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank4]:     self.advance(data_fetcher)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 212, in advance
[rank4]:     batch, _, __ = next(data_fetcher)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 133, in __next__
[rank4]:     batch = super().__next__()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fetchers.py", line 60, in __next__
[rank4]:     batch = next(self.iterator)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 341, in __next__
[rank4]:     out = next(self._iterator)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/utilities/combined_loader.py", line 78, in __next__
[rank4]:     out[i] = next(self.iterators[i])
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
[rank4]:     data = self._next_data()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
[rank4]:     return self._process_data(data)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
[rank4]:     data.reraise()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/_utils.py", line 706, in reraise
[rank4]:     raise exception
[rank4]: RuntimeError: Caught RuntimeError in DataLoader worker process 0.
[rank4]: Original Traceback (most recent call last):
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
[rank4]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
[rank4]:     return self.collate_fn(data)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 317, in default_collate
[rank4]:     return collate(batch, collate_fn_map=default_collate_fn_map)

Environment

Current environment
<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:
                - NVIDIA A10G
                - NVIDIA A10G
                - NVIDIA A10G
                - NVIDIA A10G
        - available:         True
        - version:           12.1
* Lightning:
        - lightning:         2.3.3
        - lightning-utilities: 0.11.6
        - pytorch-lightning: 1.9.5
        - torch:             2.4.0
        - torchao:           0.3.1
        - torchmetrics:      1.4.1
        - torchtune:         0.2.1
        - torchvision:       0.19.0
* Packages:
        - aiofiles:          22.1.0
        - aiohappyeyeballs:  2.4.0
        - aiohttp:           3.10.5
        - aiosignal:         1.3.1
        - aiosqlite:         0.18.0
        - alembic:           1.10.2
        - antlr4-python3-runtime: 4.9.3
        - anyio:             3.6.2
        - appdirs:           1.4.4
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - argparse:          1.4.0
        - asttokens:         2.2.1
        - async-generator:   1.10
        - async-timeout:     4.0.3
        - attrs:             22.2.0
        - autofaiss:         2.15.8
        - babel:             2.12.1
        - backcall:          0.2.0
        - beautifulsoup4:    4.11.2
        - bleach:            6.0.0
        - blobfile:          2.1.1
        - boto3:             1.35.2
        - botocore:          1.35.2
        - certifi:           2022.12.7
        - certipy:           0.1.3
        - cffi:              1.15.1
        - charset-normalizer: 3.1.0
        - click:             8.1.7
        - coloredlogs:       15.0.1
        - comm:              0.1.2
        - cryptography:      39.0.2
        - datasets:          2.21.0
        - debugpy:           1.6.6
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.8
        - docker-pycreds:    0.4.0
        - embedding-reader:  1.7.0
        - executing:         1.2.0
        - fairscale:         0.4.6
        - faiss-cpu:         1.8.0.post1
        - fastjsonschema:    2.16.3
        - filelock:          3.15.4
        - fire:              0.4.0
        - flatbuffers:       24.3.25
        - frozenlist:        1.4.1
        - fsspec:            2024.6.1
        - fuzzywuzzy:        0.18.0
        - gitdb:             4.0.11
        - gitpython:         3.1.43
        - greenlet:          2.0.2
        - hnswlib:           0.7.0
        - huggingface-hub:   0.24.6
        - humanfriendly:     10.0
        - idna:              3.4
        - importlib-metadata: 6.1.0
        - importlib-resources: 5.12.0
        - ipykernel:         6.21.3
        - ipython:           8.11.0
        - ipython-genutils:  0.2.0
        - jedi:              0.18.2
        - jinja2:            3.1.2
        - jmespath:          1.0.1
        - joblib:            1.3.2
        - json5:             0.9.11
        - jsonschema:        4.17.3
        - jupyter-client:    8.0.3
        - jupyter-core:      5.3.0
        - jupyter-events:    0.6.3
        - jupyter-server:    2.5.0
        - jupyter-server-fileid: 0.8.0
        - jupyter-server-terminals: 0.4.4
        - jupyter-server-ydoc: 0.8.0
        - jupyter-telemetry: 0.1.0
        - jupyter-ydoc:      0.2.3
        - jupyterhub:        3.1.1
        - jupyterlab:        3.6.2
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-server: 2.20.0
        - levenshtein:       0.23.0
        - lightning:         2.3.3
        - lightning-utilities: 0.11.6
        - lxml:              4.9.4
        - mako:              1.2.4
        - markupsafe:        2.1.2
        - matplotlib-inline: 0.1.6
        - mistune:           2.0.5
        - mpmath:            1.3.0
        - multidict:         6.0.5
        - multiprocess:      0.70.16
        - nbclassic:         0.5.3
        - nbclient:          0.7.2
        - nbconvert:         7.2.10
        - nbformat:          5.7.3
        - nest-asyncio:      1.5.6
        - networkx:          3.1
        - nltk:              3.9.1
        - notebook:          6.5.3
        - notebook-shim:     0.2.2
        - numpy:             1.24.4
        - nvidia-cublas-cu12: 12.1.3.1
        - nvidia-cuda-cupti-cu12: 12.1.105
        - nvidia-cuda-nvrtc-cu12: 12.1.105
        - nvidia-cuda-runtime-cu12: 12.1.105
        - nvidia-cudnn-cu12: 9.1.0.70
        - nvidia-cufft-cu12: 11.0.2.54
        - nvidia-curand-cu12: 10.3.2.106
        - nvidia-cusolver-cu12: 11.4.5.107
        - nvidia-cusparse-cu12: 12.1.0.106
        - nvidia-nccl-cu12:  2.20.5
        - nvidia-nvjitlink-cu12: 12.6.20
        - nvidia-nvtx-cu12:  12.1.105
        - oauthlib:          3.2.2
        - omegaconf:         2.3.0
        - onnx:              1.14.1
        - onnxruntime-gpu:   1.14.1
        - optimum:           1.16.2
        - packaging:         23.0
        - pamela:            1.0.0
        - pandas:            1.3.5
        - pandocfilters:     1.5.0
        - parso:             0.8.3
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            10.4.0
        - pip:               20.0.2
        - pkgutil-resolve-name: 1.3.10
        - platformdirs:      3.1.1
        - prometheus-client: 0.16.0
        - prompt-toolkit:    3.0.38
        - protobuf:          5.27.3
        - psutil:            5.9.5
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyarrow:           15.0.2
        - pycparser:         2.21
        - pycryptodomex:     3.20.0
        - pygments:          2.14.0
        - pyopenssl:         23.0.0
        - pyrsistent:        0.19.3
        - python-dateutil:   2.8.2
        - python-json-logger: 2.0.7
        - python-levenshtein: 0.23.0
        - pytorch-lightning: 1.9.5
        - pytz:              2022.7.1
        - pyyaml:            6.0
        - pyzmq:             26.0.3
        - rapidfuzz:         3.4.0
        - regex:             2024.7.24
        - requests:          2.28.2
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - ruamel.yaml:       0.17.21
        - ruamel.yaml.clib:  0.2.7
        - s3transfer:        0.10.2
        - safetensors:       0.4.4
        - scikit-learn:      1.3.2
        - scipy:             1.10.1
        - send2trash:        1.8.0
        - sentence-transformers: 2.2.2
        - sentencepiece:     0.2.0
        - sentry-sdk:        2.13.0
        - setproctitle:      1.3.3
        - setuptools:        45.2.0
        - six:               1.16.0
        - smmap:             5.0.1
        - sniffio:           1.3.0
        - soupsieve:         2.4
        - sqlalchemy:        2.0.7
        - stack-data:        0.6.2
        - sympy:             1.13.2
        - termcolor:         2.4.0
        - terminado:         0.17.1
        - threadpoolctl:     3.5.0
        - tiktoken:          0.7.0
        - tinycss2:          1.2.1
        - tokenizers:        0.19.1
        - tomli:             2.0.1
        - torch:             2.4.0
        - torchao:           0.3.1
        - torchmetrics:      1.4.1
        - torchtune:         0.2.1
        - torchvision:       0.19.0
        - tornado:           6.2
        - tqdm:              4.66.1
        - traitlets:         5.9.0
        - transformers:      4.44.0
        - triton:            3.0.0
        - typing-extensions: 4.5.0
        - urllib3:           1.26.15
        - wandb:             0.17.7
        - wcwidth:           0.2.6
        - webencodings:      0.5.1
        - websocket-client:  1.5.1
        - wheel:             0.34.2
        - xxhash:            3.5.0
        - y-py:              0.5.9
        - yarl:              1.9.4
        - ypy-websocket:     0.8.4
        - zipp:              3.15.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.10
        - release:           5.10.219-208.866.amzn2.x86_64
        - version:           #1 SMP Tue Jun 18 14:00:06 UTC 2024

</details>

More info

The reason of this modification is that the provided example is too abstract to actually test a train meaningful a complete i.e. using a real tokenizer, so that I have added torchtune that provides both these functionalities.
It would be a good add on to have this example adapted in this way.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions