Closed
Description
In load_state function, the order of missing and unexpected is not consistent with load_state method
def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs):
"""Copy parameters and buffers from :attr:`state_dict` into
this module and its descendants.
Args:
target: DynamicalSystem. The dynamical system to load its states.
state_dict: dict. A dict containing parameters and persistent buffers.
Returns:
-------
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique()
missing_keys = []
unexpected_keys = []
for name, node in nodes.items():
r = node.load_state(state_dict[name], **kwargs)
if r is not None:
missing, unexpected = r
missing_keys.extend([f'{name}.{key}' for key in missing])
unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
return StateLoadResult(missing_keys, unexpected_keys)
in BrainPyObject
def __load_state__(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
"""Load states from the external objects."""
variables = self.vars(include_self=True, level=0).unique()
keys1 = set(state_dict.keys())
keys2 = set(variables.keys())
for key in keys2.intersection(keys1):
variables[key].value = jax.numpy.asarray(state_dict[key])
unexpected_keys = list(keys1 - keys2)
missing_keys = list(keys2 - keys1)
return unexpected_keys, missing_keys