Skip to content

Commit

Permalink
Merge pull request #427 from datamol-io/metric_fix
Browse files Browse the repository at this point in the history
Metric fix: auroc avpr and metric averaging
  • Loading branch information
DomInvivo authored Aug 2, 2023
2 parents 6006da5 + 7903fb0 commit 2e3d0bf
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
4 changes: 2 additions & 2 deletions expts/neurips2023_configs/base_config/large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ predictor:
metrics:
l1000_vcap: &classif_metrics
- name: auroc
metric: auroc_ipu
metric: auroc
num_classes: 5
task: multiclass
target_to_int: True
Expand All @@ -370,7 +370,7 @@ metrics:
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
metric: averageprecision
num_classes: 5
task: multiclass
target_to_int: True
Expand Down
13 changes: 9 additions & 4 deletions expts/neurips2023_configs/debug/config_large_gcn_debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,16 @@ predictor:
metrics:
l1000_vcap: &classif_metrics
- name: auroc
metric: auroc_ipu
metric: auroc
num_classes: 5
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
metric: averageprecision
num_classes: 5
task: multiclass
target_to_int: True
Expand All @@ -377,14 +380,16 @@ metrics:
l1000_mcf7: *classif_metrics
pcba_1328:
- name: auroc
metric: auroc_ipu
metric: auroc
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
metric: averageprecision
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
pcqm4m_g25: &pcqm_metrics
- name: mae
Expand Down
7 changes: 7 additions & 0 deletions graphium/trainer/predictor_summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,13 @@ def get_metrics_logs(
task_metrics_logs = {}
for task in self.tasks:
task_metrics_logs[task] = self.task_summaries[task].get_metrics_logs()
# average metrics
for key in task_metrics_logs[task]:
if isinstance(task_metrics_logs[task][key], torch.Tensor):
if task_metrics_logs[task][key].numel() > 1:
task_metrics_logs[task][key] = task_metrics_logs[task][key][
task_metrics_logs[task][key] != 0
].mean()

# Include global (weighted loss)
task_metrics_logs["_global"] = {}
Expand Down

0 comments on commit 2e3d0bf

Please sign in to comment.