15
15
import ignite .distributed as idist
16
16
from ignite .engine import (
17
17
Engine ,
18
+ Events ,
18
19
_check_arg ,
19
20
create_supervised_evaluator ,
20
21
create_supervised_trainer ,
26
27
27
28
28
29
def _default_create_supervised_trainer (
29
- gradient_accumulation_steps = 1 ,
30
+ gradient_accumulation_steps : int = 1 ,
30
31
model_device : Optional [str ] = None ,
31
32
trainer_device : Optional [str ] = None ,
32
33
trace : bool = False ,
33
34
amp_mode : str = None ,
34
35
scaler : Union [bool , "torch.cuda.amp.GradScaler" ] = False ,
35
36
):
36
- model = Linear (1 , 1 )
37
+ model = Linear (1 , 1 , bias = False )
37
38
38
39
if model_device :
39
40
model .to (model_device )
40
41
41
42
model .weight .data .zero_ ()
42
- model .bias .data .zero_ ()
43
43
optimizer = SGD (model .parameters (), 0.1 )
44
44
45
45
if trace :
@@ -62,16 +62,12 @@ def _default_create_supervised_trainer(
62
62
gradient_accumulation_steps = gradient_accumulation_steps ,
63
63
)
64
64
assert model .weight .data [0 , 0 ].item () == approx (0.0 )
65
- assert model .bias .item () == approx (0.0 )
66
65
67
66
return trainer , model
68
67
69
68
70
69
def _test_create_supervised_trainer (
71
- gradient_accumulation_steps = 3 ,
72
- loss = 0.0045 ,
73
- weight = 0.0540 ,
74
- bias = 0.1133 ,
70
+ gradient_accumulation_steps : int = 1 ,
75
71
model_device : Optional [str ] = None ,
76
72
trainer_device : Optional [str ] = None ,
77
73
trace : bool = False ,
@@ -87,16 +83,32 @@ def _test_create_supervised_trainer(
87
83
scaler = scaler ,
88
84
)
89
85
90
- x = torch .tensor ([[0.1 ], [0.3 ], [0.7 ], [0.9 ], [1.3 ]])
91
- y = torch .tensor ([[0.3 ], [0.5 ], [0.9 ], [1.3 ], [0.3 ]])
86
+ x = torch .tensor ([[0.01 ], [0.02 ], [0.03 ], [0.04 ], [0.05 ]])
87
+ y = torch .tensor ([[0.015 ], [0.025 ], [0.035 ], [0.045 ], [0.055 ]])
92
88
data = [(_x , _y ) for _x , _y in zip (x , y )]
93
89
90
+ theta = [0.0 ]
91
+ accumulation = [0.0 ]
92
+ loss = [0.0 ]
93
+
94
+ @trainer .on (Events .ITERATION_COMPLETED )
95
+ def _ ():
96
+ _x , _y = trainer .state .batch
97
+ _x , _y = _x .to (model_device ), _y .to (model_device )
98
+ accumulation [0 ] += 0.2 * _x .item () * (theta [0 ] * _x .item () - _y .item ())
99
+ # loss is not accumulated !
100
+ loss [0 ] = mse_loss (model (_x ), _y ).item () / gradient_accumulation_steps
101
+
102
+ @trainer .on (Events .ITERATION_COMPLETED (every = gradient_accumulation_steps ))
103
+ def _ ():
104
+ theta [0 ] -= accumulation [0 ] / gradient_accumulation_steps
105
+ assert pytest .approx (model .weight .data [0 , 0 ].item (), abs = 1.e-5 ) == theta [0 ]
106
+ assert pytest .approx (trainer .state .output [- 1 ], abs = 1e-5 ) == loss [0 ]
107
+ accumulation [0 ] = loss [0 ] = 0.0
108
+
94
109
if model_device == trainer_device or ((model_device == "cpu" ) ^ (trainer_device == "cpu" )):
95
- state = trainer .run (data )
96
110
97
- assert round (state .output [- 1 ], 4 ) == loss , state .output [- 1 ]
98
- assert round (model .weight .data [0 , 0 ].item (), 4 ) == weight , model .weight .item ()
99
- assert round (model .bias .item (), 4 ) == bias , model .bias .item ()
111
+ state = trainer .run (data )
100
112
101
113
if amp_mode == "amp" :
102
114
assert state .output [0 ].dtype is torch .half
@@ -105,25 +117,6 @@ def _test_create_supervised_trainer(
105
117
else :
106
118
assert not hasattr (state , "scaler" )
107
119
108
- # Test for Gradient Accumulation Turned Off
109
- trainer , model = _default_create_supervised_trainer (
110
- model_device = model_device , trainer_device = trainer_device , trace = trace , amp_mode = amp_mode , scaler = scaler ,
111
- )
112
- x = torch .tensor ([[1.0 ], [1.0 ], [1.0 ], [1.0 ], [1.0 ]])
113
- data = [(_x , _y ) for _x , _y in zip (x , x )]
114
-
115
- for i in range (len (data )):
116
- original_weights = model .weight .data [0 , 0 ].item ()
117
- original_bias = model .bias .item ()
118
- state = trainer .run ([data [i ]])
119
- assert state .output [- 1 ] == pytest .approx ((1 - (original_weights + original_bias )) ** 2 ), state .output [- 1 ]
120
- assert model .weight .data [0 , 0 ].item () == pytest .approx (
121
- original_weights + 2 * 0.1 * (1 - (original_weights + original_bias ))
122
- ), model .weight .item ()
123
- assert model .bias .item () == pytest .approx (
124
- original_bias + 2 * 0.1 * (1 - (original_weights + original_bias ))
125
- ), model .bias .item ()
126
-
127
120
else :
128
121
if LooseVersion (torch .__version__ ) >= LooseVersion ("1.7.0" ):
129
122
# This is broken in 1.6.0 but will be probably fixed with 1.7.0
@@ -349,19 +342,22 @@ def _test_create_evaluation_step(
349
342
350
343
def test_create_supervised_trainer ():
351
344
_test_create_supervised_trainer_wrong_accumulation ()
352
- _test_create_supervised_trainer ()
345
+ _test_create_supervised_trainer (gradient_accumulation_steps = 1 )
346
+ _test_create_supervised_trainer (gradient_accumulation_steps = 3 )
353
347
_test_create_mocked_supervised_trainer ()
354
348
355
349
356
350
def test_create_supervised_trainer_with_cpu ():
357
351
_test_create_supervised_trainer_wrong_accumulation (trainer_device = "cpu" )
358
- _test_create_supervised_trainer (trainer_device = "cpu" )
352
+ _test_create_supervised_trainer (gradient_accumulation_steps = 1 , trainer_device = "cpu" )
353
+ _test_create_supervised_trainer (gradient_accumulation_steps = 3 , trainer_device = "cpu" )
359
354
_test_create_mocked_supervised_trainer (trainer_device = "cpu" )
360
355
361
356
362
357
def test_create_supervised_trainer_traced_with_cpu ():
363
358
_test_create_supervised_trainer_wrong_accumulation (trainer_device = "cpu" )
364
- _test_create_supervised_trainer (trainer_device = "cpu" , trace = True )
359
+ _test_create_supervised_trainer (gradient_accumulation_steps = 1 , trainer_device = "cpu" , trace = True )
360
+ _test_create_supervised_trainer (gradient_accumulation_steps = 3 , trainer_device = "cpu" , trace = True )
365
361
_test_create_mocked_supervised_trainer (trainer_device = "cpu" , trace = True )
366
362
367
363
@@ -412,7 +408,12 @@ def test_create_supervised_trainer_scaler_not_amp():
412
408
def test_create_supervised_trainer_on_cuda ():
413
409
model_device = trainer_device = "cuda"
414
410
_test_create_supervised_trainer_wrong_accumulation (model_device = model_device , trainer_device = trainer_device )
415
- _test_create_supervised_trainer (model_device = model_device , trainer_device = trainer_device )
411
+ _test_create_supervised_trainer (
412
+ gradient_accumulation_steps = 1 , model_device = model_device , trainer_device = trainer_device
413
+ )
414
+ _test_create_supervised_trainer (
415
+ gradient_accumulation_steps = 3 , model_device = model_device , trainer_device = trainer_device
416
+ )
416
417
_test_create_mocked_supervised_trainer (model_device = model_device , trainer_device = trainer_device )
417
418
418
419
@@ -424,7 +425,10 @@ def test_create_supervised_trainer_on_cuda_amp():
424
425
model_device = model_device , trainer_device = trainer_device , amp_mode = "amp"
425
426
)
426
427
_test_create_supervised_trainer (
427
- model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" ,
428
+ gradient_accumulation_steps = 1 , model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" ,
429
+ )
430
+ _test_create_supervised_trainer (
431
+ gradient_accumulation_steps = 3 , model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" ,
428
432
)
429
433
_test_create_mocked_supervised_trainer (model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" )
430
434
@@ -436,17 +440,37 @@ def test_create_supervised_trainer_on_cuda_amp_scaler():
436
440
_test_create_supervised_trainer_wrong_accumulation (
437
441
model_device = model_device , trainer_device = trainer_device , amp_mode = "amp"
438
442
)
439
-
440
443
_test_create_supervised_trainer (
441
- model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" , scaler = True ,
444
+ gradient_accumulation_steps = 1 ,
445
+ model_device = model_device ,
446
+ trainer_device = trainer_device ,
447
+ amp_mode = "amp" ,
448
+ scaler = True ,
449
+ )
450
+ _test_create_supervised_trainer (
451
+ gradient_accumulation_steps = 3 ,
452
+ model_device = model_device ,
453
+ trainer_device = trainer_device ,
454
+ amp_mode = "amp" ,
455
+ scaler = True ,
442
456
)
443
457
_test_create_mocked_supervised_trainer (
444
458
model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" , scaler = True
445
459
)
446
-
447
460
scaler = torch .cuda .amp .GradScaler (enabled = torch .cuda .is_available ())
448
461
_test_create_supervised_trainer (
449
- model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" , scaler = scaler ,
462
+ gradient_accumulation_steps = 1 ,
463
+ model_device = model_device ,
464
+ trainer_device = trainer_device ,
465
+ amp_mode = "amp" ,
466
+ scaler = scaler ,
467
+ )
468
+ _test_create_supervised_trainer (
469
+ gradient_accumulation_steps = 3 ,
470
+ model_device = model_device ,
471
+ trainer_device = trainer_device ,
472
+ amp_mode = "amp" ,
473
+ scaler = scaler ,
450
474
)
451
475
_test_create_mocked_supervised_trainer (
452
476
model_device = model_device , trainer_device = trainer_device , amp_mode = "amp" , scaler = scaler
@@ -460,11 +484,12 @@ def test_create_supervised_trainer_on_cuda_apex():
460
484
_test_create_supervised_trainer_wrong_accumulation (
461
485
model_device = model_device , trainer_device = trainer_device , amp_mode = "apex"
462
486
)
463
-
464
487
_test_create_supervised_trainer (
465
- model_device = model_device , trainer_device = trainer_device , amp_mode = "apex" ,
488
+ gradient_accumulation_steps = 1 , model_device = model_device , trainer_device = trainer_device , amp_mode = "apex" ,
489
+ )
490
+ _test_create_supervised_trainer (
491
+ gradient_accumulation_steps = 3 , model_device = model_device , trainer_device = trainer_device , amp_mode = "apex" ,
466
492
)
467
-
468
493
_test_create_mocked_supervised_trainer (model_device = model_device , trainer_device = trainer_device , amp_mode = "apex" )
469
494
470
495
@@ -488,7 +513,12 @@ def test_create_supervised_trainer_on_tpu_no_xla():
488
513
def test_create_supervised_trainer_on_tpu ():
489
514
model_device = trainer_device = "xla"
490
515
_test_create_supervised_trainer_wrong_accumulation (model_device = model_device , trainer_device = trainer_device )
491
- _test_create_supervised_trainer (model_device = model_device , trainer_device = trainer_device )
516
+ _test_create_supervised_trainer (
517
+ gradient_accumulation_steps = 1 , model_device = model_device , trainer_device = trainer_device
518
+ )
519
+ _test_create_supervised_trainer (
520
+ gradient_accumulation_steps = 3 , model_device = model_device , trainer_device = trainer_device
521
+ )
492
522
_test_create_mocked_supervised_trainer (model_device = model_device , trainer_device = trainer_device )
493
523
494
524
@@ -503,7 +533,8 @@ def test_create_supervised_trainer_on_tpu_amp():
503
533
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Skip if no GPU" )
504
534
def test_create_supervised_trainer_on_cuda_with_model_on_cpu ():
505
535
_test_create_supervised_trainer_wrong_accumulation (trainer_device = "cuda" )
506
- _test_create_supervised_trainer (trainer_device = "cuda" )
536
+ _test_create_supervised_trainer (gradient_accumulation_steps = 1 , trainer_device = "cuda" )
537
+ _test_create_supervised_trainer (gradient_accumulation_steps = 3 , trainer_device = "cuda" )
507
538
_test_create_mocked_supervised_trainer (trainer_device = "cuda" )
508
539
509
540
0 commit comments