Skip to content

add automapping function in load_pretrain to fix load weight erorr from mindcv when the feature encoder is unfolded to extract intermediate features #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion mindocr/models/backbones/mindcv_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Some utils while building models
"""
import collections.abc
import difflib
import logging
import os
from itertools import repeat
Expand All @@ -24,7 +25,7 @@ class ConfigDict(dict):
__delattr__ = dict.__delitem__


def load_pretrained(model, default_cfg, num_classes=1000, in_channels=3, filter_fn=None):
def load_pretrained(model, default_cfg, num_classes=1000, in_channels=3, filter_fn=None, auto_mapping=False):
"""load pretrained model depending on cfgs of model"""
if "url" not in default_cfg or not default_cfg["url"]:
logging.warning("Pretrained model URL is invalid")
Expand All @@ -40,6 +41,27 @@ def load_pretrained(model, default_cfg, num_classes=1000, in_channels=3, filter_
except:
print(f'ERROR: Fails to load the checkpoint. Please check whether the checkpoint is downloaded successfully in {download_path} and is not zero-byte. You may try to manually download the checkpoint from ', default_cfg["url"])

if auto_mapping:
net_param = model.get_parameters()
ckpt_param = list(param_dict.keys())
remap = {}
for param in net_param:
if param.name not in ckpt_param:
print('Cannot find a param to load: ', param.name)
poss = difflib.get_close_matches(param.name, ckpt_param, n=3, cutoff=0.6)
if len(poss) > 0:
print('=> Find most matched param: ', poss[0], ', loaded')
param_dict[param.name] = param_dict.pop(poss[0]) # replace
remap[param.name] = poss[0]
else:
raise ValueError('Cannot find any matching param from: ', ckpt_param)

if remap != {}:
print('WARNING: Auto mapping succeed. Please check the found mapping names to ensure correctness')
print('\tNet Param\t<---\tCkpt Param')
for k in remap:
print(f'\t{k}\t<---\t{remap[k]}')

if in_channels == 1:
conv1_name = default_cfg["first_conv"]
logging.info("Converting first conv (%s) from 3 to 1 channel", conv1_name)
Expand Down