Skip to content

Commit

Permalink
Clean MapsManager (#658)
Browse files Browse the repository at this point in the history
* create the get_dataloader function

* remove the train multi
  • Loading branch information
camillebrianceau authored Oct 1, 2024
1 parent 8a1589e commit bf67c1c
Show file tree
Hide file tree
Showing 4 changed files with 809 additions and 224 deletions.
59 changes: 56 additions & 3 deletions clinicadl/API_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pathlib import Path

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.caps_dataset.data import return_dataset
from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData
from clinicadl.trainer.config.classification import ClassificationConfig
from clinicadl.trainer.trainer import Trainer
Expand All @@ -12,6 +15,56 @@

DeepLearningPrepareData(image_config)

config = ClassificationConfig()
trainer = Trainer(config)
trainer.train(split_list=config.cross_validation.split, overwrite=True)
dataset = return_dataset(
input_dir,
data_df,
preprocessing_dict,
transforms_config,
label,
label_code,
cnn_index,
label_presence,
multi_cohort,
)

split_config = SplitConfig()
splitter = Splitter(split_config)

validator_config = ValidatorConfig()
validator = Validator(validator_config)

train_config = ClassificationConfig()
trainer = Trainer(train_config, validator)

for split in splitter.split_iterator():
for network in range(
first_network, self.maps_manager.num_networks
): # for multi_network
###### actual _train_single method of the Trainer ############
train_loader = trainer.get_dataloader(dataset, split, network, "train", config)
valid_loader = validator.get_dataloader(
dataset, split, network, "valid", config
) # ?? validatior, trainer ?

trainer._train(
train_loader,
valid_loader,
split=split,
network=network,
resume=resume, # in a config class
callbacks=[CodeCarbonTracker], # in a config class ?
)

validator._ensemble_prediction(
self.maps_manager,
"train",
split,
self.config.validation.selection_metrics,
)
validator._ensemble_prediction(
self.maps_manager,
"validation",
split,
self.config.validation.selection_metrics,
)
###### end ############
Loading

0 comments on commit bf67c1c

Please sign in to comment.