Skip to content

Commit

Permalink
[vision] Add problem_type support (huggingface#15851)
Browse files Browse the repository at this point in the history
* Add problem_type to missing models

* Fix deit test

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
  • Loading branch information
NielsRogge and Niels Rogge authored Mar 1, 2022
1 parent 7ff9d45 commit 286fdc6
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 25 deletions.
24 changes: 18 additions & 6 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -857,14 +857,26 @@ def forward(

loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
Expand Down
24 changes: 18 additions & 6 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -726,14 +726,26 @@ def forward(

loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
Expand Down
24 changes: 18 additions & 6 deletions src/transformers/models/segformer/modeling_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -590,14 +590,26 @@ def forward(

loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down
24 changes: 18 additions & 6 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -747,14 +747,26 @@ def forward(

loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
Expand Down
54 changes: 54 additions & 0 deletions tests/deit/test_modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import inspect
import unittest
import warnings

from transformers import DeiTConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
Expand All @@ -32,6 +33,8 @@
from torch import nn

from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
DeiTForImageClassification,
DeiTForImageClassificationWithTeacher,
Expand Down Expand Up @@ -379,6 +382,57 @@ def test_training_gradient_checkpointing(self):
loss = model(**inputs).loss
loss.backward()

def test_problem_types(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

problem_types = [
{"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
{"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
{"title": "regression", "num_labels": 1, "dtype": torch.float},
]

for model_class in self.all_model_classes:
if (
model_class
not in [
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]
or model_class.__name__ == "DeiTForImageClassificationWithTeacher"
):
continue

for problem_type in problem_types:
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):

config.problem_type = problem_type["title"]
config.num_labels = problem_type["num_labels"]

model = model_class(config)
model.to(torch_device)
model.train()

inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)

if problem_type["num_labels"] > 1:
inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])

inputs["labels"] = inputs["labels"].to(problem_type["dtype"])

# This tests that we do not trigger the warning form PyTorch "Using a target size that is different
# to the input size. This will likely lead to incorrect results due to broadcasting. Please ensure
# they have the same size." which is a symptom something in wrong for the regression problem.
# See https://github.com/huggingface/transformers/issues/11780
with warnings.catch_warnings(record=True) as warning_list:
loss = model(**inputs).loss
for w in warning_list:
if "Using a target size that is different to the input size" in str(w.message):
raise ValueError(
f"Something is going wrong in the regression problem: intercepted {w.message}"
)

loss.backward()

def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,7 +1873,10 @@ def test_problem_types(self):
]

for model_class in self.all_model_classes:
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
if model_class not in [
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
continue

for problem_type in problem_types:
Expand Down

0 comments on commit 286fdc6

Please sign in to comment.