Skip to content

[BUG] Loading losses with modules that have no parameters #1593

Closed
pytorch/tensordict
#650
@matteobettini

Description

@matteobettini

When loading a loss that has a neural network with no parameters, the reloading fails

  model = torch.nn.Tanh() # does not work
  # model = torch.nn.Linear(1, 1) works
  value = QValueActor(module=model, in_keys="obs", action_space="one_hot")
  loss = DQNLoss(value_network=model, action_space="one_hot")
  state = loss.state_dict()

  loss = DQNLoss(value_network=model, action_space="one_hot")
  loss.load_state_dict(state)
Traceback (most recent call last):
  File "/Users/matbet/PycharmProjects/rl/prova.py", line 16, in <module>
    loss.load_state_dict(state)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2027, in load_state_dict
    load(self, state_dict)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2015, in load
    load(child, child_state_dict, child_prefix)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2009, in load
    module._load_from_state_dict(
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/params.py", line 792, in _load_from_state_dict
    self.data.load_state_dict(data)
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/tensordict.py", line 834, in load_state_dict
    raise RuntimeError(
RuntimeError: Cannot load state-dict because the key sets don't match: got state_dict extra keys 
set()
 and tensordict extra keys
{'module'}

an example use case is the VDN module in MARL which is just a sum of the input and will cause this in the QMixerLoss

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions