diff --git a/clinicadl/caps_dataset/data.py b/clinicadl/caps_dataset/data.py index 638f49e9d..78553bcf9 100644 --- a/clinicadl/caps_dataset/data.py +++ b/clinicadl/caps_dataset/data.py @@ -580,11 +580,11 @@ def _get_mask_paths_and_tensors( else: for template_ in Template: if preprocessing_.name == template_.name: - template_name = template_ + template_name = template_.value for pattern_ in Pattern: if preprocessing_.name == pattern_.name: - pattern = pattern_ + pattern = pattern_.value mask_location = caps_directory / "masks" / f"tpl-{template_name}" diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index dc28d0acd..b7e65234f 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -1,31 +1,20 @@ -from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union import numpy as np import pandas as pd import torch -import torch.distributed as dist -from pydantic import ( - BaseModel, - ConfigDict, - computed_field, - model_validator, -) from torch import Tensor, nn -from torch.amp import autocast from torch.nn.functional import softmax -from torch.nn.modules.loss import _Loss -from torch.utils.data import DataLoader, Sampler, sampler +from torch.utils.data import Sampler, sampler from torch.utils.data.distributed import DistributedSampler from clinicadl.caps_dataset.data import CapsDataset from clinicadl.metrics.metric_module import MetricModule -from clinicadl.network.network import Network from clinicadl.trainer.config.train import TrainConfig -from clinicadl.utils import cluster from clinicadl.utils.enum import ( ClassificationLoss, ClassificationMetric, + Mode, ReconstructionLoss, ReconstructionMetric, RegressionLoss, @@ -249,7 +238,7 @@ def save_outputs(network_task: Union[str, Task]): def generate_test_row( network_task: Union[str, Task], - mode: str, + mode: Mode, metrics_module, n_classes: int, idx: int, @@ -274,7 +263,7 @@ def generate_test_row( [ data["participant_id"][idx], data["session_id"][idx], - data[f"{mode}_id"][idx].item(), + data[f"{mode.value}_id"][idx].item(), data["label"][idx].item(), prediction, ] @@ -286,7 +275,7 @@ def generate_test_row( [ data["participant_id"][idx], data["session_id"][idx], - data[f"{mode}_id"][idx].item(), + data[f"{mode.value}_id"][idx].item(), data["label"][idx].item(), outputs[idx].item(), ] @@ -298,7 +287,7 @@ def generate_test_row( row = [ data["participant_id"][idx], data["session_id"][idx], - data[f"{mode}_id"][idx].item(), + data[f"{mode.value}_id"][idx].item(), ] for metric in evaluation_metrics(Task.RECONSTRUCTION):