Skip to content

Commit

Permalink
[Checkpoint][2D][4/N] Add nested_dict for distributed checkpoint to c…
Browse files Browse the repository at this point in the history
…ore distributed (pytorch#89537)

This PR moves nested_dict and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint.

This provides the functionality to flatten a nested dict and unflatten a flattened dict.

Docstring will be added in the following PR.
Pull Request resolved: pytorch#89537
Approved by: https://github.com/fduwjj, https://github.com/wanchaol
  • Loading branch information
wz337 authored and pytorchmergebot committed Nov 28, 2022
1 parent a378ba2 commit 23ee675
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
42 changes: 42 additions & 0 deletions test/distributed/checkpoint/test_nested_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Owner(s): ["oncall: distributed"]

import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.distributed.checkpoint.nested_dict import (
flatten_state_dict,
unflatten_state_dict,
)


class TestFlattening(TestCase):
def test_flattening_round_trip(self) -> None:
state_dict = {
"key0": 1,
"key1": [1, 2],
"key2": {1: 2, 2: 3},
"key3": torch.tensor([1]),
"key4": [[torch.tensor(2), "x"], [1, 2, 3], {"key6": [44]}],
}

flatten_dict, mapping = flatten_state_dict(state_dict)
restored = unflatten_state_dict(flatten_dict, mapping)

self.assertEqual(state_dict, restored)

def test_mapping(self) -> None:
state_dict = {
"k0": [1],
"k2": [torch.tensor([1]), 99, [{"k3": torch.tensor(1)}]],
"k3": ["x", 99, [{"k3": "y"}]],
}

_, mapping = flatten_state_dict(state_dict)
self.assertIn(("k0",), mapping.values())
self.assertIn(("k2", 0), mapping.values())
self.assertIn(("k2", 1), mapping.values())
self.assertIn(("k2", 2, 0, "k3"), mapping.values())
self.assertIn(("k3",), mapping.values())


if __name__ == "__main__":
run_tests()
61 changes: 61 additions & 0 deletions torch/distributed/checkpoint/nested_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict, Tuple

from torch.distributed.checkpoint.metadata import (
STATE_DICT_TYPE,
)

from .traverse import (
traverse_state_dict,
set_element,
OBJ_PATH,
STATE_DICT_ITEM,
)

"""
TODO:
Need to add ability to handle tuple, OrderedDict, NamedTuple.
Update mappings from dict to a class.
Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple.
"""


FLATTEN_MAPPING = Dict[str, OBJ_PATH]


# TODO: Update Docstring for nested_dict.py
def flatten_state_dict(
state_dict: STATE_DICT_TYPE,
) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:
"""
Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.
Use ``unflatten_state_dict`` to revert this process.
Returns:
A tuple with the flaten state_dict and a mapping from original to new state_dict.
N.B. The new keys are derived from the object paths, joined by dot.
For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.
"""
flattened: STATE_DICT_TYPE = {}
mappings: FLATTEN_MAPPING = {}

def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
new_fqn = ".".join(map(str, path))
if new_fqn in flattened:
raise ValueError(f"duplicated flatten key {new_fqn}")
flattened[new_fqn] = value
mappings[new_fqn] = path

traverse_state_dict(state_dict, flat_copy)
return flattened, mappings


def unflatten_state_dict(
state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING
) -> STATE_DICT_TYPE:
"""
Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``
"""
nested: STATE_DICT_TYPE = {}
for key, value in state_dict.items():
set_element(nested, mapping[key], value)
return nested

0 comments on commit 23ee675

Please sign in to comment.