forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DataPipe] Simple graph snapshotting (pytorch#79479)
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs. ### Implementation The general idea of the simple snapshot is that we will: 1. Create a new iterator 2. Move that iterator forward by `n_iterations` 3. Save that as the `_fast_forward_iterator` of the DataPipe 4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator` ### Usage As of this implementation, the usage will something like: ```python rng = torch.Generator() initial_rng_state = rng.get_state() datapipe: IterDataPipe = ... # Some usage of the DataPipe, here maybe yielding the first 5 values n_iter = 5 it = iter(datapipe) for _ in range(n_iter): next(it) serialized_graph = pickle.dumps(datapipe) # The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state) # It can be deserialized at a later point in time or by a different process deserialized_graph = pickle.loads(serialized_graph) # I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use rng_for_deserialized = torch.Generator() rng_for_deserialized.set_state(initial_rng_state) n_iterations = deserialized_graph._number_of_samples_yielded _simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized) # The while DataPipe graph should have the same state as before serialization, such that: self.assertEqual(list(it), list(deserialized_graph)) # True ``` ### Next Steps If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading. Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored. Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406) Pull Request resolved: pytorch#79479 Approved by: https://github.com/ejguan
- Loading branch information
1 parent
cb63ffc
commit 35d97e2
Showing
9 changed files
with
345 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.