@@ -169,52 +169,52 @@ def test_opt_params_handler_on_non_torch_optimizers():
169
169
assert "lr/group_0" in res and res ["lr/group_0" ] == 0.1234
170
170
171
171
172
- def test_attach ():
172
+ @pytest .mark .parametrize (
173
+ "event, n_calls, kwargs" ,
174
+ [
175
+ (Events .ITERATION_STARTED , 50 * 5 , {"a" : 0 }),
176
+ (Events .ITERATION_COMPLETED , 50 * 5 , {}),
177
+ (Events .EPOCH_STARTED , 5 , {}),
178
+ (Events .EPOCH_COMPLETED , 5 , {}),
179
+ (Events .STARTED , 1 , {}),
180
+ (Events .COMPLETED , 1 , {}),
181
+ (Events .ITERATION_STARTED (every = 10 ), 50 // 10 * 5 , {}),
182
+ (Events .STARTED | Events .COMPLETED , 2 , {}),
183
+ ],
184
+ )
185
+ def test_attach (event , n_calls , kwargs ):
173
186
174
187
n_epochs = 5
175
188
data = list (range (50 ))
176
189
177
- def _test (event , n_calls , kwargs = {}):
178
-
179
- losses = torch .rand (n_epochs * len (data ))
180
- losses_iter = iter (losses )
190
+ losses = torch .rand (n_epochs * len (data ))
191
+ losses_iter = iter (losses )
181
192
182
- def update_fn (engine , batch ):
183
- return next (losses_iter )
193
+ def update_fn (engine , batch ):
194
+ return next (losses_iter )
184
195
185
- trainer = Engine (update_fn )
196
+ trainer = Engine (update_fn )
186
197
187
- logger = DummyLogger ()
188
-
189
- mock_log_handler = MagicMock ()
190
-
191
- logger .attach (trainer , log_handler = mock_log_handler , event_name = event , ** kwargs )
198
+ logger = DummyLogger ()
192
199
193
- trainer . run ( data , max_epochs = n_epochs )
200
+ mock_log_handler = MagicMock ( )
194
201
195
- if isinstance (event , EventsList ):
196
- events = [e for e in event ]
197
- else :
198
- events = [event ]
202
+ logger .attach (trainer , log_handler = mock_log_handler , event_name = event , ** kwargs )
199
203
200
- if len (kwargs ) > 0 :
201
- calls = [call (trainer , logger , e , ** kwargs ) for e in events ]
202
- else :
203
- calls = [call (trainer , logger , e ) for e in events ]
204
+ trainer .run (data , max_epochs = n_epochs )
204
205
205
- mock_log_handler .assert_has_calls (calls )
206
- assert mock_log_handler .call_count == n_calls
206
+ if isinstance (event , EventsList ):
207
+ events = [e for e in event ]
208
+ else :
209
+ events = [event ]
207
210
208
- _test (Events .ITERATION_STARTED , len (data ) * n_epochs , kwargs = {"a" : 0 })
209
- _test (Events .ITERATION_COMPLETED , len (data ) * n_epochs )
210
- _test (Events .EPOCH_STARTED , n_epochs )
211
- _test (Events .EPOCH_COMPLETED , n_epochs )
212
- _test (Events .STARTED , 1 )
213
- _test (Events .COMPLETED , 1 )
211
+ if len (kwargs ) > 0 :
212
+ calls = [call (trainer , logger , e , ** kwargs ) for e in events ]
213
+ else :
214
+ calls = [call (trainer , logger , e ) for e in events ]
214
215
215
- _test (Events .ITERATION_STARTED (every = 10 ), len (data ) // 10 * n_epochs )
216
-
217
- _test (Events .STARTED | Events .COMPLETED , 2 )
216
+ mock_log_handler .assert_has_calls (calls )
217
+ assert mock_log_handler .call_count == n_calls
218
218
219
219
220
220
def test_attach_wrong_event_name ():
@@ -260,7 +260,19 @@ def update_fn(engine, batch):
260
260
assert mock_log_handler .call_count == n_calls
261
261
262
262
263
- def test_as_context_manager ():
263
+ @pytest .mark .parametrize (
264
+ "event, n_calls" ,
265
+ [
266
+ (Events .ITERATION_STARTED , 50 * 5 ),
267
+ (Events .ITERATION_COMPLETED , 50 * 5 ),
268
+ (Events .EPOCH_STARTED , 5 ),
269
+ (Events .EPOCH_COMPLETED , 5 ),
270
+ (Events .STARTED , 1 ),
271
+ (Events .COMPLETED , 1 ),
272
+ (Events .ITERATION_STARTED (every = 10 ), 50 // 10 * 5 ),
273
+ ],
274
+ )
275
+ def test_as_context_manager (event , n_calls ):
264
276
265
277
n_epochs = 5
266
278
data = list (range (50 ))
@@ -272,42 +284,32 @@ def __init__(self, writer):
272
284
def close (self ):
273
285
self .writer .close ()
274
286
275
- def _test (event , n_calls ):
276
- global close_counter
277
- close_counter = 0
278
-
279
- losses = torch .rand (n_epochs * len (data ))
280
- losses_iter = iter (losses )
281
-
282
- def update_fn (engine , batch ):
283
- return next (losses_iter )
287
+ global close_counter
288
+ close_counter = 0
284
289
285
- writer = MagicMock ( )
286
- writer . close = MagicMock ( )
290
+ losses = torch . rand ( n_epochs * len ( data ) )
291
+ losses_iter = iter ( losses )
287
292
288
- with _DummyLogger ( writer ) as logger :
289
- assert isinstance ( logger , _DummyLogger )
293
+ def update_fn ( engine , batch ) :
294
+ return next ( losses_iter )
290
295
291
- trainer = Engine ( update_fn )
292
- mock_log_handler = MagicMock ()
296
+ writer = MagicMock ( )
297
+ writer . close = MagicMock ()
293
298
294
- logger .attach (trainer , log_handler = mock_log_handler , event_name = event )
299
+ with _DummyLogger (writer ) as logger :
300
+ assert isinstance (logger , _DummyLogger )
295
301
296
- trainer .run (data , max_epochs = n_epochs )
302
+ trainer = Engine (update_fn )
303
+ mock_log_handler = MagicMock ()
297
304
298
- mock_log_handler .assert_called_with (trainer , logger , event )
299
- assert mock_log_handler .call_count == n_calls
305
+ logger .attach (trainer , log_handler = mock_log_handler , event_name = event )
300
306
301
- writer . close . assert_called_once_with ( )
307
+ trainer . run ( data , max_epochs = n_epochs )
302
308
303
- _test (Events .ITERATION_STARTED , len (data ) * n_epochs )
304
- _test (Events .ITERATION_COMPLETED , len (data ) * n_epochs )
305
- _test (Events .EPOCH_STARTED , n_epochs )
306
- _test (Events .EPOCH_COMPLETED , n_epochs )
307
- _test (Events .STARTED , 1 )
308
- _test (Events .COMPLETED , 1 )
309
+ mock_log_handler .assert_called_with (trainer , logger , event )
310
+ assert mock_log_handler .call_count == n_calls
309
311
310
- _test ( Events . ITERATION_STARTED ( every = 10 ), len ( data ) // 10 * n_epochs )
312
+ writer . close . assert_called_once_with ( )
311
313
312
314
313
315
def test_base_weights_handler_wrong_setup ():
0 commit comments