From 4b18b0171d9500c2d914d1940dcd46f182a5527f Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 17 Oct 2024 14:01:29 +0200 Subject: [PATCH] remove SSDA --- clinicadl/commandline/modules_options/ssda.py | 45 -- .../pipelines/train/classification/cli.py | 7 - .../pipelines/train/reconstruction/cli.py | 7 - .../pipelines/train/regression/cli.py | 7 - clinicadl/config/config/ssda.py | 41 - clinicadl/nn/networks/__init__.py | 1 - clinicadl/predictor/predictor.py | 79 -- .../random_search/random_search_utils.py | 1 - clinicadl/resources/config/train_config.toml | 1 - clinicadl/trainer/config/train.py | 4 - clinicadl/trainer/trainer.py | 701 +----------------- clinicadl/utils/cli_param/option.py | 7 - tests/unittests/nn/networks/test_ssda.py | 11 - tests/unittests/train/test_utils.py | 3 - .../train/trainer/test_training_config.py | 26 - 15 files changed, 38 insertions(+), 903 deletions(-) delete mode 100644 clinicadl/commandline/modules_options/ssda.py delete mode 100644 clinicadl/config/config/ssda.py delete mode 100644 tests/unittests/nn/networks/test_ssda.py diff --git a/clinicadl/commandline/modules_options/ssda.py b/clinicadl/commandline/modules_options/ssda.py deleted file mode 100644 index 8119726ef..000000000 --- a/clinicadl/commandline/modules_options/ssda.py +++ /dev/null @@ -1,45 +0,0 @@ -import click - -from clinicadl.config.config.ssda import SSDAConfig -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type - -# SSDA -caps_target = click.option( - "--caps_target", - "-d", - type=get_type("caps_target", SSDAConfig), - default=get_default("caps_target", SSDAConfig), - help="CAPS of target data.", - show_default=True, -) -preprocessing_json_target = click.option( - "--preprocessing_json_target", - "-d", - type=get_type("preprocessing_json_target", SSDAConfig), - default=get_default("preprocessing_json_target", SSDAConfig), - help="Path to json target.", - show_default=True, -) -ssda_network = click.option( - "--ssda_network/--single_network", - default=get_default("ssda_network", SSDAConfig), - help="If provided uses a ssda-network framework.", - show_default=True, -) -tsv_target_lab = click.option( - "--tsv_target_lab", - "-d", - type=get_type("tsv_target_lab", SSDAConfig), - default=get_default("tsv_target_lab", SSDAConfig), - help="TSV of labeled target data.", - show_default=True, -) -tsv_target_unlab = click.option( - "--tsv_target_unlab", - "-d", - type=get_type("tsv_target_unlab", SSDAConfig), - default=get_default("tsv_target_unlab", SSDAConfig), - help="TSV of unllabeled target data.", - show_default=True, -) diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index d4a3b68a1..8ac287402 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -63,12 +62,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda option -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py index d63bf63f8..fc39ef54e 100644 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ b/clinicadl/commandline/pipelines/train/reconstruction/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -63,12 +62,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda option -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py index ff6dd68ca..59e816192 100644 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ b/clinicadl/commandline/pipelines/train/regression/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -61,12 +60,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda o -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split diff --git a/clinicadl/config/config/ssda.py b/clinicadl/config/config/ssda.py deleted file mode 100644 index caf52634d..000000000 --- a/clinicadl/config/config/ssda.py +++ /dev/null @@ -1,41 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Any, Dict - -from pydantic import BaseModel, ConfigDict, computed_field - -from clinicadl.utils.iotools.utils import read_preprocessing - -logger = getLogger("clinicadl.ssda_config") - - -class SSDAConfig(BaseModel): - """Config class to perform SSDA.""" - - caps_target: Path = Path("") - preprocessing_json_target: Path = Path("") - ssda_network: bool = False - tsv_target_lab: Path = Path("") - tsv_target_unlab: Path = Path("") - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - @computed_field - @property - def preprocessing_dict_target(self) -> Dict[str, Any]: - """ - Gets the preprocessing dictionary from a target preprocessing json file. - - Returns - ------- - Dict[str, Any] - The preprocessing dictionary. - """ - if not self.ssda_network: - return {} - - preprocessing_json_target = ( - self.caps_target / "tensor_extraction" / self.preprocessing_json_target - ) - - return read_preprocessing(preprocessing_json_target) diff --git a/clinicadl/nn/networks/__init__.py b/clinicadl/nn/networks/__init__.py index c77097e60..3b88830fb 100644 --- a/clinicadl/nn/networks/__init__.py +++ b/clinicadl/nn/networks/__init__.py @@ -8,7 +8,6 @@ resnet18, ) from .random import RandomArchitecture -from .ssda import Conv5_FC3_SSDA from .unet import UNet from .vae import ( CVAE_3D, diff --git a/clinicadl/predictor/predictor.py b/clinicadl/predictor/predictor.py index d686c944d..30fbbe5b8 100644 --- a/clinicadl/predictor/predictor.py +++ b/clinicadl/predictor/predictor.py @@ -1044,85 +1044,6 @@ def _test_loader( data_group=data_group, ) - def _test_loader_ssda( - self, - maps_manager: MapsManager, - dataloader, - criterion, - alpha, - data_group, - split, - selection_metrics, - use_labels=True, - gpu=None, - network=None, - target=False, - report_ci=True, - ): - """ - Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. - - Args: - dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. - criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. - data_group (str): name of the data group used for the testing task. - split (int): Index of the split used to train the model tested. - selection_metrics (list[str]): List of metrics used to select the best models which are tested. - use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. - gpu (bool): If given, a new value for the device of the model will be computed. - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - log_dir = ( - maps_manager.maps_path - / f"split-{split}" - / f"best-{selection_metric}" - / data_group - ) - maps_manager.write_description_log( - log_dir, - data_group, - dataloader.dataset.caps_dict, - dataloader.dataset.df, - ) - - # load the best trained model during the training - model, _ = maps_manager._init_model( - transfer_path=maps_manager.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - ) - prediction_df, metrics = self.test_da( - network_task=maps_manager.network_task, - model=model, - dataloader=dataloader, - criterion=criterion, - target=target, - report_ci=report_ci, - mode=maps_manager.mode, - metrics_module=maps_manager.metrics_module, - n_classes=maps_manager.n_classes, - ) - if use_labels: - if network is not None: - metrics[f"{maps_manager.mode}_id"] = network - - if report_ci: - loss_to_log = metrics["Metric_values"][-1] - else: - loss_to_log = metrics["loss"] - - logger.info( - f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" - ) - - # Replace here - maps_manager._mode_level_to_tsv( - prediction_df, metrics, split, selection_metric, data_group=data_group - ) - @torch.no_grad() def _compute_output_tensors( self, diff --git a/clinicadl/random_search/random_search_utils.py b/clinicadl/random_search/random_search_utils.py index ed164ea0c..f8f3bca9a 100644 --- a/clinicadl/random_search/random_search_utils.py +++ b/clinicadl/random_search/random_search_utils.py @@ -124,7 +124,6 @@ def random_sampling(rs_options: Dict[str, Any]) -> Dict[str, Any]: "mode": "fixed", "multi_cohort": "fixed", "multi_network": "choice", - "ssda_netork": "fixed", "n_fcblocks": "randint", "n_splits": "fixed", "n_proc": "fixed", diff --git a/clinicadl/resources/config/train_config.toml b/clinicadl/resources/config/train_config.toml index f4f2afe30..9e5f54657 100644 --- a/clinicadl/resources/config/train_config.toml +++ b/clinicadl/resources/config/train_config.toml @@ -4,7 +4,6 @@ [Model] architecture = "default" # ex : Conv5_FC3 multi_network = false -ssda_network = false [Architecture] # CNN diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index a1e949997..30a92c92a 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -14,7 +14,6 @@ from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.config.config.lr_scheduler import LRschedulerConfig from clinicadl.config.config.reproducibility import ReproducibilityConfig -from clinicadl.config.config.ssda import SSDAConfig from clinicadl.maps_manager.config import MapsManagerConfig from clinicadl.network.config import NetworkConfig from clinicadl.optimizer.optimization import OptimizationConfig @@ -50,7 +49,6 @@ class TrainConfig(BaseModel, ABC): optimizer: OptimizerConfig reproducibility: ReproducibilityConfig split: SplitConfig - ssda: SSDAConfig transfer_learning: TransferLearningConfig transforms: TransformsConfig validation: ValidationConfig @@ -77,7 +75,6 @@ def __init__(self, **kwargs): optimizer=kwargs, reproducibility=kwargs, split=kwargs, - ssda=kwargs, transfer_learning=kwargs, transforms=kwargs, validation=kwargs, @@ -97,7 +94,6 @@ def _update(self, config_dict: Dict[str, Any]) -> None: self.optimizer.__dict__.update(config_dict) self.reproducibility.__dict__.update(config_dict) self.split.__dict__.update(config_dict) - self.ssda.__dict__.update(config_dict) self.transfer_learning.__dict__.update(config_dict) self.transforms.__dict__.update(config_dict) self.validation.__dict__.update(config_dict) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 4ab344144..775ecd2c6 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -234,35 +234,31 @@ def train( If splits specified in input already exist and overwrite is False. """ - if self.config.ssda.ssda_network: - self._train_ssda(split_list, resume=False) - - else: - # splitter_config = SplitterConfig(**self.config.get_dict()) - # self.splitter = Splitter(splitter_config) - # self.splitter.check_split_list(self.config.maps_manager.maps_dir, self.config.maps_manager.overwrite) - self.splitter.check_split_list( - self.config.maps_manager.maps_dir, - overwrite, # overwrite change so careful it is not the maps manager overwrite parameters here + # splitter_config = SplitterConfig(**self.config.get_dict()) + # self.splitter = Splitter(splitter_config) + # self.splitter.check_split_list(self.config.maps_manager.maps_dir, self.config.maps_manager.overwrite) + self.splitter.check_split_list( + self.config.maps_manager.maps_dir, + overwrite, # overwrite change so careful it is not the maps manager overwrite parameters here + ) + for split in self.splitter.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, ) - for split in self.splitter.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - split_df_dict = self.splitter[split] + split_df_dict = self.splitter[split] - if self.config.model.multi_network: - resume, first_network = self.init_first_network(False, split) - for network in range(first_network, self.maps_manager.num_networks): - self._train_single( - split, split_df_dict, network=network, resume=resume - ) - else: - self._train_single(split, split_df_dict, resume=False) + if self.config.model.multi_network: + resume, first_network = self.init_first_network(False, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=False) # def check_split_list(self, split_list, overwrite): # existing_splits = [] @@ -315,26 +311,23 @@ def _resume( f"Please try train command on these splits and resume only others." ) - if self.config.ssda.ssda_network: - self._train_ssda(split_list, resume=True) - else: - for split in self.splitter.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) + for split in self.splitter.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, + ) - split_df_dict = self.splitter[split] - if self.config.model.multi_network: - resume, first_network = self.init_first_network(True, split) - for network in range(first_network, self.maps_manager.num_networks): - self._train_single( - split, split_df_dict, network=network, resume=resume - ) - else: - self._train_single(split, split_df_dict, resume=True) + split_df_dict = self.splitter[split] + if self.config.model.multi_network: + resume, first_network = self.init_first_network(True, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=True) def init_first_network(self, resume: bool, split: int): first_network = 0 @@ -478,218 +471,6 @@ def _train_single( self.maps_manager._erase_tmp(split) - def _train_ssda( - self, - split_list: Optional[List[int]] = None, - resume: bool = False, - ) -> None: - """ - Trains a single CNN for a source and target domain using semi-supervised domain adaptation. - - Parameters - ---------- - split_list : Optional[List[int]] (optional, default=None) - List of splits on which the training task is performed. - If None, performs training on all splits of the cross-validation. - resume : bool (optional, default=False) - If True, the job is resumed from checkpoint. - """ - - splitter_config = SplitterConfig(**self.config.get_dict()) - - self.splitter = Splitter(splitter_config) - self.splitter_target_lab = Splitter(splitter_config) - - for split in self.splitter.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - - split_df_dict = self.splitter[split] - split_df_dict_target_lab = self.splitter_target_lab[split] - - logger.debug("Loading source training data...") - data_train_source = return_dataset( - self.config.data.caps_directory, - split_df_dict["train"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - - logger.debug("Loading target labelled training data...") - data_train_target_labeled = return_dataset( - Path(self.config.ssda.caps_target), # TO CHECK - split_df_dict_target_lab["train"], - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, # A checker - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - from torch.utils.data import ConcatDataset - - combined_dataset = ConcatDataset( - [data_train_source, data_train_target_labeled] - ) - - logger.debug("Loading target unlabelled training data...") - data_target_unlabeled = return_dataset( - Path(self.config.ssda.caps_target), - pd.read_csv(self.config.ssda.tsv_target_unlab, sep="\t"), - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, # A checker - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - - logger.debug("Loading validation source data...") - data_valid_source = return_dataset( - self.config.data.caps_directory, - split_df_dict["validation"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - logger.debug("Loading validation target labelled data...") - data_valid_target_labeled = return_dataset( - Path(self.config.ssda.caps_target), - split_df_dict_target_lab["validation"], - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - train_source_sampler = generate_sampler( - self.maps_manager.network_task, - data_train_source, - self.config.dataloader.sampler, - ) - - logger.info( - f"Getting train and validation loader with batch size {self.config.dataloader.batch_size}" - ) - - ## Oversampling of the target dataset - from torch.utils.data import SubsetRandomSampler - - # Create index lists for target labeled dataset - labeled_indices = list(range(len(data_train_target_labeled))) - - # Oversample the indices for the target labelled dataset to match the size of the labeled source dataset - data_train_source_size = ( - len(data_train_source) // self.config.dataloader.batch_size - ) - labeled_oversampled_indices = labeled_indices * ( - data_train_source_size // len(labeled_indices) - ) - - # Append remaining indices to match the size of the largest dataset - labeled_oversampled_indices += labeled_indices[ - : data_train_source_size % len(labeled_indices) - ] - - # Create SubsetRandomSamplers using the oversampled indices - labeled_sampler = SubsetRandomSampler(labeled_oversampled_indices) - - train_source_loader = DataLoader( - data_train_source, - batch_size=self.config.dataloader.batch_size, - sampler=train_source_sampler, - # shuffle=True, # len(data_train_source) < len(data_train_target_labeled), - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - drop_last=True, - ) - logger.info( - f"Train source loader size is {len(train_source_loader)*self.config.dataloader.batch_size}" - ) - train_target_loader = DataLoader( - data_train_target_labeled, - batch_size=1, # To limit the need of oversampling - # sampler=train_target_sampler, - sampler=labeled_sampler, - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - # shuffle=True, # len(data_train_target_labeled) < len(data_train_source), - drop_last=True, - ) - logger.info( - f"Train target labeled loader size oversample is {len(train_target_loader)}" - ) - - data_train_target_labeled.df = data_train_target_labeled.df[ - ["participant_id", "session_id", "diagnosis", "cohort", "domain"] - ] - - train_target_unl_loader = DataLoader( - data_target_unlabeled, - batch_size=self.config.dataloader.batch_size, - num_workers=self.config.dataloader.n_proc, - # sampler=unlabeled_sampler, - worker_init_fn=pl_worker_init_function, - shuffle=True, - drop_last=True, - ) - - logger.info( - f"Train target unlabeled loader size is {len(train_target_unl_loader)*self.config.dataloader.batch_size}" - ) - - valid_loader_source = DataLoader( - data_valid_source, - batch_size=self.config.dataloader.batch_size, - shuffle=False, - num_workers=self.config.dataloader.n_proc, - ) - logger.info( - f"Validation loader source size is {len(valid_loader_source)*self.config.dataloader.batch_size}" - ) - - valid_loader_target = DataLoader( - data_valid_target_labeled, - batch_size=self.config.dataloader.batch_size, # To check - shuffle=False, - num_workers=self.config.dataloader.n_proc, - ) - logger.info( - f"Validation loader target size is {len(valid_loader_target)*self.config.dataloader.batch_size}" - ) - - self._train_ssdann( - train_source_loader, - train_target_loader, - train_target_unl_loader, - valid_loader_target, - valid_loader_source, - split, - resume=resume, - ) - - self.validator._ensemble_prediction( - self.maps_manager, - "train", - split, - self.config.validation.selection_metrics, - ) - self.validator._ensemble_prediction( - self.maps_manager, - "validation", - split, - self.config.validation.selection_metrics, - ) - - self.maps_manager._erase_tmp(split) - def _train( self, train_loader: DataLoader, @@ -1011,412 +792,6 @@ def _train( self.callback_handler.on_train_end(parameters=self.maps_manager.parameters) - def _train_ssdann( - self, - train_source_loader: DataLoader, - train_target_loader: DataLoader, - train_target_unl_loader: DataLoader, - valid_loader: DataLoader, - valid_source_loader: DataLoader, - split: int, - network: Optional[Any] = None, - resume: bool = False, - evaluate_source: bool = True, # TO MODIFY - ): - """ - _summary_ - - Parameters - ---------- - train_source_loader : torch.utils.data.DataLoader - _description_ - train_target_loader : torch.utils.data.DataLoader - _description_ - train_target_unl_loader : torch.utils.data.DataLoader - _description_ - valid_loader : torch.utils.data.DataLoader - _description_ - valid_source_loader : torch.utils.data.DataLoader - _description_ - split : int - _description_ - network : Optional[Any] (optional, default=None) - _description_ - resume : bool (optional, default=False) - _description_ - evaluate_source : bool (optional, default=True) - _description_ - - Raises - ------ - Exception - _description_ - """ - model, beginning_epoch = self.maps_manager._init_model( - split=split, - resume=resume, - transfer_path=self.config.transfer_learning.transfer_path, - transfer_selection=self.config.transfer_learning.transfer_selection_metric, - ) - - criterion = get_criterion( - self.maps_manager.network_task, self.config.model.loss - ) - logger.debug(f"Criterion for {self.config.network_task} is {criterion}") - optimizer = self._init_optimizer(model, split=split, resume=resume) - - logger.debug(f"Optimizer used for training is optimizer") - - model.train() - train_source_loader.dataset.train() - train_target_loader.dataset.train() - train_target_unl_loader.dataset.train() - - early_stopping = EarlyStopping( - "min", - min_delta=self.config.early_stopping.tolerance, - patience=self.config.early_stopping.patience, - ) - - metrics_valid_target = {"loss": None} - metrics_valid_source = {"loss": None} - - log_writer = LogWriter( - self.maps_manager.maps_path, - evaluation_metrics(self.maps_manager.network_task) + ["loss"], - split, - resume=resume, - beginning_epoch=beginning_epoch, - network=network, - ) - epoch = log_writer.beginning_epoch - - retain_best = RetainBest( - selection_metrics=list(self.config.validation.selection_metrics) - ) - import numpy as np - - while epoch < self.config.optimization.epochs and not early_stopping.step( - metrics_valid_target["loss"] - ): - logger.info(f"Beginning epoch {epoch}.") - - model.zero_grad() - evaluation_flag, step_flag = True, True - - for i, (data_source, data_target, data_target_unl) in enumerate( - zip(train_source_loader, train_target_loader, train_target_unl_loader) - ): - p = ( - float(epoch * len(train_target_loader)) - / 10 - / len(train_target_loader) - ) - alpha = 2.0 / (1.0 + np.exp(-10 * p)) - 1 - # alpha = 0 - _, _, loss_dict = model.compute_outputs_and_loss( - data_source, data_target, data_target_unl, criterion, alpha - ) # TO CHECK - logger.debug(f"Train loss dictionary {loss_dict}") - loss = loss_dict["loss"] - loss.backward() - if (i + 1) % self.config.optimization.accumulation_steps == 0: - step_flag = False - optimizer.step() - optimizer.zero_grad() - - del loss - - # Evaluate the model only when no gradients are accumulated - if ( - self.config.validation.evaluation_steps != 0 - and (i + 1) % self.config.validation.evaluation_steps == 0 - ): - evaluation_flag = False - - # Evaluate on target data - logger.info("Evaluation on target data") - ( - _, - metrics_train_target, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_target_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) # TO CHECK - - ( - _, - metrics_valid_target, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - - model.train() - train_target_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_target, - metrics_valid_target, - len(train_target_loader), - "training_target.tsv", - ) - logger.info( - f"{self.config.data.mode} level training loss for target data is {metrics_train_target['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for target data is {metrics_valid_target['loss']} " - f"at the end of iteration {i}" - ) - - # Evaluate on source data - logger.info("Evaluation on source data") - ( - _, - metrics_train_source, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_source_loader, - criterion=criterion, - alpha=alpha, - ) - ( - _, - metrics_valid_source, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_source_loader, - criterion=criterion, - alpha=alpha, - ) - - model.train() - train_source_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_source, - metrics_valid_source, - len(train_source_loader), - ) - logger.info( - f"{self.config.data.mode} level training loss for source data is {metrics_train_source['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for source data is {metrics_valid_source['loss']} " - f"at the end of iteration {i}" - ) - - # If no step has been performed, raise Exception - if step_flag: - raise Exception( - "The model has not been updated once in the epoch. The accumulation step may be too large." - ) - - # If no evaluation has been performed, warn the user - elif evaluation_flag and self.config.validation.evaluation_steps != 0: - logger.warning( - f"Your evaluation steps {self.config.validation.evaluation_steps} are too big " - f"compared to the size of the dataset. " - f"The model is evaluated only once at the end epochs." - ) - - # Update weights one last time if gradients were computed without update - if (i + 1) % self.config.optimization.accumulation_steps != 0: - optimizer.step() - optimizer.zero_grad() - # Always test the results and save them once at the end of the epoch - model.zero_grad() - logger.debug(f"Last checkpoint at the end of the epoch {epoch}") - - if evaluate_source: - logger.info( - f"Evaluate source data at the end of the epoch {epoch} with alpha: {alpha}." - ) - _, metrics_train_source = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_source_loader, - criterion=criterion, - alpha=alpha, - target=True, - report_ci=False, - ) - _, metrics_valid_source = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_source_loader, - criterion=criterion, - alpha=alpha, - target=True, - report_ci=False, - ) - - log_writer.step( - epoch, - i, - metrics_train_source, - metrics_valid_source, - len(train_source_loader), - ) - - logger.info( - f"{self.config.data.mode} level training loss for source data is {metrics_train_source['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for source data is {metrics_valid_source['loss']} " - f"at the end of iteration {i}" - ) - - _, metrics_train_target = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_target_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - _, metrics_valid_target = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - - model.train() - train_source_loader.dataset.train() - train_target_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_target, - metrics_valid_target, - len(train_target_loader), - "training_target.tsv", - ) - - logger.info( - f"{self.config.data.mode} level training loss for target data is {metrics_train_target['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for target data is {metrics_valid_target['loss']} " - f"at the end of iteration {i}" - ) - - # Save checkpoints and best models - best_dict = retain_best.step(metrics_valid_target) - self.maps_manager._write_weights( - { - "model": model.state_dict(), - "epoch": epoch, - "name": self.config.model.architecture, - }, - best_dict, - split, - network=network, - save_all_models=False, - ) - self.maps_manager._write_weights( - { - "optimizer": optimizer.state_dict(), # TO MODIFY - "epoch": epoch, - "name": self.config.optimizer, - }, - None, - split, - filename="optimizer.pth.tar", - save_all_models=False, - ) - - epoch += 1 - - self.validator._test_loader_ssda( - self.maps_manager, - train_target_loader, - criterion, - data_group="train", - split=split, - selection_metrics=self.config.validation.selection_metrics, - network=network, - target=True, - alpha=0, - ) - self.validator._test_loader_ssda( - self.maps_manager, - valid_loader, - criterion, - data_group="validation", - split=split, - selection_metrics=self.config.validation.selection_metrics, - network=network, - target=True, - alpha=0, - ) - - if save_outputs(self.maps_manager.network_task): - self.validator._compute_output_tensors( - self.maps_manager, - train_target_loader.dataset, - "train", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - self.validator._compute_output_tensors( - self.maps_manager, - train_target_loader.dataset, - "validation", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - def _init_callbacks(self) -> None: """ Initializes training callbacks. diff --git a/clinicadl/utils/cli_param/option.py b/clinicadl/utils/cli_param/option.py index 6ff86cda2..75438ceda 100644 --- a/clinicadl/utils/cli_param/option.py +++ b/clinicadl/utils/cli_param/option.py @@ -58,13 +58,6 @@ multiple=True, default=None, ) -ssda_network = click.option( - "--ssda_network", - type=bool, - default=False, - show_default=True, - help="ssda training.", -) valid_longitudinal = click.option( "--valid_longitudinal/--valid_baseline", type=bool, diff --git a/tests/unittests/nn/networks/test_ssda.py b/tests/unittests/nn/networks/test_ssda.py deleted file mode 100644 index 06da85ff2..000000000 --- a/tests/unittests/nn/networks/test_ssda.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from clinicadl.nn.networks.ssda import Conv5_FC3_SSDA - - -def test_UNet(): - input_ = torch.randn(2, 1, 64, 63, 62) - network = Conv5_FC3_SSDA(input_size=(1, 64, 63, 62), output_size=3) - output = network(input_) - for out in output: - assert out.shape == torch.Size((2, 3)) diff --git a/tests/unittests/train/test_utils.py b/tests/unittests/train/test_utils.py index 6b33787eb..2914d2d9b 100644 --- a/tests/unittests/train/test_utils.py +++ b/tests/unittests/train/test_utils.py @@ -7,7 +7,6 @@ expected_classification = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, @@ -65,7 +64,6 @@ expected_regression = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, @@ -121,7 +119,6 @@ expected_reconstruction = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index b25dc20bb..c6b130cb8 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -5,7 +5,6 @@ from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config.ssda import SSDAConfig from clinicadl.network.config import NetworkConfig from clinicadl.predictor.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig @@ -70,31 +69,6 @@ def test_model_config(): ) -def test_ssda_config(caps_example): - preprocessing_json_target = ( - caps_example / "tensor_extraction" / "preprocessing.json" - ) - c = SSDAConfig( - ssda_network=True, - preprocessing_json_target=preprocessing_json_target, - ) - expected_preprocessing_dict = { - "preprocessing": "t1-linear", - "mode": "image", - "use_uncropped_image": False, - "prepare_dl": False, - "extract_json": "t1-linear_mode-image.json", - "file_type": { - "pattern": "*space-MNI152NLin2009cSym_desc-Crop_res-1x1x1_T1w.nii.gz", - "description": "T1W Image registered using t1-linear and cropped (matrix size 169\u00d7208\u00d7179, 1 mm isotropic voxels)", - "needed_pipeline": "t1-linear", - }, - } - assert c.preprocessing_dict_target == expected_preprocessing_dict - c = SSDAConfig() - assert c.preprocessing_dict_target == {} - - def test_transferlearning_config(): c = TransferLearningConfig(transfer_path=False) assert c.transfer_path is None