Skip to content

load_old_state_dict of DiffusionModelUNet is wrong #8029

Closed
@yiheng-wang-nv

Description

@yiheng-wang-nv

Describe the bug
some weights are not handled.

To Reproduce
Using monai-generative==0.2.2 and dev branch of MONAI to compare:

from generative.networks.nets import DiffusionModelUNet
from monai.networks.nets.diffusion_model_unet import DiffusionModelUNet as MonaiDiffusionModelUNet
import torch

input_params = {
    "spatial_dims": 2,
    "in_channels": 1,
    "out_channels": 1,
    "num_channels": [32, 64, 128, 256],
    "attention_levels": [False, True, True, True],
    "num_head_channels": [0, 32, 32, 32],
    "num_res_blocks": 2,
}

old_network = DiffusionModelUNet(**input_params)

new_params = input_params.copy()
new_params.pop("num_channels")
new_params["channels"] = input_params["num_channels"]
new_params["include_fc"] = False
new_params["use_combined_linear"] = False

new_network = MonaiDiffusionModelUNet(**new_params)

for k in new_network.state_dict().keys():
    if "to_q.weight" in k:
        print(k)

for k in old_network.state_dict().keys():
    if "to_q.weight" in k:
        print(k)

new network has layers like:

down_blocks.1.attentions.0.attn.to_q.weight
down_blocks.1.attentions.1.attn.to_q.weight

old network has layers like:

down_blocks.1.attentions.0.to_q.weight
down_blocks.1.attentions.1.to_q.weight

However, the load_old_state_dict function does not handle these layers

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions