Skip to content

Commit

Permalink
[DataPipe] Simple graph snapshotting (pytorch#79479)
Browse files Browse the repository at this point in the history
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.

### Implementation

The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`

### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
    next(it)
serialized_graph = pickle.dumps(datapipe)

# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded

_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph))  # True
```

### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.

Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.

Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)
Pull Request resolved: pytorch#79479
Approved by: https://github.com/ejguan
  • Loading branch information
NivekT authored and pytorchmergebot committed Jul 23, 2022
1 parent cb63ffc commit 35d97e2
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 24 deletions.
220 changes: 220 additions & 0 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
from torch.utils.data.datapipes.utils.decoder import (
basichandlers as decoder_basichandlers,
)
from torch.utils.data.datapipes.utils.snapshot import (
_simple_graph_snapshot_restoration
)
from torch.utils.data.datapipes.dataframe import CaptureDataFrame
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper

Expand Down Expand Up @@ -2986,5 +2989,222 @@ def reset(self):
list(it)
self.assertEqual(3, datapipe._number_of_samples_yielded)


class _CustomNonGeneratorTestDataPipe(IterDataPipe):
def __init__(self):
self.n = 10
self.source = list(range(self.n))

# This class's `__iter__` is not a generator function
def __iter__(self):
return iter(self.source)

def __len__(self):
return self.n


class _CustomSelfNextTestDataPipe(IterDataPipe):
def __init__(self):
self.n = 10
self.iter = iter(range(self.n))

def __iter__(self):
return self

def __next__(self):
return next(self.iter)

def reset(self):
self.iter = iter(range(self.n))

def __len__(self):
return self.n


class TestIterDataPipeGraphFastForward(TestCase):

def _fast_forward_graph_test_helper(self, datapipe, fast_forward_fn, expected_res, n_iterations=3, rng=None):
if rng is None:
rng = torch.Generator()
rng = rng.manual_seed(0)
torch.utils.data.graph_settings.apply_shuffle_seed(datapipe, rng)

# Test Case: fast forward works with list
rng.manual_seed(0)
fast_forward_fn(datapipe, n_iterations, rng)
actual_res = list(datapipe)
self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
self.assertEqual(expected_res[n_iterations:], actual_res)

# Test Case: fast forward works with iterator
rng.manual_seed(0)
fast_forward_fn(datapipe, n_iterations, rng)
it = iter(datapipe)
actual_res = list(it)
self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
self.assertEqual(expected_res[n_iterations:], actual_res)
with self.assertRaises(StopIteration):
next(it)

def test_simple_snapshot_graph(self):
graph1 = dp.iter.IterableWrapper(range(10))
res1 = list(range(10))
self._fast_forward_graph_test_helper(graph1, _simple_graph_snapshot_restoration,
expected_res=res1)

graph2 = graph1.map(_mul_10)
res2 = [10 * x for x in res1]
self._fast_forward_graph_test_helper(graph2, _simple_graph_snapshot_restoration,
expected_res=res2)

rng = torch.Generator()
graph3 = graph2.shuffle()
rng.manual_seed(0)
torch.utils.data.graph_settings.apply_shuffle_seed(graph3, rng)
res3 = list(graph3)
self._fast_forward_graph_test_helper(graph3, _simple_graph_snapshot_restoration,
expected_res=res3)

graph4 = graph3.map(_mul_10)
res4 = [10 * x for x in res3]
self._fast_forward_graph_test_helper(graph4, _simple_graph_snapshot_restoration,
expected_res=res4)

batch_size = 2
graph5 = graph4.batch(batch_size)
res5 = [res4[i:i + batch_size] for i in range(0, len(res4), batch_size)] # .batch(2)
self._fast_forward_graph_test_helper(graph5, _simple_graph_snapshot_restoration,
expected_res=res5)

# With `fork` and `zip`
cdp1, cdp2 = graph5.fork(2)
graph6 = cdp1.zip(cdp2)
rng = rng.manual_seed(100)
torch.utils.data.graph_settings.apply_shuffle_seed(graph6, rng)
res6 = [(x, x) for x in res5]
self._fast_forward_graph_test_helper(graph6, _simple_graph_snapshot_restoration,
expected_res=res6)

# With `fork` and `concat`
graph7 = cdp1.concat(cdp2)
res7 = res5 * 2
self._fast_forward_graph_test_helper(graph7, _simple_graph_snapshot_restoration,
expected_res=res7)

# Raises an exception if the graph has already been restored
with self.assertRaisesRegex(RuntimeError, "Snapshot restoration cannot be applied."):
_simple_graph_snapshot_restoration(graph7, 1)
_simple_graph_snapshot_restoration(graph7, 1)

def test_simple_snapshot_custom_non_generator(self):
graph = _CustomNonGeneratorTestDataPipe()
self._fast_forward_graph_test_helper(graph, _simple_graph_snapshot_restoration, expected_res=range(10))

def test_simple_snapshot_custom_self_next(self):
graph = _CustomSelfNextTestDataPipe()
self._fast_forward_graph_test_helper(graph, _simple_graph_snapshot_restoration, expected_res=range(10))

def _snapshot_test_helper(self, datapipe, expected_res, n_iter=3, rng=None):
"""
Extend the previous test with serialization and deserialization test.
"""
if rng is None:
rng = torch.Generator()
rng.manual_seed(0)
torch.utils.data.graph_settings.apply_shuffle_seed(datapipe, rng)
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
deserialized_graph = pickle.loads(serialized_graph)
self.assertEqual(n_iter, datapipe._number_of_samples_yielded)
self.assertEqual(n_iter, deserialized_graph._number_of_samples_yielded)

rng_for_deserialized = torch.Generator()
rng_for_deserialized.manual_seed(0)
_simple_graph_snapshot_restoration(deserialized_graph, n_iter, rng=rng_for_deserialized)
self.assertEqual(expected_res[n_iter:], list(it))
self.assertEqual(expected_res[n_iter:], list(deserialized_graph))

def test_simple_snapshot_graph_with_serialization(self):
graph1 = dp.iter.IterableWrapper(range(10))
res1 = list(range(10))
self._snapshot_test_helper(graph1, expected_res=res1)

graph2 = graph1.map(_mul_10)
res2 = [10 * x for x in res1]
self._snapshot_test_helper(graph2, expected_res=res2)

rng = torch.Generator()
graph3 = graph2.shuffle()
rng.manual_seed(0)
torch.utils.data.graph_settings.apply_shuffle_seed(graph3, rng)
res3 = list(graph3)
self._snapshot_test_helper(graph3, expected_res=res3)

graph4 = graph3.map(_mul_10)
res4 = [10 * x for x in res3]
self._snapshot_test_helper(graph4, expected_res=res4)

batch_size = 2
graph5 = graph4.batch(batch_size)
res5 = [res4[i:i + batch_size] for i in range(0, len(res4), batch_size)] # .batch(2)
self._snapshot_test_helper(graph5, expected_res=res5)

# With `fork` and `zip`
cdp1, cdp2 = graph5.fork(2)
graph6 = cdp1.zip(cdp2)
res6 = [(x, x) for x in res5]
self._snapshot_test_helper(graph6, expected_res=res6)

# With `fork` and `concat`
graph7 = cdp1.concat(cdp2)
res7 = res5 * 2
self._snapshot_test_helper(graph7, expected_res=res7)

def test_simple_snapshot_graph_repeated(self):
cdp1, cdp2 = dp.iter.IterableWrapper(range(10)).map(_mul_10).shuffle().map(_mul_10).map(_mul_10).fork(2)
graph = cdp1.zip(cdp2)

rng = torch.Generator()
rng.manual_seed(0)
torch.utils.data.graph_settings.apply_shuffle_seed(graph, rng)

# Get expected result
expected_res = list(graph)

rng.manual_seed(0)
torch.utils.data.graph_settings.apply_shuffle_seed(graph, rng)
it = iter(graph)
n_iter = 3
for _ in range(n_iter):
next(it)

# First serialization/deserialization
serialized_graph = pickle.dumps(graph)
deserialized_graph = pickle.loads(serialized_graph)

rng_for_deserialized = torch.Generator()
rng_for_deserialized.manual_seed(0)
_simple_graph_snapshot_restoration(deserialized_graph, deserialized_graph._number_of_samples_yielded,
rng=rng_for_deserialized)

it = iter(deserialized_graph)
# Get the next element and ensure it is as expected
self.assertEqual(expected_res[3], next(it))

# Serializalize/Deserialize and fast-forward again after to ensure it works
serialized_graph2 = pickle.dumps(deserialized_graph)
deserialized_graph2 = pickle.loads(serialized_graph2)

rng_for_deserialized = torch.Generator()
rng_for_deserialized.manual_seed(0)
_simple_graph_snapshot_restoration(deserialized_graph2, deserialized_graph._number_of_samples_yielded,
rng=rng_for_deserialized)

# Get the next element and ensure it is as expected
self.assertEqual(expected_res[4:], list(deserialized_graph2))


if __name__ == '__main__':
run_tests()
30 changes: 29 additions & 1 deletion torch/utils/data/datapipes/_hook_iterator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import inspect
import functools
from enum import Enum

import torch.autograd


class _SnapshotState(Enum):
r"""
These are the snapshotting-related states that IterDataPipes can be in.
`NotStarted` - allows you to restore a snapshot and create an iterator without reset
`Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
`Iterating` - can restore, will reset if you create a new iterator
"""
NotStarted = 0
Restored = 1
Iterating = 2


def _simplify_obj_name(obj) -> str:
"""
Simplify the display strings of objects for the purpose of rendering within DataPipe error messages.
Expand Down Expand Up @@ -140,6 +154,15 @@ def __getattr__(self, name):
def wrap_generator(*args, **kwargs):
gen = func(*args, **kwargs)
datapipe = args[0]
if datapipe._fast_forward_iterator:
it = datapipe._fast_forward_iterator
datapipe._fast_forward_iterator = None
datapipe._snapshot_state = _SnapshotState.Iterating
while True:
try:
yield next(it)
except StopIteration:
return
iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator
_profiler_enabled = torch.autograd._profiler_enabled()
try:
Expand All @@ -161,7 +184,7 @@ def wrap_generator(*args, **kwargs):
_check_iterator_valid(datapipe, iterator_id)
response = gen.send(request)
except StopIteration as e:
return e.value
return
except Exception as e:
# TODO: Simplify the traceback message to skip over `response = gen.send(None)`
# Part of https://github.com/pytorch/data/issues/284
Expand Down Expand Up @@ -206,6 +229,11 @@ def wrap_next(*args, **kwargs):
def wrap_iter(*args, **kwargs):
iter_ret = func(*args, **kwargs)
datapipe = args[0]
datapipe._snapshot_state = _SnapshotState.Iterating
if datapipe._fast_forward_iterator:
iter_ret = datapipe._fast_forward_iterator
datapipe._fast_forward_iterator = None
return iter_ret
iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator
return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace)

Expand Down
28 changes: 8 additions & 20 deletions torch/utils/data/datapipes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import functools
import numbers
import sys
from torch.utils.data.datapipes._hook_iterator import hook_iterator

from torch.utils.data.datapipes._hook_iterator import hook_iterator, _SnapshotState
from typing import (Any, Dict, Iterator, Generic, List, Set, Tuple, TypeVar, Union,
get_type_hints)
from typing import _eval_type, _tp_cache, _type_check, _type_repr # type: ignore[attr-defined]
Expand Down Expand Up @@ -350,33 +351,20 @@ def __new__(cls, name, bases, namespace, **kwargs):
@functools.wraps(reset_func)
def conditional_reset(*args, **kwargs):
r"""
Only execute DataPipe's `reset()` method if `_restored` is False. This allows recently
Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating`. This allows recently
restored DataPipe to preserve its restored state during the initial `__iter__` call.
"""
datapipe = args[0]
if datapipe._restored is True:
datapipe._restored = False
else:
if datapipe._snapshot_state == _SnapshotState.Iterating:
# Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have
# already begun iterating.
datapipe._number_of_samples_yielded = 0
datapipe._fast_forward_iterator = None
reset_func(*args, **kwargs)
datapipe._snapshot_state = _SnapshotState.Iterating

namespace['reset'] = conditional_reset

if '__setstate__' in namespace:
setstate_func = namespace['__setstate__']

@functools.wraps(setstate_func)
def wrap_setstate(*args, **kwargs):
r"""
Set `_restored` to True during `__setstate__`, such that the next `reset()` call during
iterator creation will not actually reset the state of the DataPipe.
"""
datapipe = args[0]
datapipe._restored = True
return setstate_func(*args, **kwargs)

namespace['__setstate__'] = wrap_setstate

if '__iter__' in namespace:
hook_iterator(namespace, 'enumerate(DataPipe)#{}'.format(name))
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
Expand Down
6 changes: 4 additions & 2 deletions torch/utils/data/datapipes/datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, Callable, Optional, TypeVar, Generic, Iterator

from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
from torch.utils.data.datapipes.utils.common import (
_deprecation_warning,
_iter_deprecated_functional_names,
Expand Down Expand Up @@ -111,7 +112,8 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
repr_hook: Optional[Callable] = None
_valid_iterator_id: Optional[int] = None
_number_of_samples_yielded: int = 0
_restored: bool = False
_snapshot_state: _SnapshotState = _SnapshotState.NotStarted
_fast_forward_iterator: Optional[Iterator] = None

def __getattr__(self, attribute_name):
if attribute_name in IterDataPipe.functions:
Expand Down Expand Up @@ -186,7 +188,7 @@ def __str__(self):
# Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
return str(self.__class__.__qualname__)

def reset(self):
def reset(self) -> None:
r"""
Reset the `IterDataPipe` to the initial state. By default, no-op. For subclasses of `IterDataPipe`,
depending on their functionalities, they may want to override this method with implementations that
Expand Down
4 changes: 3 additions & 1 deletion torch/utils/data/datapipes/datapipe.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# classes/objects here, even though we are not injecting extra code into them at the moment.

from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar, Union
from torch.utils.data import Dataset, IterableDataset, default_collate

Expand Down Expand Up @@ -39,7 +40,8 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
str_hook: Optional[Callable] = ...
repr_hook: Optional[Callable] = ...
_number_of_samples_yielded: int = ...
_restored: bool = False
_snapshot_state: _SnapshotState = _SnapshotState.Iterating
_fast_forward_iterator: Optional[Iterator] = ...
def __getattr__(self, attribute_name: Any): ...
@classmethod
def register_function(cls, function_name: Any, function: Any) -> None: ...
Expand Down
Loading

0 comments on commit 35d97e2

Please sign in to comment.