Closed
Description
Describe the bug
some weights are not handled.
To Reproduce
Using monai-generative==0.2.2
and dev branch of MONAI to compare:
from generative.networks.nets import DiffusionModelUNet
from monai.networks.nets.diffusion_model_unet import DiffusionModelUNet as MonaiDiffusionModelUNet
import torch
input_params = {
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_channels": [32, 64, 128, 256],
"attention_levels": [False, True, True, True],
"num_head_channels": [0, 32, 32, 32],
"num_res_blocks": 2,
}
old_network = DiffusionModelUNet(**input_params)
new_params = input_params.copy()
new_params.pop("num_channels")
new_params["channels"] = input_params["num_channels"]
new_params["include_fc"] = False
new_params["use_combined_linear"] = False
new_network = MonaiDiffusionModelUNet(**new_params)
for k in new_network.state_dict().keys():
if "to_q.weight" in k:
print(k)
for k in old_network.state_dict().keys():
if "to_q.weight" in k:
print(k)
new network has layers like:
down_blocks.1.attentions.0.attn.to_q.weight
down_blocks.1.attentions.1.attn.to_q.weight
old network has layers like:
down_blocks.1.attentions.0.to_q.weight
down_blocks.1.attentions.1.to_q.weight
However, the load_old_state_dict
function does not handle these layers
Metadata
Metadata
Assignees
Labels
No labels