Skip to content

Commit 56bed0d

Browse files
rohan-varmapytorchmergebot
authored andcommitted
Load state dict post hook
Implements `register_load_state_dict_post_hook` API as discussed in pytorch#75287. Unittests cover: - Ensuring hooks are called with the correct module - Hook is called with `IncompatibleKeys` field - If hook modifies this, load_state_dict returns the modified result Pull Request resolved: pytorch#76823 Approved by: https://github.com/albanD
1 parent f84d4d9 commit 56bed0d

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed

test/test_nn.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21329,6 +21329,62 @@ def __init__(self, mod):
2132921329
m.load_state_dict(state_dict)
2133021330
self.assertEqual(2, hook_called)
2133121331

21332+
def test_load_state_dict_post_hook(self):
21333+
hook_called = 0
21334+
21335+
class MyModule(nn.Module):
21336+
def __init__(self):
21337+
super(MyModule, self).__init__()
21338+
self.foo = torch.nn.Parameter(torch.rand(10))
21339+
21340+
def my_post_load_hook(self, module, incompatible_keys):
21341+
assert module is self
21342+
nonlocal hook_called
21343+
incompatible_keys.missing_keys.append("foo")
21344+
incompatible_keys.unexpected_keys.append("bar")
21345+
hook_called += 1
21346+
21347+
nested = MyModule()
21348+
wrapped = nn.ModuleList([nested])
21349+
handle = nested.register_load_state_dict_post_hook(
21350+
nested.my_post_load_hook,
21351+
)
21352+
# Hook must be called even if it is wrapped
21353+
ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
21354+
self.assertEqual(hook_called, 1)
21355+
# Ensure that the hook modified missing_keys and unexpected_keys
21356+
missing = ret.missing_keys
21357+
unexpected = ret.unexpected_keys
21358+
self.assertEqual(missing, ["foo"])
21359+
self.assertEqual(unexpected, ["bar"])
21360+
# When called with strict=True, the error raised should mention the
21361+
# missing and unexpected keys the hook added.
21362+
with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"):
21363+
wrapped.load_state_dict(wrapped.state_dict(), strict=True)
21364+
self.assertEqual(hook_called, 2)
21365+
# Removing the hook via handle.remove() should cause it not to
21366+
# fire anymore.
21367+
handle.remove()
21368+
# Hook did not run so it should not have added any keys
21369+
ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
21370+
self.assertEqual(ret.missing_keys, [])
21371+
self.assertEqual(ret.unexpected_keys, [])
21372+
# hook_called should not have been incremented
21373+
self.assertEqual(hook_called, 2)
21374+
21375+
def load_hook_clear_incompatible(module, incompatible_keys):
21376+
incompatible_keys.missing_keys.clear()
21377+
incompatible_keys.unexpected_keys.clear()
21378+
21379+
nested.register_load_state_dict_post_hook(load_hook_clear_incompatible)
21380+
state_dict = wrapped.state_dict()
21381+
state_dict["extra"] = torch.ones(1)
21382+
# load state_dict with strict=True should not throw.
21383+
ret = wrapped.load_state_dict(state_dict, strict=True)
21384+
# explicitly ensure that the post hook clearned out incompatible_keys
21385+
self.assertEqual([], ret.missing_keys)
21386+
self.assertEqual([], ret.unexpected_keys)
21387+
2133221388

2133321389
instantiate_device_type_tests(TestNNDeviceType, globals())
2133421390
instantiate_parametrized_tests(TestNN)

torch/distributed/nn/api/remote_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"_forward_pre_hooks",
6666
"_state_dict_hooks",
6767
"_load_state_dict_pre_hooks",
68+
"_load_state_dict_post_hooks",
6869
"_modules",
6970
# The two attributes below are generated methods, not available at pickling time.
7071
"forward_async",

torch/nn/modules/module.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def __init__(self) -> None:
267267
self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
268268
self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
269269
self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
270+
self._load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict()
270271
self._modules: Dict[str, Optional['Module']] = OrderedDict()
271272

272273
forward: Callable[..., Any] = _forward_unimplemented
@@ -1183,6 +1184,8 @@ def __setstate__(self, state):
11831184
self._state_dict_hooks = OrderedDict()
11841185
if '_load_state_dict_pre_hooks' not in self.__dict__:
11851186
self._load_state_dict_pre_hooks = OrderedDict()
1187+
if '_load_state_dict_post_hooks' not in self.__dict__:
1188+
self._load_state_dict_post_hooks = OrderedDict()
11861189
if '_non_persistent_buffers_set' not in self.__dict__:
11871190
self._non_persistent_buffers_set = set()
11881191
if '_is_full_backward_hook' not in self.__dict__:
@@ -1403,6 +1406,37 @@ def _register_load_state_dict_pre_hook(self, hook, with_module=False):
14031406
self._load_state_dict_pre_hooks[handle.id] = hook
14041407
return handle
14051408

1409+
def register_load_state_dict_post_hook(self, hook):
1410+
r"""Registers a post hook to be run after module's ``load_state_dict``
1411+
is called.
1412+
1413+
It should have the following signature::
1414+
hook(module, incompatible_keys) -> None
1415+
1416+
The ``module`` argument is the current module that this hook is registered
1417+
on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
1418+
of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
1419+
is a ``list`` of ``str`` containing the missing keys and
1420+
``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
1421+
1422+
The given incompatible_keys can be modified inplace if needed.
1423+
1424+
Note that the checks performed when calling :func:`load_state_dict` with
1425+
``strict=True`` are affected by modifications the hook makes to
1426+
``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
1427+
set of keys will result in an error being thrown when ``strict=True``, and
1428+
clearning out both missing and unexpected keys will avoid an error.
1429+
1430+
Returns:
1431+
:class:`torch.utils.hooks.RemovableHandle`:
1432+
a handle that can be used to remove the added hook by calling
1433+
``handle.remove()``
1434+
"""
1435+
handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)
1436+
self._load_state_dict_post_hooks[handle.id] = hook
1437+
return handle
1438+
1439+
14061440
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
14071441
missing_keys, unexpected_keys, error_msgs):
14081442
r"""Copies parameters and buffers from :attr:`state_dict` into only
@@ -1540,6 +1574,16 @@ def load(module, prefix=''):
15401574
if child is not None:
15411575
load(child, prefix + name + '.')
15421576

1577+
# Note that the hook can modify missing_keys and unexpected_keys.
1578+
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
1579+
for hook in module._load_state_dict_post_hooks.values():
1580+
out = hook(module, incompatible_keys)
1581+
assert out is None, (
1582+
"Hooks registered with ``register_load_state_dict_post_hook`` are not"
1583+
"expected to return new values, if incompatible_keys need to be modified,"
1584+
"it should be done inplace."
1585+
)
1586+
15431587
load(self)
15441588
del load
15451589

0 commit comments

Comments
 (0)