-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial Commit of a self-supervised framework bootstrapped from nnU-Net
- Loading branch information
Showing
209 changed files
with
6,027 additions
and
11,549 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
{ | ||
// Use IntelliSense to learn about possible attributes. | ||
// Hover to view descriptions of existing attributes. | ||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | ||
"version": "0.2.0", | ||
"configurations": [ | ||
{ | ||
"name": "Python: Current File", | ||
"type": "python", | ||
"request": "launch", | ||
"program": "${file}", | ||
"console": "integratedTerminal", | ||
"justMyCode": true | ||
}, | ||
{ | ||
"name": "Python: Preprocess SSL Dataset", | ||
"type": "python", | ||
"request": "launch", | ||
"program": "${workspaceFolder}/nnssl/experiment_planning/plan_and_preprocess_entrypoints.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": true, | ||
"args": [ | ||
"-d", | ||
"601", | ||
"-fpe", | ||
"DatasetFingerprintExtractor", | ||
"-npfp", | ||
"16", | ||
"--clean", | ||
] | ||
}, | ||
{ | ||
"name": "Python: Training SSL Dataset", | ||
"type": "python", | ||
"request": "launch", | ||
"program": "${workspaceFolder}/nnssl/run/run_training.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": true, | ||
"args": [ | ||
"601", | ||
"3d_fullres", | ||
"-tr", | ||
"nnsslDummyMAETrainer", | ||
] | ||
} | ||
] | ||
} |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from nnssl.architectures.get_network_from_plans import get_network_from_plans | ||
from torch import nn | ||
|
||
from nnssl.experiment_planning.experiment_planners.plan import ConfigurationPlan, Plan | ||
|
||
|
||
def build_network_architecture( | ||
config_plan: ConfigurationPlan, | ||
num_input_channels: int, | ||
num_output_channels: int, | ||
) -> nn.Module: | ||
""" | ||
This is where you build the architecture according to the plans. There is no obligation to use | ||
get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what | ||
you want. Even ignore the plans and just return something static (as long as it can process the requested | ||
patch size) | ||
but don't bug us with your bugs arising from fiddling with this :-P | ||
This is the function that is called in inference as well! This is needed so that all network architecture | ||
variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for | ||
training, so if you change the network architecture during training by deriving a new trainer class then | ||
inference will know about it). | ||
If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet: | ||
> label_manager = plans_manager.get_label_manager(dataset_json) | ||
> label_manager.num_segmentation_heads | ||
(why so complicated? -> We can have either classical training (classes) or regions. If we have regions, | ||
the number of outputs is != the number of classes. Also there is the ignore label for which no output | ||
should be generated. label_manager takes care of all that for you.) | ||
""" | ||
return get_network_from_plans( | ||
configuration_plan=config_plan, | ||
num_input_channels=num_input_channels, | ||
num_output_channels=num_output_channels, | ||
deep_supervision=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from torch import nn | ||
|
||
|
||
def convert_to_spark_cnn(architecture: nn.Module): | ||
raise NotImplementedError("TODO: implement convert_to_spark_cnn") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet | ||
from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op | ||
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 | ||
from nnssl.experiment_planning.experiment_planners.plan import ConfigurationPlan, Plan | ||
from nnssl.utilities.network_initialization import InitWeights_He | ||
from torch import nn | ||
|
||
|
||
def get_network_from_plans( | ||
configuration_plan: ConfigurationPlan, | ||
num_input_channels: int, | ||
num_output_channels: int, | ||
deep_supervision: bool = True, | ||
): | ||
""" | ||
we may have to change this in the future to accommodate other plans -> network mappings | ||
num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the | ||
trainer rather than inferring it again from the plans here. | ||
""" | ||
num_stages = len(configuration_plan.conv_kernel_sizes) | ||
|
||
dim = len(configuration_plan.conv_kernel_sizes[0]) | ||
conv_op = convert_dim_to_conv_op(dim) | ||
|
||
segmentation_network_class_name = configuration_plan.UNet_class_name | ||
mapping = {"PlainConvUNet": PlainConvUNet, "ResidualEncoderUNet": ResidualEncoderUNet} | ||
kwargs = { | ||
"PlainConvUNet": { | ||
"conv_bias": True, | ||
"norm_op": get_matching_instancenorm(conv_op), | ||
"norm_op_kwargs": {"eps": 1e-5, "affine": True}, | ||
"dropout_op": None, | ||
"dropout_op_kwargs": None, | ||
"nonlin": nn.LeakyReLU, | ||
"nonlin_kwargs": {"inplace": True}, | ||
}, | ||
"ResidualEncoderUNet": { | ||
"conv_bias": True, | ||
"norm_op": get_matching_instancenorm(conv_op), | ||
"norm_op_kwargs": {"eps": 1e-5, "affine": True}, | ||
"dropout_op": None, | ||
"dropout_op_kwargs": None, | ||
"nonlin": nn.LeakyReLU, | ||
"nonlin_kwargs": {"inplace": True}, | ||
}, | ||
} | ||
assert segmentation_network_class_name in mapping.keys(), ( | ||
"The network architecture specified by the plans file " | ||
"is non-standard (maybe your own?). Yo'll have to dive " | ||
"into either this " | ||
"function (get_network_from_plans) or " | ||
"the init of your nnUNetModule to accommodate that." | ||
) | ||
network_class = mapping[segmentation_network_class_name] | ||
|
||
conv_or_blocks_per_stage = { | ||
"n_conv_per_stage" | ||
if network_class != ResidualEncoderUNet | ||
else "n_blocks_per_stage": configuration_plan.n_conv_per_stage_encoder, | ||
"n_conv_per_stage_decoder": configuration_plan.n_conv_per_stage_decoder, | ||
} | ||
# network class name!! | ||
model = network_class( | ||
input_channels=num_input_channels, | ||
n_stages=num_stages, | ||
features_per_stage=[ | ||
min(configuration_plan.UNet_base_num_features * 2**i, configuration_plan.unet_max_num_features) | ||
for i in range(num_stages) | ||
], | ||
conv_op=conv_op, | ||
kernel_sizes=configuration_plan.conv_kernel_sizes, | ||
strides=configuration_plan.pool_op_kernel_sizes, | ||
num_classes=num_output_channels, | ||
deep_supervision=deep_supervision, | ||
**conv_or_blocks_per_stage, | ||
**kwargs[segmentation_network_class_name], | ||
) | ||
model.apply(InitWeights_He(1e-2)) | ||
if network_class == ResidualEncoderUNet: | ||
model.apply(init_last_bn_before_add_to_0) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.