1
1
import unittest
2
- from unittest import mock
3
2
from typing import List , Optional
3
+ from unittest import mock
4
4
5
5
from transformers import is_tf_available , is_torch_available , pipeline
6
- from transformers .tokenization_utils_base import to_py_obj
7
6
from transformers .pipelines import DefaultArgumentHandler , Pipeline
8
7
from transformers .testing_utils import _run_slow_tests , is_pipeline_test , require_tf , require_torch , slow
8
+ from transformers .tokenization_utils_base import to_py_obj
9
9
10
10
11
11
VALID_INPUTS = ["A simple string" , ["list of strings" ]]
12
12
13
13
14
- @is_pipeline_test
14
+ # @is_pipeline_test
15
15
class CustomInputPipelineCommonMixin :
16
16
pipeline_task = None
17
- pipeline_loading_kwargs = {}
18
- small_models = None # Models tested without the @slow decorator
19
- large_models = None # Models tested with the @slow decorator
17
+ pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
18
+ pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
19
+ small_models = [] # Models tested without the @slow decorator
20
+ large_models = [] # Models tested with the @slow decorator
21
+ valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
20
22
21
23
def setUp (self ) -> None :
22
24
if not is_tf_available () and not is_torch_available ():
@@ -48,78 +50,41 @@ def setUp(self) -> None:
48
50
@require_torch
49
51
@slow
50
52
def test_pt_defaults (self ):
51
- pipeline (self .pipeline_task , framework = "pt" )
53
+ pipeline (self .pipeline_task , framework = "pt" , ** self . pipeline_loading_kwargs )
52
54
53
55
@require_tf
54
56
@slow
55
57
def test_tf_defaults (self ):
56
- pipeline (self .pipeline_task , framework = "tf" )
58
+ pipeline (self .pipeline_task , framework = "tf" , ** self . pipeline_loading_kwargs )
57
59
58
60
@require_torch
59
61
def test_torch_small (self ):
60
62
for model_name in self .small_models :
61
- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "pt" )
63
+ nlp = pipeline (
64
+ task = self .pipeline_task ,
65
+ model = model_name ,
66
+ tokenizer = model_name ,
67
+ framework = "pt" ,
68
+ ** self .pipeline_loading_kwargs ,
69
+ )
62
70
self ._test_pipeline (nlp )
63
71
64
72
@require_tf
65
73
def test_tf_small (self ):
66
74
for model_name in self .small_models :
67
- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "tf" )
75
+ nlp = pipeline (
76
+ task = self .pipeline_task ,
77
+ model = model_name ,
78
+ tokenizer = model_name ,
79
+ framework = "tf" ,
80
+ ** self .pipeline_loading_kwargs ,
81
+ )
68
82
self ._test_pipeline (nlp )
69
83
70
84
@require_torch
71
85
@slow
72
86
def test_torch_large (self ):
73
87
for model_name in self .large_models :
74
- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "pt" )
75
- self ._test_pipeline (nlp )
76
-
77
- @require_tf
78
- @slow
79
- def test_tf_large (self ):
80
- for model_name in self .large_models :
81
- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "tf" )
82
- self ._test_pipeline (nlp )
83
-
84
- def _test_pipeline (self , nlp : Pipeline ):
85
- raise NotImplementedError
86
-
87
-
88
- # @is_pipeline_test
89
- class MonoInputPipelineCommonMixin :
90
- pipeline_task = None
91
- pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
92
- pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
93
- small_models = [] # Models tested without the @slow decorator
94
- large_models = [] # Models tested with the @slow decorator
95
- mandatory_keys = {} # Keys which should be in the output
96
- valid_inputs = VALID_INPUTS # inputs which are valid
97
- invalid_inputs = [None ] # inputs which are not allowed
98
- expected_multi_result : Optional [List ] = None
99
- expected_check_keys : Optional [List [str ]] = None
100
-
101
- def setUp (self ) -> None :
102
- if not is_tf_available () and not is_torch_available ():
103
- return # Currently no JAX pipelines
104
-
105
- for model_name in self .small_models :
106
- pipeline (self .pipeline_task , model = model_name , tokenizer = model_name , ** self .pipeline_loading_kwargs )
107
- for model_name in self .large_models :
108
- pipeline (self .pipeline_task , model = model_name , tokenizer = model_name , ** self .pipeline_loading_kwargs )
109
-
110
- @require_torch
111
- @slow
112
- def test_pt_defaults_loads (self ):
113
- pipeline (self .pipeline_task , framework = "pt" , ** self .pipeline_loading_kwargs )
114
-
115
- @require_tf
116
- @slow
117
- def test_tf_defaults_loads (self ):
118
- pipeline (self .pipeline_task , framework = "tf" , ** self .pipeline_loading_kwargs )
119
-
120
- @require_torch
121
- def test_torch_small (self ):
122
- for model_name in self .small_models :
123
88
nlp = pipeline (
124
89
task = self .pipeline_task ,
125
90
model = model_name ,
@@ -130,8 +95,9 @@ def test_torch_small(self):
130
95
self ._test_pipeline (nlp )
131
96
132
97
@require_tf
133
- def test_tf_small (self ):
134
- for model_name in self .small_models :
98
+ @slow
99
+ def test_tf_large (self ):
100
+ for model_name in self .large_models :
135
101
nlp = pipeline (
136
102
task = self .pipeline_task ,
137
103
model = model_name ,
@@ -141,6 +107,9 @@ def test_tf_small(self):
141
107
)
142
108
self ._test_pipeline (nlp )
143
109
110
+ def _test_pipeline (self , nlp : Pipeline ):
111
+ raise NotImplementedError
112
+
144
113
@require_torch
145
114
def test_compare_slow_fast_torch (self ):
146
115
for model_name in self .small_models :
@@ -160,7 +129,7 @@ def test_compare_slow_fast_torch(self):
160
129
use_fast = True ,
161
130
** self .pipeline_loading_kwargs ,
162
131
)
163
- self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast )
132
+ self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast , method = "forward" )
164
133
165
134
@require_tf
166
135
def test_compare_slow_fast_tf (self ):
@@ -181,54 +150,51 @@ def test_compare_slow_fast_tf(self):
181
150
use_fast = True ,
182
151
** self .pipeline_loading_kwargs ,
183
152
)
184
- self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast )
185
-
186
- def _compare_slow_fast_pipelines (self , nlp_slow : Pipeline , nlp_fast : Pipeline ):
187
- with mock .patch .object (nlp_slow .model , 'forward' , wraps = nlp_slow .model .forward ) as mock_slow ,\
188
- mock .patch .object (nlp_fast .model , 'forward' , wraps = nlp_fast .model .forward ) as mock_fast :
153
+ self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast , method = "call" )
154
+
155
+ def _compare_slow_fast_pipelines (self , nlp_slow : Pipeline , nlp_fast : Pipeline , method : str ):
156
+ """We check that the inputs to the models forward passes are identical for
157
+ slow and fast tokenizers.
158
+ """
159
+ with mock .patch .object (
160
+ nlp_slow .model , method , wraps = getattr (nlp_slow .model , method )
161
+ ) as mock_slow , mock .patch .object (nlp_fast .model , method , wraps = getattr (nlp_fast .model , method )) as mock_fast :
189
162
for inputs in self .valid_inputs :
190
- outputs_slow = nlp_slow (inputs , ** self .pipeline_running_kwargs )
191
- outputs_fast = nlp_fast (inputs , ** self .pipeline_running_kwargs )
163
+ if isinstance (inputs , dict ):
164
+ inputs .update (self .pipeline_running_kwargs )
165
+ _ = nlp_slow (** inputs )
166
+ _ = nlp_fast (** inputs )
167
+ else :
168
+ _ = nlp_slow (inputs , ** self .pipeline_running_kwargs )
169
+ _ = nlp_fast (inputs , ** self .pipeline_running_kwargs )
192
170
193
171
mock_slow .assert_called ()
194
172
mock_fast .assert_called ()
195
173
196
- slow_call_args , slow_call_kwargs = mock_slow .call_args
197
- fast_call_args , fast_call_kwargs = mock_fast .call_args
174
+ self .assertEqual (len (mock_slow .call_args_list ), len (mock_fast .call_args_list ))
175
+ for mock_slow_call_args , mock_fast_call_args in zip (
176
+ mock_slow .call_args_list , mock_slow .call_args_list
177
+ ):
178
+ slow_call_args , slow_call_kwargs = mock_slow_call_args
179
+ fast_call_args , fast_call_kwargs = mock_fast_call_args
198
180
199
- slow_call_args , slow_call_kwargs = to_py_obj (slow_call_args ), to_py_obj (slow_call_kwargs )
200
- fast_call_args , fast_call_kwargs = to_py_obj (fast_call_args ), to_py_obj (fast_call_kwargs )
181
+ slow_call_args , slow_call_kwargs = to_py_obj (slow_call_args ), to_py_obj (slow_call_kwargs )
182
+ fast_call_args , fast_call_kwargs = to_py_obj (fast_call_args ), to_py_obj (fast_call_kwargs )
201
183
202
- self .assertEqual (slow_call_args , fast_call_args )
203
- self .assertDictEqual (slow_call_kwargs , fast_call_kwargs )
184
+ self .assertEqual (slow_call_args , fast_call_args )
185
+ self .assertDictEqual (slow_call_kwargs , fast_call_kwargs )
204
186
205
- self .assertEqual (outputs_slow , outputs_fast )
206
187
207
- @require_torch
208
- @slow
209
- def test_torch_large (self ):
210
- for model_name in self .large_models :
211
- nlp = pipeline (
212
- task = self .pipeline_task ,
213
- model = model_name ,
214
- tokenizer = model_name ,
215
- framework = "pt" ,
216
- ** self .pipeline_loading_kwargs ,
217
- )
218
- self ._test_pipeline (nlp )
188
+ @is_pipeline_test
189
+ class MonoInputPipelineCommonMixin (CustomInputPipelineCommonMixin ):
190
+ """A version of the CustomInputPipelineCommonMixin
191
+ with a predefined `_test_pipeline` method.
192
+ """
219
193
220
- @require_tf
221
- @slow
222
- def test_tf_large (self ):
223
- for model_name in self .large_models :
224
- nlp = pipeline (
225
- task = self .pipeline_task ,
226
- model = model_name ,
227
- tokenizer = model_name ,
228
- framework = "tf" ,
229
- ** self .pipeline_loading_kwargs ,
230
- )
231
- self ._test_pipeline (nlp )
194
+ mandatory_keys = {} # Keys which should be in the output
195
+ invalid_inputs = [None ] # inputs which are not allowed
196
+ expected_multi_result : Optional [List ] = None
197
+ expected_check_keys : Optional [List [str ]] = None
232
198
233
199
def _test_pipeline (self , nlp : Pipeline ):
234
200
self .assertIsNotNone (nlp )
0 commit comments