Skip to content

Commit

Permalink
Sortir les étapes de validation du MapsManager (#657)
Browse files Browse the repository at this point in the history
* first try to create a validator

* test to integrate validator config
  • Loading branch information
camillebrianceau authored Sep 26, 2024
1 parent 17ef236 commit 8a1589e
Show file tree
Hide file tree
Showing 7 changed files with 614 additions and 472 deletions.
17 changes: 17 additions & 0 deletions clinicadl/API_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData
from clinicadl.trainer.config.classification import ClassificationConfig
from clinicadl.trainer.trainer import Trainer
from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task
from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options

image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method(
extraction=ExtractionMethod.IMAGE,
preprocessing_type=Preprocessing.T1_LINEAR,
)

DeepLearningPrepareData(image_config)

config = ClassificationConfig()
trainer = Trainer(config)
trainer.train(split_list=config.cross_validation.split, overwrite=True)
273 changes: 0 additions & 273 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import pandas as pd
import torch
import torch.distributed as dist
from torch.amp import autocast

from clinicadl.caps_dataset.caps_dataset_utils import read_json
from clinicadl.caps_dataset.data import (
Expand All @@ -17,16 +15,13 @@
from clinicadl.metrics.metric_module import MetricModule
from clinicadl.metrics.utils import (
check_selection_metric,
find_selection_metrics,
)
from clinicadl.predict.utils import get_prediction
from clinicadl.trainer.tasks_utils import (
ensemble_prediction,
evaluation_metrics,
generate_label_code,
output_size,
test,
test_da,
)
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils import cluster
Expand Down Expand Up @@ -149,274 +144,6 @@ def __getattr__(self, name):
###################################
# High-level functions templates #
###################################
def _test_loader(
self,
dataloader,
criterion,
data_group: str,
split: int,
selection_metrics,
use_labels=True,
gpu=None,
amp=False,
network=None,
report_ci=True,
):
"""
Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files.
Args:
dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset.
criterion (torch.nn.modules.loss._Loss): optimization criterion used during training.
data_group (str): name of the data group used for the testing task.
split (int): Index of the split used to train the model tested.
selection_metrics (list[str]): List of metrics used to select the best models which are tested.
use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed.
gpu (bool): If given, a new value for the device of the model will be computed.
amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage).
network (int): Index of the network tested (only used in multi-network setting).
"""
for selection_metric in selection_metrics:
if cluster.master:
log_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
self.write_description_log(
log_dir,
data_group,
dataloader.dataset.config.data.caps_dict,
dataloader.dataset.config.data.data_df,
)

# load the best trained model during the training
model, _ = self._init_model(
transfer_path=self.maps_path,
split=split,
transfer_selection=selection_metric,
gpu=gpu,
network=network,
)
model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp)

prediction_df, metrics = test(
mode=self.mode,
metrics_module=self.metrics_module,
n_classes=self.n_classes,
network_task=self.network_task,
model=model,
dataloader=dataloader,
criterion=criterion,
use_labels=use_labels,
amp=amp,
report_ci=report_ci,
)
if use_labels:
if network is not None:
metrics[f"{self.mode}_id"] = network

loss_to_log = (
metrics["Metric_values"][-1] if report_ci else metrics["loss"]
)

logger.info(
f"{self.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}"
)

if cluster.master:
# Replace here
self._mode_level_to_tsv(
prediction_df,
metrics,
split,
selection_metric,
data_group=data_group,
)

def _test_loader_ssda(
self,
dataloader,
criterion,
alpha,
data_group,
split,
selection_metrics,
use_labels=True,
gpu=None,
network=None,
target=False,
report_ci=True,
):
"""
Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files.
Args:
dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset.
criterion (torch.nn.modules.loss._Loss): optimization criterion used during training.
data_group (str): name of the data group used for the testing task.
split (int): Index of the split used to train the model tested.
selection_metrics (list[str]): List of metrics used to select the best models which are tested.
use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed.
gpu (bool): If given, a new value for the device of the model will be computed.
network (int): Index of the network tested (only used in multi-network setting).
"""
for selection_metric in selection_metrics:
log_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
self.write_description_log(
log_dir,
data_group,
dataloader.dataset.caps_dict,
dataloader.dataset.df,
)

# load the best trained model during the training
model, _ = self._init_model(
transfer_path=self.maps_path,
split=split,
transfer_selection=selection_metric,
gpu=gpu,
network=network,
)
prediction_df, metrics = test_da(
self.network_task,
model,
dataloader,
criterion,
target=target,
report_ci=report_ci,
)
if use_labels:
if network is not None:
metrics[f"{self.mode}_id"] = network

if report_ci:
loss_to_log = metrics["Metric_values"][-1]
else:
loss_to_log = metrics["loss"]

logger.info(
f"{self.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}"
)

# Replace here
self._mode_level_to_tsv(
prediction_df, metrics, split, selection_metric, data_group=data_group
)

@torch.no_grad()
def _compute_output_tensors(
self,
dataset,
data_group,
split,
selection_metrics,
nb_images=None,
gpu=None,
network=None,
):
"""
Compute the output tensors and saves them in the MAPS.
Args:
dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set.
data_group (str): name of the data group used for the task.
split (int): split number.
selection_metrics (list[str]): metrics used for model selection.
nb_images (int): number of full images to write. Default computes the outputs of the whole data set.
gpu (bool): If given, a new value for the device of the model will be computed.
network (int): Index of the network tested (only used in multi-network setting).
"""
for selection_metric in selection_metrics:
# load the best trained model during the training
model, _ = self._init_model(
transfer_path=self.maps_path,
split=split,
transfer_selection=selection_metric,
gpu=gpu,
network=network,
nb_unfrozen_layer=self.nb_unfrozen_layer,
)
model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp)
model.eval()

tensor_path = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
/ "tensors"
)
if cluster.master:
tensor_path.mkdir(parents=True, exist_ok=True)
dist.barrier()

if nb_images is None: # Compute outputs for the whole data set
nb_modes = len(dataset)
else:
nb_modes = nb_images * dataset.elem_per_image

for i in [
*range(cluster.rank, nb_modes, cluster.world_size),
*range(int(nb_modes % cluster.world_size <= cluster.rank)),
]:
data = dataset[i]
image = data["image"]
x = image.unsqueeze(0).to(model.device)
with autocast("cuda", enabled=self.std_amp):
output = model(x)
output = output.squeeze(0).cpu().float()
participant_id = data["participant_id"]
session_id = data["session_id"]
mode_id = data[f"{self.mode}_id"]
input_filename = (
f"{participant_id}_{session_id}_{self.mode}-{mode_id}_input.pt"
)
output_filename = (
f"{participant_id}_{session_id}_{self.mode}-{mode_id}_output.pt"
)
torch.save(image, tensor_path / input_filename)
torch.save(output, tensor_path / output_filename)
logger.debug(f"File saved at {[input_filename, output_filename]}")

def _ensemble_prediction(
self,
data_group,
split,
selection_metrics,
use_labels=True,
skip_leak_check=False,
):
"""Computes the results on the image-level."""

if not selection_metrics:
selection_metrics = find_selection_metrics(
self.maps_path, self.split_name, split
)

for selection_metric in selection_metrics:
#####################
# Soft voting
if self.num_networks > 1 and not skip_leak_check:
self._ensemble_to_tsv(
split,
selection=selection_metric,
data_group=data_group,
use_labels=use_labels,
)
elif self.mode != "image" and not skip_leak_check:
self._mode_to_image_tsv(
split,
selection=selection_metric,
data_group=data_group,
use_labels=use_labels,
)

###############################
# Checks #
Expand Down
27 changes: 17 additions & 10 deletions clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ClinicaDLDataLeakageError,
MAPSError,
)
from clinicadl.validator.validator import Validator

logger = getLogger("clinicadl.predict_manager")
level_list: List[str] = ["warning", "info", "debug"]
Expand All @@ -38,6 +39,7 @@ class PredictManager:
def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None:
self.maps_manager = MapsManager(_config.maps_dir)
self._config = _config
self.validator = Validator()

def predict(
self,
Expand Down Expand Up @@ -183,7 +185,8 @@ def predict(
split_selection_metrics,
)
if cluster.master:
self.maps_manager._ensemble_prediction(
self.validator._ensemble_prediction(
self.maps_manager,
self._config.data_group,
split,
self._config.selection_metrics,
Expand Down Expand Up @@ -288,20 +291,22 @@ def _predict_multi(
if self._config.n_proc is not None
else self.maps_manager.n_proc,
)
self.maps_manager._test_loader(
test_loader,
criterion,
self._config.data_group,
split,
split_selection_metrics,
self.validator._test_loader(
maps_manager=self.maps_manager,
dataloader=test_loader,
criterion=criterion,
data_group=self._config.data_group,
split=split,
selection_metrics=split_selection_metrics,
use_labels=self._config.use_labels,
gpu=self._config.gpu,
amp=self._config.amp,
network=network,
)
if self._config.save_tensor:
logger.debug("Saving tensors")
self.maps_manager._compute_output_tensors(
self.validator._compute_output_tensors(
self.maps_manager,
data_test,
self._config.data_group,
split,
Expand Down Expand Up @@ -416,7 +421,8 @@ def _predict_single(
if self._config.n_proc is not None
else self.maps_manager.n_proc,
)
self.maps_manager._test_loader(
self.validator._test_loader(
self.maps_manager,
test_loader,
criterion,
self._config.data_group,
Expand All @@ -428,7 +434,8 @@ def _predict_single(
)
if self._config.save_tensor:
logger.debug("Saving tensors")
self.maps_manager._compute_output_tensors(
self.validator._compute_output_tensors(
self.maps_manager,
data_test,
self._config.data_group,
split,
Expand Down
Loading

0 comments on commit 8a1589e

Please sign in to comment.