forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Checkpoint][2D][4/N] Add nested_dict for distributed checkpoint to c…
…ore distributed (pytorch#89537) This PR moves nested_dict and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint. This provides the functionality to flatten a nested dict and unflatten a flattened dict. Docstring will be added in the following PR. Pull Request resolved: pytorch#89537 Approved by: https://github.com/fduwjj, https://github.com/wanchaol
- Loading branch information
1 parent
a378ba2
commit 23ee675
Showing
2 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
import torch | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
from torch.distributed.checkpoint.nested_dict import ( | ||
flatten_state_dict, | ||
unflatten_state_dict, | ||
) | ||
|
||
|
||
class TestFlattening(TestCase): | ||
def test_flattening_round_trip(self) -> None: | ||
state_dict = { | ||
"key0": 1, | ||
"key1": [1, 2], | ||
"key2": {1: 2, 2: 3}, | ||
"key3": torch.tensor([1]), | ||
"key4": [[torch.tensor(2), "x"], [1, 2, 3], {"key6": [44]}], | ||
} | ||
|
||
flatten_dict, mapping = flatten_state_dict(state_dict) | ||
restored = unflatten_state_dict(flatten_dict, mapping) | ||
|
||
self.assertEqual(state_dict, restored) | ||
|
||
def test_mapping(self) -> None: | ||
state_dict = { | ||
"k0": [1], | ||
"k2": [torch.tensor([1]), 99, [{"k3": torch.tensor(1)}]], | ||
"k3": ["x", 99, [{"k3": "y"}]], | ||
} | ||
|
||
_, mapping = flatten_state_dict(state_dict) | ||
self.assertIn(("k0",), mapping.values()) | ||
self.assertIn(("k2", 0), mapping.values()) | ||
self.assertIn(("k2", 1), mapping.values()) | ||
self.assertIn(("k2", 2, 0, "k3"), mapping.values()) | ||
self.assertIn(("k3",), mapping.values()) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
from typing import Dict, Tuple | ||
|
||
from torch.distributed.checkpoint.metadata import ( | ||
STATE_DICT_TYPE, | ||
) | ||
|
||
from .traverse import ( | ||
traverse_state_dict, | ||
set_element, | ||
OBJ_PATH, | ||
STATE_DICT_ITEM, | ||
) | ||
|
||
""" | ||
TODO: | ||
Need to add ability to handle tuple, OrderedDict, NamedTuple. | ||
Update mappings from dict to a class. | ||
Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple. | ||
""" | ||
|
||
|
||
FLATTEN_MAPPING = Dict[str, OBJ_PATH] | ||
|
||
|
||
# TODO: Update Docstring for nested_dict.py | ||
def flatten_state_dict( | ||
state_dict: STATE_DICT_TYPE, | ||
) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: | ||
""" | ||
Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. | ||
Use ``unflatten_state_dict`` to revert this process. | ||
Returns: | ||
A tuple with the flaten state_dict and a mapping from original to new state_dict. | ||
N.B. The new keys are derived from the object paths, joined by dot. | ||
For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. | ||
""" | ||
flattened: STATE_DICT_TYPE = {} | ||
mappings: FLATTEN_MAPPING = {} | ||
|
||
def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: | ||
new_fqn = ".".join(map(str, path)) | ||
if new_fqn in flattened: | ||
raise ValueError(f"duplicated flatten key {new_fqn}") | ||
flattened[new_fqn] = value | ||
mappings[new_fqn] = path | ||
|
||
traverse_state_dict(state_dict, flat_copy) | ||
return flattened, mappings | ||
|
||
|
||
def unflatten_state_dict( | ||
state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING | ||
) -> STATE_DICT_TYPE: | ||
""" | ||
Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict`` | ||
""" | ||
nested: STATE_DICT_TYPE = {} | ||
for key, value in state_dict.items(): | ||
set_element(nested, mapping[key], value) | ||
return nested |