From 7a25f6ba121bc4f8914921b2c5b5008dd28e7542 Mon Sep 17 00:00:00 2001 From: cbaakman Date: Thu, 21 Dec 2023 11:26:04 +0100 Subject: [PATCH] a fix for the output files of prediction --- deeprank/learn/NeuralNet.py | 5 +++-- test/test_learn.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deeprank/learn/NeuralNet.py b/deeprank/learn/NeuralNet.py index 8933bd3..dfca5ab 100644 --- a/deeprank/learn/NeuralNet.py +++ b/deeprank/learn/NeuralNet.py @@ -736,9 +736,10 @@ def _epoch(self, epoch_number, pass_name, data_loader, train_model): if targets is not None: target_values += targets.tolist() + else: + target_values += [-1] * outputs.shape[0] - if len(target_values) > 0: - self._metrics_output.process(pass_name, epoch_number, entry_names, output_values, target_values) + self._metrics_output.process(pass_name, epoch_number, entry_names, output_values, target_values) if count_data_entries > 0: epoch_loss = sum_of_losses / count_data_entries diff --git a/test/test_learn.py b/test/test_learn.py index e0a5d15..48b4326 100644 --- a/test/test_learn.py +++ b/test/test_learn.py @@ -68,11 +68,14 @@ def test_predict(): metrics_directory = os.path.join(work_dir_path, "runs") + output_exporter = OutputExporter(metrics_directory) + neural_net = NeuralNet(dataset, cnn_class, model_type='3d',task='class', pretrained_model="test/data/models/best_valid_model.pth.tar", - cuda=False, metrics_exporters=[OutputExporter(metrics_directory), - TensorboardBinaryClassificationExporter(metrics_directory)]) + cuda=False, metrics_exporters=[output_exporter]) neural_net.test() + + assert os.path.isfile(output_exporter.get_filename("test", 0)) finally: rmtree(work_dir_path)