@@ -267,6 +267,7 @@ def __init__(self) -> None:
267
267
self ._forward_pre_hooks : Dict [int , Callable ] = OrderedDict ()
268
268
self ._state_dict_hooks : Dict [int , Callable ] = OrderedDict ()
269
269
self ._load_state_dict_pre_hooks : Dict [int , Callable ] = OrderedDict ()
270
+ self ._load_state_dict_post_hooks : Dict [int , Callable ] = OrderedDict ()
270
271
self ._modules : Dict [str , Optional ['Module' ]] = OrderedDict ()
271
272
272
273
forward : Callable [..., Any ] = _forward_unimplemented
@@ -1183,6 +1184,8 @@ def __setstate__(self, state):
1183
1184
self ._state_dict_hooks = OrderedDict ()
1184
1185
if '_load_state_dict_pre_hooks' not in self .__dict__ :
1185
1186
self ._load_state_dict_pre_hooks = OrderedDict ()
1187
+ if '_load_state_dict_post_hooks' not in self .__dict__ :
1188
+ self ._load_state_dict_post_hooks = OrderedDict ()
1186
1189
if '_non_persistent_buffers_set' not in self .__dict__ :
1187
1190
self ._non_persistent_buffers_set = set ()
1188
1191
if '_is_full_backward_hook' not in self .__dict__ :
@@ -1403,6 +1406,37 @@ def _register_load_state_dict_pre_hook(self, hook, with_module=False):
1403
1406
self ._load_state_dict_pre_hooks [handle .id ] = hook
1404
1407
return handle
1405
1408
1409
+ def register_load_state_dict_post_hook (self , hook ):
1410
+ r"""Registers a post hook to be run after module's ``load_state_dict``
1411
+ is called.
1412
+
1413
+ It should have the following signature::
1414
+ hook(module, incompatible_keys) -> None
1415
+
1416
+ The ``module`` argument is the current module that this hook is registered
1417
+ on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
1418
+ of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
1419
+ is a ``list`` of ``str`` containing the missing keys and
1420
+ ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
1421
+
1422
+ The given incompatible_keys can be modified inplace if needed.
1423
+
1424
+ Note that the checks performed when calling :func:`load_state_dict` with
1425
+ ``strict=True`` are affected by modifications the hook makes to
1426
+ ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
1427
+ set of keys will result in an error being thrown when ``strict=True``, and
1428
+ clearning out both missing and unexpected keys will avoid an error.
1429
+
1430
+ Returns:
1431
+ :class:`torch.utils.hooks.RemovableHandle`:
1432
+ a handle that can be used to remove the added hook by calling
1433
+ ``handle.remove()``
1434
+ """
1435
+ handle = hooks .RemovableHandle (self ._load_state_dict_post_hooks )
1436
+ self ._load_state_dict_post_hooks [handle .id ] = hook
1437
+ return handle
1438
+
1439
+
1406
1440
def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
1407
1441
missing_keys , unexpected_keys , error_msgs ):
1408
1442
r"""Copies parameters and buffers from :attr:`state_dict` into only
@@ -1540,6 +1574,16 @@ def load(module, prefix=''):
1540
1574
if child is not None :
1541
1575
load (child , prefix + name + '.' )
1542
1576
1577
+ # Note that the hook can modify missing_keys and unexpected_keys.
1578
+ incompatible_keys = _IncompatibleKeys (missing_keys , unexpected_keys )
1579
+ for hook in module ._load_state_dict_post_hooks .values ():
1580
+ out = hook (module , incompatible_keys )
1581
+ assert out is None , (
1582
+ "Hooks registered with ``register_load_state_dict_post_hook`` are not"
1583
+ "expected to return new values, if incompatible_keys need to be modified,"
1584
+ "it should be done inplace."
1585
+ )
1586
+
1543
1587
load (self )
1544
1588
del load
1545
1589
0 commit comments