-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Network Arch Code refactoring to allow replacing the last FC and add a freeze core parameter #6552
Comments
Right now the networks which do have final layers and named layers/blocks can be substituted directly, and the optimiser can be given only the parameters for those layers: import monai
import torch
net=monai.networks.nets.SegResNet()
print(net(torch.rand(1,1,8,8,8)).shape) # torch.Size([1, 2, 8, 8, 8])
net.conv_final=monai.networks.blocks.Convolution(3,8,5,kernel_size=1) # 5 output channels instead of 2
print(net(torch.rand(1,1,8,8,8)).shape) # torch.Size([1, 5, 8, 8, 8])
opt = torch.optim.Adam(net.conv_final.parameters()) I'm not sure what more you're looking for here other than some utilities to help doing this and renaming some internals of networks? We could add some helpers but often the code for a forward method is going to need adapting anyway so making a subclass would still make sense. For changing names in existing networks this breaks compatibility with existing saved states so we need to be very careful on what sort of changes we feel are justified for such a refactor. |
Hi @ericspod
May be we can create a class that all networks should implement, where they are forced to add a function is the networks to change the last layer as
For the freezing option, again I don't want the user to know the details of the architecture and the names of which parameters to pass to the optimizer. This will be more complicated is the last layer of FC is a set of layers as in segresnet (see below) So the steps to do fine tunning should be as : May be to clarify this isolation need, if you look in the segresnet code you will see that the last layer is more than just conv, it is normalization relu and conv. User should not worry about these details
I agree with you that "For changing names in existing networks this breaks compatibility with existing saved states so we need to be very careful" I think that adding this interface and having the network implement it should not break check point but I might be wrong. Also exporting ts sometimes has issues with parent classes that I faced before that should be tested |
Hi @AHarouni thanks for the outline. I think it's clear what you want to do but I don't think the changes to networks really justifies this. Even if we can make changes that don't break existing save states, these would be for existing network only and would adhere to a pattern we couldn't expect others to adopt. We had always wanted components in MONAI to be loosely coupled so didn't impose requirements on network architecture. I don't really want to move away from this in exchange for a mechanism that simplifies a relatively small aspect of network training. We have in some networks like the Regressor some aspect of what you want and we could possibly change others, but this covers only one sort of fine tuning that someone might want to do and isn't a general solution to that problem. This works for classifiers where you want to change what is being classified and how many classes by adapting a final layer, but other architectures wouldn't necessarily be designed in a way that permits changing just one aspect of a network so simply. During training you may also want to have gradients through the network as normal and optimise the new layer as normal but the rest of the network at a much slower rate. A As I said I see where you're coming from and what the solution is for your fine tuning strategy, but I don't see it as a general enough solution to be worth the consequences. I'm less worried about requiring users to delve into the workings of a network to figure out how to adapt components and then do fine tuning, I imagine people will need to know about the inner workings quite often to know what can be done anyway except in simple cases. If you wanted to propose a more concrete set of changes or a full PR we'd be happy to discuss it there as well as here, some of the components at least of what you want to add seem doable. Thanks! CC @Nic-Ma @wyli @atbenmurray for any comments? |
Hi @ericspod For me to figure out next steps I would like to know:
If it is ok then the issue is coming up with this interface. I can then propose one and show it works with a fine tune workflow using one network as an example. I can then open issue for each network architecture to create a new network to support that interface. It is up to the researcher to respond or may be nvidia's developers can contribute it. Problem is, the solution I proposed above doesn't work for monai label as the loading of weights are through handlers. I have been thinking more and more about this. The core problem is when loading the check point the last layer dimensions don't match. What do you think about adding a function in the network class that would chop off the weights of the FC or layer the researcher thinks should be finetuned? That way the network creation in the fine tune or the normal training would remain the same, including loading from check point to continue training. |
Hi @AHarouni, |
I just found out this tutorial https://github.com/Project-MONAI/tutorials/blob/main/self_supervised_pretraining/swinunetr_pretrained/swinunetr_finetune.ipynb
|
Hi @AHarouni I see the usefulness of this code for this network but it's implementing behaviour that can be generalised so that the same adaptation can be applied to other networks. This would allow us to filter a state dictionary for any network with a function to filter members of the incoming data using a function or some translation table. Other networks would benefit from the same thing with only minor differences so I think it makes sense as a separate utility function. This way we don't rely on a method of a network being present to do this, this would include non-MONAI networks as well. I don't like the idea of requiring certain implementation details of our networks, we've always said that we want to maintain architectural similarity with Pytorch and compatibility with the existing Pytorch code as much as possible. The methods of a network should be concerned solely with the construction and operation of the network so it mixes purposes I feel to add methods of this type. We would also not be able to use this functionality with non-MONAI networks as I said, but decoupled utilities which make few assumptions about the networks they operate on are more flexible in that they can be used with such networks. I would still suggest that we can define generalised utilities, for example the cell above can be reimplemented as a general function: def load_adapted_state_dict(network, new_weights, filter_func):
# Generate new state dict so it can be loaded to MONAI SwinUNETR Model
model_prior_dict = network.state_dict()
model_update_dict = dict(model_prior_dict)
for key, value in new_weights.items():
new_pair = filter_func(key, value)
if new_pair is not None:
model_update_dict[new_pair[0]] = new_pair[1]
network.load_state_dict(model_update_dict, strict=True)
model_final_loaded_dict = network.state_dict()
# Safeguard test to ensure that weights got loaded successfully
layer_counter = 0
for k, _v in model_final_loaded_dict.items():
if k in model_prior_dict:
layer_counter = layer_counter + 1
old_wts = model_prior_dict[k]
new_wts = model_final_loaded_dict[k]
old_wts = old_wts.to("cpu").numpy()
new_wts = new_wts.to("cpu").numpy()
diff = np.mean(np.abs(old_wts, new_wts))
print("Layer {}, the update difference is: {}".format(k, diff))
if diff == 0.0:
print("Warning: No difference found for layer {}".format(k))
print("Total updated layers {} / {}".format(layer_counter, len(model_prior_dict)))
print("Pretrained Weights Succesfully Loaded !")
def _filter(k, v):
if k in [
"encoder.mask_token",
"encoder.norm.weight",
"encoder.norm.bias",
"out.conv.conv.weight",
"out.conv.conv.bias",
]:
return None
if k[:8] == "encoder.":
if k[8:19] == "patch_embed":
new_key = "swinViT." + k[8:]
else:
new_key = "swinViT." + k[8:18] + k[20:]
return new_key, v
else:
return k, v
load_adapted_state_dict(model, torch.load(pretrained_path)["model"], _filter) |
Hi @ericspod I think the main thing we disagree about is the responsibilities. I believe if I would like to use a network and finetune it I should NOT need to know anything about the network layers and which layers to copy weights to and which to ignore. It should be the responsibility of the researcher who created the network. So basically who should write the |
The general purpose code would be something that does inspect the members of a network like what I proposed here. Researchers providing their own networks can also provide functions using this general purpose code to tweak the members they know should be changed. I think in general though this will suffice for only a very small number of cases where you want to fine tuning or refinement, that is only the things the implementors anticipate someone wanting to do whereas there's probably many other things people would want to do which will require understanding the network's inner workings anyway. Either way the code shouldn't be part of the class definition because it's counter to our architectural ideas and introduces close coupling between the network and fine tuning. It's fine for the implementor of the network to provide their own fine tuning functions, if we have a more involved architectural pattern we can think of what classes or other components to add to MONAI to facilitate this. |
I think this one can be partially addressed by the enhanced |
There is definitely overlap with the concepts here, with the |
Part of #6552. ### Description Add `freeze_layers`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…6917) Part of #6552. ### Description After PR #6835, we have added `copy_model_args` in the `load` API which can help us update the state_dict flexibly. https://github.com/KumoLiu/MONAI/blob/93a149a611b66153cf804b31a7b36a939e2e593a/monai/bundle/scripts.py#L397 Given this [issue](#6552), we need to be able to filter the model's weights flexibly. In `copy_model_state`, we already have a "mapping" arg, the filter will be more flexible if we can support regular expression in the mapping. This PR mainly added the support for regular expression for "mapping" arg. In the [example](#6552 (comment)) in this [issue](#6552), after this PR, we can do something like: ``` exclude_vars = "encoder.mask_token|encoder.norm.weight|encoder.norm.bias|out.conv.conv.weight|out.conv.conv.bias" mapping={"encoder.layers(.*).0.0.": "swinViT.layers(.*).0."} dst_dict, updated_keys, unchanged_keys = copy_model_state( model, ssl_weights, exclude_vars=exclude_vars, mapping=mapping ) ``` Additionally, based on the comments of Eric [here](#6552 (comment)), I totally agree, we could add a handler to make the pipeline easier to implement, but perhaps this task is no need to set as a "BundleTodo" for MONAIv1.3 but as an enhancement for MONAI near future. What do you think? @ericspod @wyli @Nic-Ma ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
Describe the solution you'd like
Once there is a good model out there (either we trained it, from model zoo, or other sources), we would like to easily replace the last FC layer with 104 outputs for total segmentor for example with a different number of outputs. user would get errors when loading the checkpoint and need to do some network surgery.
Second request is usually user would want to load the check point then freeze the core parameters and just train the newly added FC.
Describe alternatives you've considered
Code should be refactored to allow to give the last FC a different name. with that the check point would be loaded without any errors
as in Project-MONAI/MONAILabel#1298 I had to write my own segresnet below to change the name of the last layer
Additional context
We do have code to frezze the layers but it needs to be a utility that calls into a function of the network to get FC layer name. For this we should have an interface for network models to implement so this util function can be called
I have it hacked as
The text was updated successfully, but these errors were encountered: