Skip to content

Commit 191f88a

Browse files
committed
WIP getting equivalence on pipelines
1 parent f3903bb commit 191f88a

File tree

4 files changed

+82
-103
lines changed

4 files changed

+82
-103
lines changed

src/transformers/tokenization_roberta_fast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ def mask_token(self) -> str:
190190

191191
@mask_token.setter
192192
def mask_token(self, value):
193-
""" Overriding the default behavior of the mask token to have it eat the space before it.
193+
"""Overriding the default behavior of the mask token to have it eat the space before it.
194194
195-
This is needed to preserve backward compatibility with all the previously used models
196-
based on Roberta.
195+
This is needed to preserve backward compatibility with all the previously used models
196+
based on Roberta.
197197
"""
198198
# Mask token behave like a normal word, i.e. include the space before it
199199
# So we set lstrip to True

tests/test_pipelines_common.py

Lines changed: 66 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
import unittest
2-
from unittest import mock
32
from typing import List, Optional
3+
from unittest import mock
44

55
from transformers import is_tf_available, is_torch_available, pipeline
6-
from transformers.tokenization_utils_base import to_py_obj
76
from transformers.pipelines import DefaultArgumentHandler, Pipeline
87
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
8+
from transformers.tokenization_utils_base import to_py_obj
99

1010

1111
VALID_INPUTS = ["A simple string", ["list of strings"]]
1212

1313

14-
@is_pipeline_test
14+
# @is_pipeline_test
1515
class CustomInputPipelineCommonMixin:
1616
pipeline_task = None
17-
pipeline_loading_kwargs = {}
18-
small_models = None # Models tested without the @slow decorator
19-
large_models = None # Models tested with the @slow decorator
17+
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
18+
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
19+
small_models = [] # Models tested without the @slow decorator
20+
large_models = [] # Models tested with the @slow decorator
21+
valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
2022

2123
def setUp(self) -> None:
2224
if not is_tf_available() and not is_torch_available():
@@ -48,78 +50,41 @@ def setUp(self) -> None:
4850
@require_torch
4951
@slow
5052
def test_pt_defaults(self):
51-
pipeline(self.pipeline_task, framework="pt")
53+
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
5254

5355
@require_tf
5456
@slow
5557
def test_tf_defaults(self):
56-
pipeline(self.pipeline_task, framework="tf")
58+
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
5759

5860
@require_torch
5961
def test_torch_small(self):
6062
for model_name in self.small_models:
61-
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt")
63+
nlp = pipeline(
64+
task=self.pipeline_task,
65+
model=model_name,
66+
tokenizer=model_name,
67+
framework="pt",
68+
**self.pipeline_loading_kwargs,
69+
)
6270
self._test_pipeline(nlp)
6371

6472
@require_tf
6573
def test_tf_small(self):
6674
for model_name in self.small_models:
67-
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf")
75+
nlp = pipeline(
76+
task=self.pipeline_task,
77+
model=model_name,
78+
tokenizer=model_name,
79+
framework="tf",
80+
**self.pipeline_loading_kwargs,
81+
)
6882
self._test_pipeline(nlp)
6983

7084
@require_torch
7185
@slow
7286
def test_torch_large(self):
7387
for model_name in self.large_models:
74-
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="pt")
75-
self._test_pipeline(nlp)
76-
77-
@require_tf
78-
@slow
79-
def test_tf_large(self):
80-
for model_name in self.large_models:
81-
nlp = pipeline(task=self.pipeline_task, model=model_name, tokenizer=model_name, framework="tf")
82-
self._test_pipeline(nlp)
83-
84-
def _test_pipeline(self, nlp: Pipeline):
85-
raise NotImplementedError
86-
87-
88-
# @is_pipeline_test
89-
class MonoInputPipelineCommonMixin:
90-
pipeline_task = None
91-
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
92-
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
93-
small_models = [] # Models tested without the @slow decorator
94-
large_models = [] # Models tested with the @slow decorator
95-
mandatory_keys = {} # Keys which should be in the output
96-
valid_inputs = VALID_INPUTS # inputs which are valid
97-
invalid_inputs = [None] # inputs which are not allowed
98-
expected_multi_result: Optional[List] = None
99-
expected_check_keys: Optional[List[str]] = None
100-
101-
def setUp(self) -> None:
102-
if not is_tf_available() and not is_torch_available():
103-
return # Currently no JAX pipelines
104-
105-
for model_name in self.small_models:
106-
pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs)
107-
for model_name in self.large_models:
108-
pipeline(self.pipeline_task, model=model_name, tokenizer=model_name, **self.pipeline_loading_kwargs)
109-
110-
@require_torch
111-
@slow
112-
def test_pt_defaults_loads(self):
113-
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
114-
115-
@require_tf
116-
@slow
117-
def test_tf_defaults_loads(self):
118-
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
119-
120-
@require_torch
121-
def test_torch_small(self):
122-
for model_name in self.small_models:
12388
nlp = pipeline(
12489
task=self.pipeline_task,
12590
model=model_name,
@@ -130,8 +95,9 @@ def test_torch_small(self):
13095
self._test_pipeline(nlp)
13196

13297
@require_tf
133-
def test_tf_small(self):
134-
for model_name in self.small_models:
98+
@slow
99+
def test_tf_large(self):
100+
for model_name in self.large_models:
135101
nlp = pipeline(
136102
task=self.pipeline_task,
137103
model=model_name,
@@ -141,6 +107,9 @@ def test_tf_small(self):
141107
)
142108
self._test_pipeline(nlp)
143109

110+
def _test_pipeline(self, nlp: Pipeline):
111+
raise NotImplementedError
112+
144113
@require_torch
145114
def test_compare_slow_fast_torch(self):
146115
for model_name in self.small_models:
@@ -160,7 +129,7 @@ def test_compare_slow_fast_torch(self):
160129
use_fast=True,
161130
**self.pipeline_loading_kwargs,
162131
)
163-
self._compare_slow_fast_pipelines(nlp_slow, nlp_fast)
132+
self._compare_slow_fast_pipelines(nlp_slow, nlp_fast, method="forward")
164133

165134
@require_tf
166135
def test_compare_slow_fast_tf(self):
@@ -181,54 +150,51 @@ def test_compare_slow_fast_tf(self):
181150
use_fast=True,
182151
**self.pipeline_loading_kwargs,
183152
)
184-
self._compare_slow_fast_pipelines(nlp_slow, nlp_fast)
185-
186-
def _compare_slow_fast_pipelines(self, nlp_slow: Pipeline, nlp_fast: Pipeline):
187-
with mock.patch.object(nlp_slow.model, 'forward', wraps=nlp_slow.model.forward) as mock_slow,\
188-
mock.patch.object(nlp_fast.model, 'forward', wraps=nlp_fast.model.forward) as mock_fast:
153+
self._compare_slow_fast_pipelines(nlp_slow, nlp_fast, method="call")
154+
155+
def _compare_slow_fast_pipelines(self, nlp_slow: Pipeline, nlp_fast: Pipeline, method: str):
156+
"""We check that the inputs to the models forward passes are identical for
157+
slow and fast tokenizers.
158+
"""
159+
with mock.patch.object(
160+
nlp_slow.model, method, wraps=getattr(nlp_slow.model, method)
161+
) as mock_slow, mock.patch.object(nlp_fast.model, method, wraps=getattr(nlp_fast.model, method)) as mock_fast:
189162
for inputs in self.valid_inputs:
190-
outputs_slow = nlp_slow(inputs, **self.pipeline_running_kwargs)
191-
outputs_fast = nlp_fast(inputs, **self.pipeline_running_kwargs)
163+
if isinstance(inputs, dict):
164+
inputs.update(self.pipeline_running_kwargs)
165+
_ = nlp_slow(**inputs)
166+
_ = nlp_fast(**inputs)
167+
else:
168+
_ = nlp_slow(inputs, **self.pipeline_running_kwargs)
169+
_ = nlp_fast(inputs, **self.pipeline_running_kwargs)
192170

193171
mock_slow.assert_called()
194172
mock_fast.assert_called()
195173

196-
slow_call_args, slow_call_kwargs = mock_slow.call_args
197-
fast_call_args, fast_call_kwargs = mock_fast.call_args
174+
self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list))
175+
for mock_slow_call_args, mock_fast_call_args in zip(
176+
mock_slow.call_args_list, mock_slow.call_args_list
177+
):
178+
slow_call_args, slow_call_kwargs = mock_slow_call_args
179+
fast_call_args, fast_call_kwargs = mock_fast_call_args
198180

199-
slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs)
200-
fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs)
181+
slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs)
182+
fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs)
201183

202-
self.assertEqual(slow_call_args, fast_call_args)
203-
self.assertDictEqual(slow_call_kwargs, fast_call_kwargs)
184+
self.assertEqual(slow_call_args, fast_call_args)
185+
self.assertDictEqual(slow_call_kwargs, fast_call_kwargs)
204186

205-
self.assertEqual(outputs_slow, outputs_fast)
206187

207-
@require_torch
208-
@slow
209-
def test_torch_large(self):
210-
for model_name in self.large_models:
211-
nlp = pipeline(
212-
task=self.pipeline_task,
213-
model=model_name,
214-
tokenizer=model_name,
215-
framework="pt",
216-
**self.pipeline_loading_kwargs,
217-
)
218-
self._test_pipeline(nlp)
188+
@is_pipeline_test
189+
class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin):
190+
"""A version of the CustomInputPipelineCommonMixin
191+
with a predefined `_test_pipeline` method.
192+
"""
219193

220-
@require_tf
221-
@slow
222-
def test_tf_large(self):
223-
for model_name in self.large_models:
224-
nlp = pipeline(
225-
task=self.pipeline_task,
226-
model=model_name,
227-
tokenizer=model_name,
228-
framework="tf",
229-
**self.pipeline_loading_kwargs,
230-
)
231-
self._test_pipeline(nlp)
194+
mandatory_keys = {} # Keys which should be in the output
195+
invalid_inputs = [None] # inputs which are not allowed
196+
expected_multi_result: Optional[List] = None
197+
expected_check_keys: Optional[List[str]] = None
232198

233199
def _test_pipeline(self, nlp: Pipeline):
234200
self.assertIsNotNone(nlp)

tests/test_pipelines_dialog.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class DialoguePipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
99
pipeline_task = "conversational"
1010
small_models = [] # Default model - Models tested without the @slow decorator
1111
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
12+
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
1213

1314
def _test_pipeline(self, nlp: Pipeline):
1415
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]

tests/test_pipelines_zero_shot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
1111
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
1212
] # Models tested without the @slow decorator
1313
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator
14+
valid_inputs = [
15+
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
16+
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]},
17+
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"},
18+
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]},
19+
{"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"},
20+
{
21+
"sequences": "Who are you voting for in 2020?",
22+
"candidate_labels": "politics",
23+
"hypothesis_template": "This text is about {}",
24+
},
25+
]
1426

1527
def _test_scores_sum_to_one(self, result):
1628
sum = 0.0

0 commit comments

Comments
 (0)