-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
Is your feature request related to a problem? Please describe.
The pretrained flag in the resnet50 function currently only accepts a boolean value. If it's set to True, the function throws an error pointing the user to download MedicalNet weights. Then, the user has to download these and manually load the state_dict into the model.
Describe the solution you'd like
This process can be simplified largely by allowing the pretrained flag to take str values that point to a path and automatically loading state_dict from these paths.
model = resnet50_monai(pretrained=False, n_input_channels=n_input_channels, widen_factor=widen_factor, conv1_t_stride=conv1_t_stride, feed_forward=feed_forward, bias_downsample=bias_downsample)
model = model.to(device)
if pretrained:
if Path(pretrained).exists():
logger.info(f"Loading weights from {weights_path}...")
checkpoint = torch.load(pretrained, map_location=device)
else:
### Throw error
if "state_dict" in checkpoint:
model_state_dict = checkpoint["state_dict"]
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict, strict=True)This can help with loading the MedicalNet weights more easily and potentially open up this API to loading other models that have been trained with monai as well. If we don't want to support duck typing, then we could also add another flag to the function.
Additional context
I do understand that monai is moving in a different direction with loading bundles etc. but since this pretrained flag still exists, it might be useful to simplify it.