Skip to content

Commit

Permalink
caps directory optional
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 17, 2024
1 parent c5219f1 commit c2d4a23
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 10 deletions.
2 changes: 1 addition & 1 deletion clinicadl/caps_dataset/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DataConfig(BaseModel): # TODO : put in data module
that must be passed by the user.
"""

caps_directory: Path
caps_directory: Optional[Path] = None
baseline: bool = False
diagnoses: Tuple[str, ...] = ("AD", "CN")
data_df: Optional[pd.DataFrame] = None
Expand Down
6 changes: 1 addition & 5 deletions clinicadl/interpret/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, field_validator

from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig
from clinicadl.caps_dataset.data_config import DataConfig
from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp
from clinicadl.maps_manager.config import MapsManagerConfig as MapsManagerConfigBase
Expand All @@ -30,10 +30,6 @@ def check_output_saving_tensor(self, network_task: str) -> None:
)


class DataConfig(DataBaseConfig):
caps_directory: Optional[Path] = None


class InterpretBaseConfig(BaseModel):
name: str
method: InterpretationMethod = InterpretationMethod.GRADIENTS
Expand Down
1 change: 0 additions & 1 deletion clinicadl/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def interpret(self):
group_df = self._config.data.create_groupe_df()
self._check_data_group(group_df)

assert self._config.split
for split in self.splitter.split_iterator():
logger.info(f"Interpretation of split {split}")
df_group, parameters_group = self.get_group_info(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def run_interpret(cnn_input, tmp_out_dir, ref_dir):
"name": f"test-{method}",
"method_cls": method,
}
options = merge_options_and_maps_json_options(maps_path / "maps.json", **dict_)
interpret_config = InterpretConfig(**options)
# options = merge_options_and_maps_json_options(maps_path / "maps.json", **dict_)
interpret_config = InterpretConfig(**dict_)

interpret_manager = Predictor(interpret_config)
interpret_manager.interpret()
interpret_map = interpret_manager.get_interpretation(
"train", f"test-{interpret_config.method}"
"train", f"test-{interpret_config.interpret.method}"
)

0 comments on commit c2d4a23

Please sign in to comment.