Skip to content

Commit 873df9a

Browse files
authored
[FIX] ROC AUC for multi class classification (#482)
* fixed cut mix * remove unnecessary comment * change all_supported_metrics * fix roc_auc for multiclass * remove unnecessary code
1 parent d29d11b commit 873df9a

File tree

4 files changed

+6
-7
lines changed

4 files changed

+6
-7
lines changed

autoPyTorch/pipeline/components/setup/network/base_network.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,14 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:
5656

5757
self.network = torch.nn.Sequential(X['network_embedding'], X['network_backbone'], X['network_head'])
5858

59+
if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS:
60+
self.network = torch.nn.Sequential(self.network, nn.Softmax(dim=1))
5961
# Properly set the network training device
6062
if self.device is None:
6163
self.device = get_device_from_fit_dictionary(X)
6264

6365
self.to(self.device)
6466

65-
if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS:
66-
self.final_activation = nn.Softmax(dim=1)
67-
6867
self.is_fitted_ = True
6968

7069
return self

autoPyTorch/pipeline/components/training/metrics/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def __call__(
173173
Score function applied to prediction of estimator on X.
174174
"""
175175
y_type = type_of_target(y_true)
176-
if y_type not in ("binary", "multilabel-indicator"):
176+
if y_type not in ("binary", "multilabel-indicator") and self.name != 'roc_auc':
177177
raise ValueError("{0} format is not supported".format(y_type))
178178

179179
if y_type == "binary":

autoPyTorch/pipeline/components/training/metrics/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858

5959
# Score functions that need decision values
60-
roc_auc = make_metric('roc_auc', sklearn.metrics.roc_auc_score, needs_threshold=True)
60+
roc_auc = make_metric('roc_auc', sklearn.metrics.roc_auc_score, needs_threshold=True, multi_class= 'ovo')
6161
average_precision = make_metric('average_precision',
6262
sklearn.metrics.average_precision_score,
6363
needs_threshold=True)

autoPyTorch/pipeline/components/training/metrics/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def get_metrics(dataset_properties: Dict[str, Any],
9999
if names is not None:
100100
for name in names:
101101
if name not in supported_metrics.keys():
102-
raise ValueError("Invalid name entered for task {}, currently "
103-
"supported metrics for task include {}".format(dataset_properties['task_type'],
102+
raise ValueError("Invalid name {} entered for task {}, currently "
103+
"supported metrics for task include {}".format(name, dataset_properties['task_type'],
104104
list(supported_metrics.keys())))
105105
else:
106106
metric = supported_metrics[name]

0 commit comments

Comments
 (0)