Skip to content

Commit

Permalink
Remove TaskManager (#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau authored Sep 11, 2024
1 parent e4f6f80 commit 574f7a0
Show file tree
Hide file tree
Showing 20 changed files with 1,175 additions and 1,114 deletions.
20 changes: 12 additions & 8 deletions clinicadl/caps_dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _get_full_image(self) -> torch.Tensor:

try:
image_path = self._get_image_path(participant_id, session_id, cohort)
image = torch.load(image_path)
image = torch.load(image_path, weights_only=True)
except IndexError:
file_type = self.config.extraction.file_type
results = clinicadl_file_reader(
Expand Down Expand Up @@ -316,7 +316,7 @@ def __getitem__(self, idx):
participant, session, cohort, _, label, domain = self._get_meta_data(idx)

image_path = self._get_image_path(participant, session, cohort)
image = torch.load(image_path)
image = torch.load(image_path, weights_only=True)

train_trf, trf = self.config.transforms.get_transforms()

Expand Down Expand Up @@ -385,10 +385,12 @@ def __getitem__(self, idx):
self.config.extraction.stride_size,
patch_idx,
)
patch_tensor = torch.load(Path(patch_dir).resolve() / patch_filename)
patch_tensor = torch.load(
Path(patch_dir).resolve() / patch_filename, weights_only=True
)

else:
image = torch.load(image_path)
image = torch.load(image_path, weights_only=True)
patch_tensor = extract_patch_tensor(
image,
self.config.extraction.patch_size,
Expand Down Expand Up @@ -504,10 +506,10 @@ def __getitem__(self, idx):
roi_filename = extract_roi_path(
image_path, mask_path, self.config.extraction.roi_uncrop_output
)
roi_tensor = torch.load(Path(roi_dir) / roi_filename)
roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True)

else:
image = torch.load(image_path)
image = torch.load(image_path, weights_only=True)
mask_array = self.mask_arrays[roi_idx]
roi_tensor = extract_roi_tensor(
image, mask_array, self.config.extraction.uncropped_roi
Expand Down Expand Up @@ -653,11 +655,13 @@ def __getitem__(self, idx):
self.config.extraction.slice_mode,
slice_idx,
)
slice_tensor = torch.load(Path(slice_dir) / slice_filename)
slice_tensor = torch.load(
Path(slice_dir) / slice_filename, weights_only=True
)

else:
image_path = self._get_image_path(participant, session, cohort)
image = torch.load(image_path)
image = torch.load(image_path, weights_only=True)
slice_tensor = extract_slice_tensor(
image,
self.config.extraction.slice_direction,
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/commandline/pipelines/train/resume/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from clinicadl.commandline.modules_options import (
cross_validation,
)
from clinicadl.trainer import Trainer
from clinicadl.trainer.trainer import Trainer


@click.command(name="resume", no_args_is_help=True)
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/interpret/gradients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc

import torch
from torch.cuda.amp import autocast
from torch.amp import autocast

from clinicadl.utils.exceptions import ClinicaDLArgumentError

Expand All @@ -28,7 +28,7 @@ def generate_gradients(
# Forward
input_batch = input_batch.to(self.device)
input_batch.requires_grad = True
with autocast(enabled=amp):
with autocast("cuda", enabled=amp):
if hasattr(self.model, "variational") and self.model.variational:
_, _, _, model_output = self.model(input_batch)
else:
Expand Down Expand Up @@ -94,7 +94,7 @@ def generate_gradients(
# Get last conv feature map
feature_maps = conv_part(input_batch).detach()
feature_maps.requires_grad = True
with autocast(enabled=amp):
with autocast("cuda", enabled=amp):
model_output = fc_part(pre_fc_part(feature_maps))
# Target for backprop
one_hot_output = torch.zeros_like(model_output)
Expand Down
163 changes: 120 additions & 43 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@
import pandas as pd
import torch
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch.amp import autocast

from clinicadl.caps_dataset.caps_dataset_utils import read_json
from clinicadl.caps_dataset.data import (
return_dataset,
)
from clinicadl.metrics.metric_module import MetricModule
from clinicadl.metrics.utils import (
check_selection_metric,
find_selection_metrics,
)
from clinicadl.predict.utils import get_prediction
from clinicadl.trainer.tasks_utils import (
ensemble_prediction,
evaluation_metrics,
generate_label_code,
output_size,
test,
test_da,
)
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils import cluster
from clinicadl.utils.computational.ddp import DDP, init_ddp
Expand All @@ -43,7 +52,7 @@ class MapsManager:
def __init__(
self,
maps_path: Path,
parameters: Dict[str, Any] = None,
parameters: Optional[Dict[str, Any]] = None,
verbose: str = "info",
):
"""
Expand All @@ -68,8 +77,38 @@ def __init__(
)
test_parameters = self.get_parameters()
# test_parameters = path_decoder(test_parameters)
# from clinicadl.trainer.task_manager import TaskConfig

self.parameters = add_default_values(test_parameters)
self.task_manager = self._init_task_manager(n_classes=self.output_size)

## to initialize the task parameters

# self.task_manager = self._init_task_manager()

self.n_classes = self.output_size
if self.network_task == "classification":
if self.n_classes is None:
self.n_classes = output_size(
self.network_task, None, self.df, self.label
)
self.metrics_module = MetricModule(
evaluation_metrics(self.network_task), n_classes=self.n_classes
)

elif (
self.network_task == "regression"
or self.network_task == "reconstruction"
):
self.metrics_module = MetricModule(
evaluation_metrics(self.network_task), n_classes=None
)

else:
raise NotImplementedError(
f"Task {self.network_task} is not implemented in ClinicaDL. "
f"Please choose between classification, regression and reconstruction."
)

self.split_name = (
self._check_split_wording()
) # Used only for retro-compatibility
Expand Down Expand Up @@ -162,10 +201,14 @@ def _test_loader(
)
model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp)

prediction_df, metrics = self.task_manager.test(
model,
dataloader,
criterion,
prediction_df, metrics = test(
mode=self.mode,
metrics_module=self.metrics_module,
n_classes=self.n_classes,
network_task=self.network_task,
model=model,
dataloader=dataloader,
criterion=criterion,
use_labels=use_labels,
amp=amp,
report_ci=report_ci,
Expand Down Expand Up @@ -241,8 +284,13 @@ def _test_loader_ssda(
gpu=gpu,
network=network,
)
prediction_df, metrics = self.task_manager.test_da(
model, dataloader, criterion, target=target, report_ci=report_ci
prediction_df, metrics = test_da(
self.network_task,
model,
dataloader,
criterion,
target=target,
report_ci=report_ci,
)
if use_labels:
if network is not None:
Expand Down Expand Up @@ -321,7 +369,7 @@ def _compute_output_tensors(
data = dataset[i]
image = data["image"]
x = image.unsqueeze(0).to(model.device)
with autocast(enabled=self.std_amp):
with autocast("cuda", enabled=self.std_amp):
output = model(x)
output = output.squeeze(0).cpu().float()
participant_id = data["participant_id"]
Expand Down Expand Up @@ -404,19 +452,42 @@ def _check_args(self, parameters):
if "label" not in self.parameters:
self.parameters["label"] = None

self.task_manager = self._init_task_manager(df=train_df)
from clinicadl.trainer.tasks_utils import (
get_default_network,
)
from clinicadl.utils.enum import Task

self.network_task = Task(self.parameters["network_task"])
# self.task_config = TaskConfig(self.network_task, self.mode, df=train_df)
# self.task_manager = self._init_task_manager(df=train_df)
if self.network_task == "classification":
self.n_classes = output_size(self.network_task, None, train_df, self.label)
self.metrics_module = MetricModule(
evaluation_metrics(self.network_task), n_classes=self.n_classes
)

elif self.network_task == "regression" or self.network_task == "reconstruction":
self.n_classes = None
self.metrics_module = MetricModule(
evaluation_metrics(self.network_task), n_classes=None
)

else:
raise NotImplementedError(
f"Task {self.network_task} is not implemented in ClinicaDL. "
f"Please choose between classification, regression and reconstruction."
)
if self.parameters["architecture"] == "default":
self.parameters["architecture"] = self.task_manager.get_default_network()
self.parameters["architecture"] = get_default_network(self.network_task)
if "selection_threshold" not in self.parameters:
self.parameters["selection_threshold"] = None
if (
"label_code" not in self.parameters
or len(self.parameters["label_code"]) == 0
or self.parameters["label_code"] is None
): # Allows to set custom label code in TOML
self.parameters["label_code"] = self.task_manager.generate_label_code(
train_df, self.label
self.parameters["label_code"] = generate_label_code(
self.network_task, train_df, self.label
)

full_dataset = return_dataset(
Expand All @@ -431,8 +502,8 @@ def _check_args(self, parameters):
self.parameters.update(
{
"num_networks": full_dataset.elem_per_image,
"output_size": self.task_manager.output_size(
full_dataset.size, full_dataset.df, self.label
"output_size": output_size(
self.network_task, full_dataset.size, full_dataset.df, self.label
),
"input_size": full_dataset.size,
}
Expand All @@ -444,7 +515,7 @@ def _check_args(self, parameters):
f"framework with only {self.parameters['num_networks']} element "
f"per image."
)
possible_selection_metrics_set = set(self.task_manager.evaluation_metrics) | {
possible_selection_metrics_set = set(evaluation_metrics(self.network_task)) | {
"loss"
}
if not set(self.parameters["selection_metrics"]).issubset(
Expand Down Expand Up @@ -708,7 +779,11 @@ def _ensemble_to_tsv(

performance_dir.mkdir(parents=True, exist_ok=True)

df_final, metrics = self.task_manager.ensemble_prediction(
df_final, metrics = ensemble_prediction(
self.mode,
self.metrics_module,
self.n_classes,
self.network_task,
test_df,
validation_df,
selection_threshold=self.selection_threshold,
Expand Down Expand Up @@ -839,7 +914,9 @@ def _init_model(
/ "tmp"
/ "checkpoint.pth.tar"
)
checkpoint_state = torch.load(checkpoint_path, map_location=device)
checkpoint_state = torch.load(
checkpoint_path, map_location=device, weights_only=True
)
model.load_state_dict(checkpoint_state["model"])
current_epoch = checkpoint_state["epoch"]
elif transfer_path:
Expand Down Expand Up @@ -912,29 +989,29 @@ def _init_split_manager_ssda(self, caps_dir, tsv_dir, split_list=None):

return split_class(**kwargs)

def _init_task_manager(
self, df: Optional[pd.DataFrame] = None, n_classes: Optional[int] = None
):
from clinicadl.utils.task_manager import (
ClassificationManager,
ReconstructionManager,
RegressionManager,
)

if self.network_task == "classification":
if n_classes is not None:
return ClassificationManager(self.mode, n_classes=n_classes)
else:
return ClassificationManager(self.mode, df=df, label=self.label)
elif self.network_task == "regression":
return RegressionManager(self.mode)
elif self.network_task == "reconstruction":
return ReconstructionManager(self.mode)
else:
raise NotImplementedError(
f"Task {self.network_task} is not implemented in ClinicaDL. "
f"Please choose between classification, regression and reconstruction."
)
# def _init_task_manager(
# self, df: Optional[pd.DataFrame] = None, n_classes: Optional[int] = None
# ):
# from clinicadl.utils.task_manager import (
# ClassificationManager,
# ReconstructionManager,
# RegressionManager,
# )

# if self.network_task == "classification":
# if n_classes is not None:
# return ClassificationManager(self.mode, n_classes=n_classes)
# else:
# return ClassificationManager(self.mode, df=df, label=self.label)
# elif self.network_task == "regression":
# return RegressionManager(self.mode)
# elif self.network_task == "reconstruction":
# return ReconstructionManager(self.mode)
# else:
# raise NotImplementedError(
# f"Task {self.network_task} is not implemented in ClinicaDL. "
# f"Please choose between classification, regression and reconstruction."
# )

###############################
# Getters #
Expand Down Expand Up @@ -1054,7 +1131,7 @@ def get_state_dict(
f"selected according to best validation {selection_metric} "
f"at path {model_path}."
)
return torch.load(model_path, map_location=map_location)
return torch.load(model_path, map_location=map_location, weights_only=True)

@property
def std_amp(self) -> bool:
Expand Down
1 change: 1 addition & 0 deletions clinicadl/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List

import numpy as np
from sklearn.utils import resample

metric_optimum = {
"MAE": "min",
Expand Down
Loading

0 comments on commit 574f7a0

Please sign in to comment.