Skip to content

loading a StatefulDataLoader state_dict creates a slightly different one in the dataloader, potentially dropping an epoch #1437

Closed
@gailweiss

Description

@gailweiss

🐛 Describe the bug

Loading a state_dict taken from a StatefulDataLoader that has just completed an epoch yields what I would describe as a "just finishing" rather than "just finished" state: the next iteration over that dataloader does nothing (as opposed to a full new epoch), before continuing as expected on the following one.

Reproduction: I run a dataloader for two epochs, save the state dict, remake it, load the state dict, and try to run for two more epochs. I find only 3 epochs have run. They do align with the first 3 epochs of a separate 4-epoch run, but, they are 3 instead of 4.

(Side comment - the fact that they align here despite me not setting the random seed reveals that the shuffling of the StatefulDataLoader ignores the current state of the random number generators - this behaviour could maybe be clearer in the documentation, as it is not equivalent to torch.utils.data.DataLoader in this regard)

from torchdata.stateful_dataloader import StatefulDataLoader

def get_dl():
    d = list(range(100))
    return StatefulDataLoader(d, batch_size=1, shuffle=True)

def run_through(dl):
    for i, b in enumerate(dl):
        if i == 0:
            print(b)

def run_for_goes(goes):
    sd = None
    c = 0
    for n in goes:
        dl = get_dl()

        if None is not sd:
            print("loading state dict:", sd)
            dl.load_state_dict(sd)
            print("recall: loaded:", sd)
            print("state dict is now:", dl.state_dict())

        for j in range(n):
            print(c, j)
            run_through(dl)
            c += 1
    
        sd = dl.state_dict()

print("===")
run_for_goes([2,2])
print("===")
run_for_goes([4])

expected output:

===
0 0
tensor([45])
1 1
tensor([33])
loading state dict: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
recall: loaded: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
state dict is now: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
2 0
tensor([62])
3 1
tensor([19])
===
0 0
tensor([45])
1 1
tensor([33])
2 2
tensor([62])
3 3
tensor([19])

obtained output:

===
0 0
tensor([45])
1 1
tensor([33])
loading state dict: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
recall: loaded: {'_index_sampler_state': {'samples_yielded': 100, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 100}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
state dict is now: {'_index_sampler_state': {'samples_yielded': 0, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 0}}, '_sampler_iter_state': None, '_sampler_iter_yielded': 100, '_num_yielded': 100, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
2 0
3 1
tensor([62])
===
0 0
tensor([45])
1 1
tensor([33])
2 2
tensor([62])
3 3
tensor([19])

Versions

PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.0.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.8
[pip3] numpy==1.26.3
[pip3] pytorch-lightning==2.0.3
[pip3] torch==2.2.0
[pip3] torchaudio==2.2.0
[pip3] torchdata==0.10.1
[pip3] torchmetrics==1.1.2
[pip3] torchvision==0.15.2a0
[conda] msgpack-numpy 0.4.8 pypi_0 pypi
[conda] numpy 1.26.3 py311he598dae_0
[conda] numpy-base 1.26.3 py311hfbfe69c_0
[conda] pytorch 2.2.0 py3.11_0 pytorch
[conda] pytorch-lightning 2.0.3 py311hca03da5_0
[conda] torchaudio 2.2.0 py311_cpu pytorch
[conda] torchdata 0.10.1 pypi_0 pypi
[conda] torchmetrics 1.1.2 py311hca03da5_0
[conda] torchvision 0.15.2 cpu_py311he74fb5d_0

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions