Closed
Description
When loading a loss that has a neural network with no parameters, the reloading fails
model = torch.nn.Tanh() # does not work
# model = torch.nn.Linear(1, 1) works
value = QValueActor(module=model, in_keys="obs", action_space="one_hot")
loss = DQNLoss(value_network=model, action_space="one_hot")
state = loss.state_dict()
loss = DQNLoss(value_network=model, action_space="one_hot")
loss.load_state_dict(state)
Traceback (most recent call last):
File "/Users/matbet/PycharmProjects/rl/prova.py", line 16, in <module>
loss.load_state_dict(state)
File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2027, in load_state_dict
load(self, state_dict)
File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2015, in load
load(child, child_state_dict, child_prefix)
File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2009, in load
module._load_from_state_dict(
File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/params.py", line 792, in _load_from_state_dict
self.data.load_state_dict(data)
File "/Users/matbet/PycharmProjects/tensordict/tensordict/tensordict.py", line 834, in load_state_dict
raise RuntimeError(
RuntimeError: Cannot load state-dict because the key sets don't match: got state_dict extra keys
set()
and tensordict extra keys
{'module'}
an example use case is the VDN module in MARL which is just a sum of the input and will cause this in the QMixerLoss