Skip to content

Simplify existing resnet API for pretrained flag #7047

@surajpaib

Description

@surajpaib

Is your feature request related to a problem? Please describe.
The pretrained flag in the resnet50 function currently only accepts a boolean value. If it's set to True, the function throws an error pointing the user to download MedicalNet weights. Then, the user has to download these and manually load the state_dict into the model.

Describe the solution you'd like
This process can be simplified largely by allowing the pretrained flag to take str values that point to a path and automatically loading state_dict from these paths.

model = resnet50_monai(pretrained=False, n_input_channels=n_input_channels, widen_factor=widen_factor, conv1_t_stride=conv1_t_stride, feed_forward=feed_forward, bias_downsample=bias_downsample)
model = model.to(device)
if pretrained:
    if Path(pretrained).exists():
        logger.info(f"Loading weights from {weights_path}...")
        checkpoint = torch.load(pretrained, map_location=device)
    else:
         ### Throw error

     if "state_dict" in checkpoint:
        model_state_dict = checkpoint["state_dict"]
        model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
        
    model.load_state_dict(model_state_dict, strict=True)

This can help with loading the MedicalNet weights more easily and potentially open up this API to loading other models that have been trained with monai as well. If we don't want to support duck typing, then we could also add another flag to the function.

Additional context

I do understand that monai is moving in a different direction with loading bundles etc. but since this pretrained flag still exists, it might be useful to simplify it.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions