Skip to content

Commit

Permalink
Replace (TF)CommonTestCases for modeling with a mixin.
Browse files Browse the repository at this point in the history
I suspect the wrapper classes were created in order to prevent the
abstract base class (TF)CommonModelTester from being included in test
discovery and running, because that would fail.

I solved this by replacing the abstract base class with a mixin.

Code changes are just de-indenting and automatic reformattings
performed by black to use the extra line space.
  • Loading branch information
aaugustin committed Dec 22, 2019
1 parent 7e98e21 commit 345c23a
Show file tree
Hide file tree
Showing 26 changed files with 988 additions and 956 deletions.
6 changes: 4 additions & 2 deletions templates/adding_a_new_model/tests/test_modeling_tf_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function

import unittest

from transformers import XxxConfig, is_tf_available

from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow


Expand All @@ -32,7 +34,7 @@


@require_tf
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (
(
Expand Down
6 changes: 4 additions & 2 deletions templates/adding_a_new_model/tests/test_modeling_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function

import unittest

from transformers import is_torch_available

from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device


Expand All @@ -34,7 +36,7 @@


@require_torch
class XxxModelTest(CommonTestCases.CommonModelTester):
class XxxModelTest(ModelTesterMixin, unittest.TestCase):

all_model_classes = (
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function

import unittest

from transformers import is_torch_available

from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device


Expand All @@ -33,7 +35,7 @@


@require_torch
class AlbertModelTest(CommonTestCases.CommonModelTester):
class AlbertModelTest(ModelTesterMixin, unittest.TestCase):

all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()

Expand Down
6 changes: 4 additions & 2 deletions tests/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function

import unittest

from transformers import is_torch_available

from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, floats_tensor, ids_tensor
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device


Expand All @@ -37,7 +39,7 @@


@require_torch
class BertModelTest(CommonTestCases.CommonModelTester):
class BertModelTest(ModelTesterMixin, unittest.TestCase):

all_model_classes = (
(
Expand Down
Loading

0 comments on commit 345c23a

Please sign in to comment.