Skip to content

Commit

Permalink
Initial Commit of a self-supervised framework bootstrapped from nnU-Net
Browse files Browse the repository at this point in the history
  • Loading branch information
TaWald committed Jan 2, 2024
1 parent 6309155 commit 184191e
Show file tree
Hide file tree
Showing 209 changed files with 6,027 additions and 11,549 deletions.
47 changes: 47 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: Preprocess SSL Dataset",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/nnssl/experiment_planning/plan_and_preprocess_entrypoints.py",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"-d",
"601",
"-fpe",
"DatasetFingerprintExtractor",
"-npfp",
"16",
"--clean",
]
},
{
"name": "Python: Training SSL Dataset",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/nnssl/run/run_training.py",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"601",
"3d_fullres",
"-tr",
"nnsslDummyMAETrainer",
]
}
]
}
File renamed without changes.
File renamed without changes.
36 changes: 36 additions & 0 deletions nnssl/architectures/build_architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from nnssl.architectures.get_network_from_plans import get_network_from_plans
from torch import nn

from nnssl.experiment_planning.experiment_planners.plan import ConfigurationPlan, Plan


def build_network_architecture(
config_plan: ConfigurationPlan,
num_input_channels: int,
num_output_channels: int,
) -> nn.Module:
"""
This is where you build the architecture according to the plans. There is no obligation to use
get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what
you want. Even ignore the plans and just return something static (as long as it can process the requested
patch size)
but don't bug us with your bugs arising from fiddling with this :-P
This is the function that is called in inference as well! This is needed so that all network architecture
variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for
training, so if you change the network architecture during training by deriving a new trainer class then
inference will know about it).
If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet:
> label_manager = plans_manager.get_label_manager(dataset_json)
> label_manager.num_segmentation_heads
(why so complicated? -> We can have either classical training (classes) or regions. If we have regions,
the number of outputs is != the number of classes. Also there is the ignore label for which no output
should be generated. label_manager takes care of all that for you.)
"""
return get_network_from_plans(
configuration_plan=config_plan,
num_input_channels=num_input_channels,
num_output_channels=num_output_channels,
deep_supervision=False,
)
5 changes: 5 additions & 0 deletions nnssl/architectures/convert_to_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from torch import nn


def convert_to_spark_cnn(architecture: nn.Module):
raise NotImplementedError("TODO: implement convert_to_spark_cnn")
82 changes: 82 additions & 0 deletions nnssl/architectures/get_network_from_plans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet
from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
from nnssl.experiment_planning.experiment_planners.plan import ConfigurationPlan, Plan
from nnssl.utilities.network_initialization import InitWeights_He
from torch import nn


def get_network_from_plans(
configuration_plan: ConfigurationPlan,
num_input_channels: int,
num_output_channels: int,
deep_supervision: bool = True,
):
"""
we may have to change this in the future to accommodate other plans -> network mappings
num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the
trainer rather than inferring it again from the plans here.
"""
num_stages = len(configuration_plan.conv_kernel_sizes)

dim = len(configuration_plan.conv_kernel_sizes[0])
conv_op = convert_dim_to_conv_op(dim)

segmentation_network_class_name = configuration_plan.UNet_class_name
mapping = {"PlainConvUNet": PlainConvUNet, "ResidualEncoderUNet": ResidualEncoderUNet}
kwargs = {
"PlainConvUNet": {
"conv_bias": True,
"norm_op": get_matching_instancenorm(conv_op),
"norm_op_kwargs": {"eps": 1e-5, "affine": True},
"dropout_op": None,
"dropout_op_kwargs": None,
"nonlin": nn.LeakyReLU,
"nonlin_kwargs": {"inplace": True},
},
"ResidualEncoderUNet": {
"conv_bias": True,
"norm_op": get_matching_instancenorm(conv_op),
"norm_op_kwargs": {"eps": 1e-5, "affine": True},
"dropout_op": None,
"dropout_op_kwargs": None,
"nonlin": nn.LeakyReLU,
"nonlin_kwargs": {"inplace": True},
},
}
assert segmentation_network_class_name in mapping.keys(), (
"The network architecture specified by the plans file "
"is non-standard (maybe your own?). Yo'll have to dive "
"into either this "
"function (get_network_from_plans) or "
"the init of your nnUNetModule to accommodate that."
)
network_class = mapping[segmentation_network_class_name]

conv_or_blocks_per_stage = {
"n_conv_per_stage"
if network_class != ResidualEncoderUNet
else "n_blocks_per_stage": configuration_plan.n_conv_per_stage_encoder,
"n_conv_per_stage_decoder": configuration_plan.n_conv_per_stage_decoder,
}
# network class name!!
model = network_class(
input_channels=num_input_channels,
n_stages=num_stages,
features_per_stage=[
min(configuration_plan.UNet_base_num_features * 2**i, configuration_plan.unet_max_num_features)
for i in range(num_stages)
],
conv_op=conv_op,
kernel_sizes=configuration_plan.conv_kernel_sizes,
strides=configuration_plan.pool_op_kernel_sizes,
num_classes=num_output_channels,
deep_supervision=deep_supervision,
**conv_or_blocks_per_stage,
**kwargs[segmentation_network_class_name],
)
model.apply(InitWeights_He(1e-2))
if network_class == ResidualEncoderUNet:
model.apply(init_last_bn_before_add_to_0)
return model
2 changes: 1 addition & 1 deletion nnunetv2/configuration.py → nnssl/configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnssl.utilities.default_n_proc_DA import get_allowed_n_proc_DA

default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import shutil
from pathlib import Path

from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
from nnssl.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnssl.paths import nnUNet_raw


def make_out_dirs(dataset_id: int, task_name="ACDC"):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy

from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json
from nnunetv2.paths import nnUNet_raw
from nnssl.paths import nnUNet_raw


def convert(source_folder, target_dataset_name):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile

from nnunetv2.configuration import default_num_processes
from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from nnssl.configuration import default_num_processes
from nnssl.evaluation.evaluate_predictions import compute_metrics_on_folder
from nnssl.paths import nnUNet_raw, nnssl_preprocessed
from nnssl.utilities.plans_handling.plans_handler import PlansManager


def accumulate_cv_results(trained_model_folder,
Expand Down Expand Up @@ -46,7 +46,7 @@ def accumulate_cv_results(trained_model_folder,
label_manager = plans_manager.get_label_manager(dataset_json)
gt_folder = join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr')
if not isdir(gt_folder):
gt_folder = join(nnUNet_preprocessed, plans_manager.dataset_name, 'gt_segmentations')
gt_folder = join(nnssl_preprocessed, plans_manager.dataset_name, 'gt_segmentations')
compute_metrics_on_folder(gt_folder,
merged_output_folder,
join(merged_output_folder, 'summary.json'),
Expand Down
Loading

0 comments on commit 184191e

Please sign in to comment.