Skip to content

Commit e7ff3f1

Browse files
authored
FIX weighted loss issue (#94)
* Changed tests for losses and how weighted strategy is handled in the base trainer * Addressed comments from francisco * Fix training test * Re-arranged tests and moved test_setup to pytest * Reduced search space for dummy forward backward pass of backbones * Fix typo * ADD Doc string to loss function
1 parent dce6a5c commit e7ff3f1

28 files changed

+322
-193
lines changed

autoPyTorch/pipeline/components/training/losses.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
"""
2+
Loss functions available in autoPyTorch
3+
4+
Classification:
5+
CrossEntropyLoss: supports multiclass, binary output types
6+
BCEWithLogitsLoss: supports binary output types
7+
Default: CrossEntropyLoss
8+
Regression:
9+
MSELoss: supports continuous output types
10+
L1Loss: supports continuous output types
11+
Default: MSELoss
12+
"""
113
from typing import Any, Dict, Optional, Type
214

315
from torch.nn.modules.loss import (
@@ -11,21 +23,30 @@
1123
from autoPyTorch.constants import BINARY, CLASSIFICATION_TASKS, CONTINUOUS, MULTICLASS, REGRESSION_TASKS, \
1224
STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES, TASK_TYPES_TO_STRING
1325

26+
1427
losses = dict(classification=dict(
1528
CrossEntropyLoss=dict(
16-
module=CrossEntropyLoss, supported_output_type=MULTICLASS),
29+
module=CrossEntropyLoss, supported_output_types=[MULTICLASS, BINARY]),
1730
BCEWithLogitsLoss=dict(
18-
module=BCEWithLogitsLoss, supported_output_type=BINARY)),
31+
module=BCEWithLogitsLoss, supported_output_types=[BINARY])),
1932
regression=dict(
2033
MSELoss=dict(
21-
module=MSELoss, supported_output_type=CONTINUOUS),
34+
module=MSELoss, supported_output_types=[CONTINUOUS]),
2235
L1Loss=dict(
23-
module=L1Loss, supported_output_type=CONTINUOUS)))
36+
module=L1Loss, supported_output_types=[CONTINUOUS])))
2437

2538
default_losses = dict(classification=CrossEntropyLoss, regression=MSELoss)
2639

2740

2841
def get_default(task: int) -> Type[Loss]:
42+
"""
43+
Utility function to get default loss for the task
44+
Args:
45+
task (int):
46+
47+
Returns:
48+
Type[torch.nn.modules.loss._Loss]
49+
"""
2950
if task in CLASSIFICATION_TASKS:
3051
return default_losses['classification']
3152
elif task in REGRESSION_TASKS:
@@ -35,19 +56,42 @@ def get_default(task: int) -> Type[Loss]:
3556

3657

3758
def get_supported_losses(task: int, output_type: int) -> Dict[str, Type[Loss]]:
59+
"""
60+
Utility function to get supported losses for a given task and output type
61+
Args:
62+
task (int): integer identifier for the task
63+
output_type: integer identifier for the output type of the task
64+
65+
Returns:
66+
Returns a dictionary containing the losses supported for the given
67+
inputs. Key-Name, Value-Module
68+
"""
3869
supported_losses = dict()
3970
if task in CLASSIFICATION_TASKS:
4071
for key, value in losses['classification'].items():
41-
if output_type == value['supported_output_type']:
72+
if output_type in value['supported_output_types']:
4273
supported_losses[key] = value['module']
4374
elif task in REGRESSION_TASKS:
4475
for key, value in losses['regression'].items():
45-
if output_type == value['supported_output_type']:
76+
if output_type in value['supported_output_types']:
4677
supported_losses[key] = value['module']
4778
return supported_losses
4879

4980

50-
def get_loss_instance(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]:
81+
def get_loss(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]:
82+
"""
83+
Utility function to get losses for the given dataset properties.
84+
If name is mentioned, checks if the loss is compatible with
85+
the dataset properties and returns the specific loss
86+
Args:
87+
dataset_properties (Dict[str, Any]): Dictionary containing
88+
properties of the dataset. Must contain task_type and
89+
output_type as strings.
90+
name (Optional[str]): name of the specific loss
91+
92+
Returns:
93+
Type[torch.nn.modules.loss._Loss]
94+
"""
5195
assert 'task_type' in dataset_properties, \
5296
"Expected dataset_properties to have task_type got {}".format(dataset_properties.keys())
5397
assert 'output_type' in dataset_properties, \

autoPyTorch/pipeline/components/training/trainer/base_trainer.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
33

44
import numpy as np
55

@@ -10,6 +10,7 @@
1010
from torch.optim.lr_scheduler import _LRScheduler
1111
from torch.utils.tensorboard.writer import SummaryWriter
1212

13+
1314
from autoPyTorch.constants import REGRESSION_TASKS
1415
from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent
1516
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
@@ -173,14 +174,13 @@ def prepare(
173174
self,
174175
metrics: List[Any],
175176
model: torch.nn.Module,
176-
criterion: torch.nn.Module,
177+
criterion: Type[torch.nn.Module],
177178
budget_tracker: BudgetTracker,
178179
optimizer: Optimizer,
179180
device: torch.device,
180181
metrics_during_training: bool,
181182
scheduler: _LRScheduler,
182183
task_type: int,
183-
output_type: int,
184184
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
185185
) -> None:
186186

@@ -191,19 +191,12 @@ def prepare(
191191
self.metrics = metrics
192192

193193
# Weights for the loss function
194-
weights = None
195-
kwargs: Dict[str, Any] = {}
196-
# if self.weighted_loss:
197-
# weights = self.get_class_weights(output_type, labels)
198-
# if output_type == BINARY:
199-
# kwargs['pos_weight'] = weights
200-
# pass
201-
# else:
202-
# kwargs['weight'] = weights
194+
kwargs = {}
195+
if self.weighted_loss:
196+
kwargs = self.get_class_weights(criterion, labels)
203197

204198
# Setup the loss function
205-
self.criterion = criterion(**kwargs) if weights is not None else criterion()
206-
199+
self.criterion = criterion(**kwargs)
207200
# setup the model
208201
self.model = model.to(device)
209202

@@ -384,13 +377,16 @@ def compute_metrics(self, outputs_data: np.ndarray, targets_data: np.ndarray
384377
targets_data = torch.cat(targets_data, dim=0).numpy()
385378
return calculate_score(targets_data, outputs_data, self.task_type, self.metrics)
386379

387-
def get_class_weights(self, output_type: int, labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
388-
) -> np.ndarray:
389-
strategy = get_loss_weight_strategy(output_type)
380+
def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
381+
) -> Dict[str, np.ndarray]:
382+
strategy = get_loss_weight_strategy(criterion)
390383
weights = strategy(y=labels)
391384
weights = torch.from_numpy(weights)
392385
weights = weights.float().to(self.device)
393-
return weights
386+
if criterion.__name__ == 'BCEWithLogitsLoss':
387+
return {'pos_weight': weights}
388+
else:
389+
return {'weight': weights}
394390

395391
def data_preparation(self, X: np.ndarray, y: np.ndarray,
396392
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:

autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from torch.optim.lr_scheduler import _LRScheduler
2020
from torch.utils.tensorboard.writer import SummaryWriter
2121

22-
from autoPyTorch.constants import STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES
22+
from autoPyTorch.constants import STRING_TO_TASK_TYPES
2323
from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice
2424
from autoPyTorch.pipeline.components.base_component import (
2525
ThirdPartyComponents,
2626
autoPyTorchComponent,
2727
find_components,
2828
)
29-
from autoPyTorch.pipeline.components.training.losses import get_loss_instance
29+
from autoPyTorch.pipeline.components.training.losses import get_loss
3030
from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics
3131
from autoPyTorch.pipeline.components.training.trainer.base_trainer import (
3232
BaseTrainerComponent,
@@ -265,15 +265,14 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Modu
265265
model=X['network'],
266266
metrics=get_metrics(dataset_properties=X['dataset_properties'],
267267
names=additional_metrics),
268-
criterion=get_loss_instance(X['dataset_properties'],
269-
name=additional_losses),
268+
criterion=get_loss(X['dataset_properties'],
269+
name=additional_losses),
270270
budget_tracker=budget_tracker,
271271
optimizer=X['optimizer'],
272272
device=get_device_from_fit_dictionary(X),
273273
metrics_during_training=X['metrics_during_training'],
274274
scheduler=X['lr_scheduler'],
275275
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
276-
output_type=STRING_TO_OUTPUT_TYPES[X['dataset_properties']['output_type']],
277276
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]]
278277
)
279278
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])

autoPyTorch/utils/implementations.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
from typing import Callable, Union
1+
from typing import Any, Callable, Dict, Type, Union
22

33
import numpy as np
44

55
import torch
66

7-
from autoPyTorch.constants import BINARY
87

9-
10-
def get_loss_weight_strategy(output_type: int) -> Callable:
11-
if output_type == BINARY:
8+
def get_loss_weight_strategy(loss: Type[torch.nn.Module]) -> Callable:
9+
"""
10+
Utility function that returns strategy for the given loss
11+
Args:
12+
loss (Type[torch.nn.Module]): type of the loss function
13+
Returns:
14+
(Callable): Relevant Callable strategy
15+
"""
16+
if loss.__name__ in LossWeightStrategyWeightedBinary.get_properties()['supported_losses']:
1217
return LossWeightStrategyWeightedBinary()
13-
else:
18+
elif loss.__name__ in LossWeightStrategyWeighted.get_properties()['supported_losses']:
1419
return LossWeightStrategyWeighted()
20+
else:
21+
raise ValueError("No strategy currently supports the given loss, {}".format(loss.__name__))
1522

1623

1724
class LossWeightStrategyWeighted():
@@ -34,6 +41,10 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
3441

3542
return weights
3643

44+
@staticmethod
45+
def get_properties() -> Dict[str, Any]:
46+
return {'supported_losses': ['CrossEntropyLoss']}
47+
3748

3849
class LossWeightStrategyWeightedBinary():
3950
def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
@@ -46,3 +57,7 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
4657
weights = counts_zero / np.maximum(counts_one, 1)
4758

4859
return np.array(weights)
60+
61+
@staticmethod
62+
def get_properties() -> Dict[str, Any]:
63+
return {'supported_losses': ['BCEWithLogitsLoss']}

test/conftest.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from sklearn.datasets import fetch_openml, make_classification, make_regression
1616

17+
import torch
18+
1719
from autoPyTorch.data.tabular_validator import TabularInputValidator
1820
from autoPyTorch.datasets.tabular_dataset import TabularDataset
1921
from autoPyTorch.utils.backend import create
@@ -357,3 +359,61 @@ def error_search_space_updates():
357359
value_range=[0, 0.5],
358360
default_value=0.2)
359361
return updates
362+
363+
364+
@pytest.fixture
365+
def loss_cross_entropy_multiclass():
366+
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'multiclass'}
367+
predictions = torch.randn(4, 4, requires_grad=True)
368+
name = 'CrossEntropyLoss'
369+
targets = torch.empty(4, dtype=torch.long).random_(4)
370+
# to ensure we have all classes in the labels
371+
while True:
372+
labels = torch.empty(20, dtype=torch.long).random_(4)
373+
if len(torch.unique(labels)) == 4:
374+
break
375+
376+
return dataset_properties, predictions, name, targets, labels
377+
378+
379+
@pytest.fixture
380+
def loss_cross_entropy_binary():
381+
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
382+
predictions = torch.randn(4, 2, requires_grad=True)
383+
name = 'CrossEntropyLoss'
384+
targets = torch.empty(4, dtype=torch.long).random_(2)
385+
# to ensure we have all classes in the labels
386+
while True:
387+
labels = torch.empty(20, dtype=torch.long).random_(2)
388+
if len(torch.unique(labels)) == 2:
389+
break
390+
return dataset_properties, predictions, name, targets, labels
391+
392+
393+
@pytest.fixture
394+
def loss_bce():
395+
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
396+
predictions = torch.empty(4).random_(2)
397+
name = 'BCEWithLogitsLoss'
398+
targets = torch.empty(4).random_(2)
399+
# to ensure we have all classes in the labels
400+
while True:
401+
labels = torch.empty(20, dtype=torch.long).random_(2)
402+
if len(torch.unique(labels)) == 2:
403+
break
404+
return dataset_properties, predictions, name, targets, labels
405+
406+
407+
@pytest.fixture
408+
def loss_mse():
409+
dataset_properties = {'task_type': 'tabular_regression', 'output_type': 'continuous'}
410+
predictions = torch.randn(4)
411+
name = 'MSELoss'
412+
targets = torch.randn(4)
413+
labels = None
414+
return dataset_properties, predictions, name, targets, labels
415+
416+
417+
@pytest.fixture
418+
def loss_details(request):
419+
return request.getfixturevalue(request.param)

test/test_pipeline/components/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def prepare_trainer(self,
9797
device=device,
9898
metrics_during_training=True,
9999
task_type=task_type,
100-
output_type=output_type,
101100
labels=y
102101
)
103102
return trainer, model, optimizer, loader, criterion, epochs, logger

test/test_pipeline/components/preprocessing/__init__.py

Whitespace-only changes.

test/test_pipeline/components/setup/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)