Skip to content

Commit dc44ce9

Browse files
committed
CLI EVALUATION_COMPARISONS check fix
1 parent d8bb58b commit dc44ce9

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/diffupath/cli.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
import click
1111
from bio2bel.constants import get_global_connection
12-
1312
from diffupath.ltoo import ltoo_by_method
14-
1513
from diffupy.constants import EMOJI, RAW, CSV, JSON
1614
from diffupy.diffuse import diffuse as run_diffusion
1715
from diffupy.kernels import regularised_laplacian_kernel
@@ -206,7 +204,7 @@ def evaluate(
206204
data_path: Optional[str] = os.path.join(ROOT_RESULTS_DIR, 'data', 'input_mappings'),
207205
graph: Optional[str] = GRAPH_PATH,
208206
kernel: Optional[str] = KERNEL_PATH,
209-
output: Optional[str] =os.path.join(OUTPUT_DIR, 'evaluation_metrics.json'),
207+
output: Optional[str] = os.path.join(OUTPUT_DIR, 'evaluation_metrics.json'),
210208
iterations: Optional[int] = 100,
211209
):
212210
"""Evaluate a kernel/network on one of the three presented datasets.
@@ -219,12 +217,12 @@ def evaluate(
219217
:param iterations: Number of iterations of the Cross-Validation.
220218
221219
"""
222-
click.secho(f'{EMOJI} Loading network for random cross-validation... {EMOJI}')
220+
click.secho(f'{EMOJI} Loading network for validation... {EMOJI}')
223221

224222
graph = process_graph_from_file(graph)
225223
kernel = process_kernel_from_file(kernel)
226224

227-
click.secho(f'{EMOJI} Loading data for cross-validation... {EMOJI}')
225+
click.secho(f'{EMOJI} Loading data for validation... {EMOJI}')
228226

229227
mapping_path_dataset_1 = os.path.join(data_path, 'dataset_1_mapping_absolute_value_bp.json')
230228
dataset1_mapping_by_database_and_entity = from_json(mapping_path_dataset_1)
@@ -268,7 +266,7 @@ def evaluate(
268266
k=iterations
269267
)
270268

271-
if comparison == BY_METHOD:
269+
elif comparison == BY_METHOD:
272270
dataset1_mapping_all_labels = reduce_dict_two_dimensional(dataset1_mapping_by_database_and_entity)
273271
dataset2_mapping_all_labels = reduce_dict_two_dimensional(dataset2_mapping_by_database_and_entity)
274272
dataset3_mapping_all_labels = reduce_dict_two_dimensional(dataset3_mapping_by_database_and_entity)
@@ -282,21 +280,24 @@ def evaluate(
282280
dataset1_mapping_all_labels,
283281
graph,
284282
kernel,
285-
k=iterations)
283+
k=iterations
284+
)
286285

287286
click.secho(f'{EMOJI} Running cross_validation_by_method for Dataset 2... {EMOJI}')
288287
metrics['auroc']['Dataset 2'], metrics['auprc']['Dataset 2'] = cross_validation_by_method(
289288
dataset2_mapping_all_labels,
290289
graph,
291290
kernel,
292-
k=iterations)
291+
k=iterations
292+
)
293293

294294
click.secho(f'{EMOJI} Running cross_validation_by_method for Dataset 3... {EMOJI}')
295295
metrics['auroc']['Dataset 3'], metrics['auprc']['Dataset 3'] = cross_validation_by_method(
296296
dataset3_mapping_all_labels,
297297
graph,
298298
kernel,
299-
k=iterations)
299+
k=iterations
300+
)
300301

301302
elif comparison == BY_DB:
302303
dataset1_mapping_all_labels = reduce_dict_two_dimensional(dataset1_mapping_by_database_and_entity)

0 commit comments

Comments
 (0)