Skip to content

Commit 47961cd

Browse files
Ishan-Kumar2Desroziers
andauthored
set amp_mode and scaler to false in grad_acc=1 test (#2243)
* set amp_mode and scaler to false in grad_acc=1 test * removed bias from test * refactor tests * fix approx * fix device * fix tol Co-authored-by: Desroziers <sylvain.desroziers@ifpen.fr>
1 parent 9cce2e8 commit 47961cd

File tree

1 file changed

+78
-47
lines changed

1 file changed

+78
-47
lines changed

tests/ignite/engine/test_create_supervised.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import ignite.distributed as idist
1616
from ignite.engine import (
1717
Engine,
18+
Events,
1819
_check_arg,
1920
create_supervised_evaluator,
2021
create_supervised_trainer,
@@ -26,20 +27,19 @@
2627

2728

2829
def _default_create_supervised_trainer(
29-
gradient_accumulation_steps=1,
30+
gradient_accumulation_steps: int = 1,
3031
model_device: Optional[str] = None,
3132
trainer_device: Optional[str] = None,
3233
trace: bool = False,
3334
amp_mode: str = None,
3435
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
3536
):
36-
model = Linear(1, 1)
37+
model = Linear(1, 1, bias=False)
3738

3839
if model_device:
3940
model.to(model_device)
4041

4142
model.weight.data.zero_()
42-
model.bias.data.zero_()
4343
optimizer = SGD(model.parameters(), 0.1)
4444

4545
if trace:
@@ -62,16 +62,12 @@ def _default_create_supervised_trainer(
6262
gradient_accumulation_steps=gradient_accumulation_steps,
6363
)
6464
assert model.weight.data[0, 0].item() == approx(0.0)
65-
assert model.bias.item() == approx(0.0)
6665

6766
return trainer, model
6867

6968

7069
def _test_create_supervised_trainer(
71-
gradient_accumulation_steps=3,
72-
loss=0.0045,
73-
weight=0.0540,
74-
bias=0.1133,
70+
gradient_accumulation_steps: int = 1,
7571
model_device: Optional[str] = None,
7672
trainer_device: Optional[str] = None,
7773
trace: bool = False,
@@ -87,16 +83,32 @@ def _test_create_supervised_trainer(
8783
scaler=scaler,
8884
)
8985

90-
x = torch.tensor([[0.1], [0.3], [0.7], [0.9], [1.3]])
91-
y = torch.tensor([[0.3], [0.5], [0.9], [1.3], [0.3]])
86+
x = torch.tensor([[0.01], [0.02], [0.03], [0.04], [0.05]])
87+
y = torch.tensor([[0.015], [0.025], [0.035], [0.045], [0.055]])
9288
data = [(_x, _y) for _x, _y in zip(x, y)]
9389

90+
theta = [0.0]
91+
accumulation = [0.0]
92+
loss = [0.0]
93+
94+
@trainer.on(Events.ITERATION_COMPLETED)
95+
def _():
96+
_x, _y = trainer.state.batch
97+
_x, _y = _x.to(model_device), _y.to(model_device)
98+
accumulation[0] += 0.2 * _x.item() * (theta[0] * _x.item() - _y.item())
99+
# loss is not accumulated !
100+
loss[0] = mse_loss(model(_x), _y).item() / gradient_accumulation_steps
101+
102+
@trainer.on(Events.ITERATION_COMPLETED(every=gradient_accumulation_steps))
103+
def _():
104+
theta[0] -= accumulation[0] / gradient_accumulation_steps
105+
assert pytest.approx(model.weight.data[0, 0].item(), abs=1.e-5) == theta[0]
106+
assert pytest.approx(trainer.state.output[-1], abs=1e-5) == loss[0]
107+
accumulation[0] = loss[0] = 0.0
108+
94109
if model_device == trainer_device or ((model_device == "cpu") ^ (trainer_device == "cpu")):
95-
state = trainer.run(data)
96110

97-
assert round(state.output[-1], 4) == loss, state.output[-1]
98-
assert round(model.weight.data[0, 0].item(), 4) == weight, model.weight.item()
99-
assert round(model.bias.item(), 4) == bias, model.bias.item()
111+
state = trainer.run(data)
100112

101113
if amp_mode == "amp":
102114
assert state.output[0].dtype is torch.half
@@ -105,25 +117,6 @@ def _test_create_supervised_trainer(
105117
else:
106118
assert not hasattr(state, "scaler")
107119

108-
# Test for Gradient Accumulation Turned Off
109-
trainer, model = _default_create_supervised_trainer(
110-
model_device=model_device, trainer_device=trainer_device, trace=trace, amp_mode=amp_mode, scaler=scaler,
111-
)
112-
x = torch.tensor([[1.0], [1.0], [1.0], [1.0], [1.0]])
113-
data = [(_x, _y) for _x, _y in zip(x, x)]
114-
115-
for i in range(len(data)):
116-
original_weights = model.weight.data[0, 0].item()
117-
original_bias = model.bias.item()
118-
state = trainer.run([data[i]])
119-
assert state.output[-1] == pytest.approx((1 - (original_weights + original_bias)) ** 2), state.output[-1]
120-
assert model.weight.data[0, 0].item() == pytest.approx(
121-
original_weights + 2 * 0.1 * (1 - (original_weights + original_bias))
122-
), model.weight.item()
123-
assert model.bias.item() == pytest.approx(
124-
original_bias + 2 * 0.1 * (1 - (original_weights + original_bias))
125-
), model.bias.item()
126-
127120
else:
128121
if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"):
129122
# This is broken in 1.6.0 but will be probably fixed with 1.7.0
@@ -349,19 +342,22 @@ def _test_create_evaluation_step(
349342

350343
def test_create_supervised_trainer():
351344
_test_create_supervised_trainer_wrong_accumulation()
352-
_test_create_supervised_trainer()
345+
_test_create_supervised_trainer(gradient_accumulation_steps=1)
346+
_test_create_supervised_trainer(gradient_accumulation_steps=3)
353347
_test_create_mocked_supervised_trainer()
354348

355349

356350
def test_create_supervised_trainer_with_cpu():
357351
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu")
358-
_test_create_supervised_trainer(trainer_device="cpu")
352+
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu")
353+
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu")
359354
_test_create_mocked_supervised_trainer(trainer_device="cpu")
360355

361356

362357
def test_create_supervised_trainer_traced_with_cpu():
363358
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu")
364-
_test_create_supervised_trainer(trainer_device="cpu", trace=True)
359+
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu", trace=True)
360+
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu", trace=True)
365361
_test_create_mocked_supervised_trainer(trainer_device="cpu", trace=True)
366362

367363

@@ -412,7 +408,12 @@ def test_create_supervised_trainer_scaler_not_amp():
412408
def test_create_supervised_trainer_on_cuda():
413409
model_device = trainer_device = "cuda"
414410
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
415-
_test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
411+
_test_create_supervised_trainer(
412+
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
413+
)
414+
_test_create_supervised_trainer(
415+
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
416+
)
416417
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
417418

418419

@@ -424,7 +425,10 @@ def test_create_supervised_trainer_on_cuda_amp():
424425
model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
425426
)
426427
_test_create_supervised_trainer(
427-
model_device=model_device, trainer_device=trainer_device, amp_mode="amp",
428+
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="amp",
429+
)
430+
_test_create_supervised_trainer(
431+
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="amp",
428432
)
429433
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="amp")
430434

@@ -436,17 +440,37 @@ def test_create_supervised_trainer_on_cuda_amp_scaler():
436440
_test_create_supervised_trainer_wrong_accumulation(
437441
model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
438442
)
439-
440443
_test_create_supervised_trainer(
441-
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=True,
444+
gradient_accumulation_steps=1,
445+
model_device=model_device,
446+
trainer_device=trainer_device,
447+
amp_mode="amp",
448+
scaler=True,
449+
)
450+
_test_create_supervised_trainer(
451+
gradient_accumulation_steps=3,
452+
model_device=model_device,
453+
trainer_device=trainer_device,
454+
amp_mode="amp",
455+
scaler=True,
442456
)
443457
_test_create_mocked_supervised_trainer(
444458
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=True
445459
)
446-
447460
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
448461
_test_create_supervised_trainer(
449-
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=scaler,
462+
gradient_accumulation_steps=1,
463+
model_device=model_device,
464+
trainer_device=trainer_device,
465+
amp_mode="amp",
466+
scaler=scaler,
467+
)
468+
_test_create_supervised_trainer(
469+
gradient_accumulation_steps=3,
470+
model_device=model_device,
471+
trainer_device=trainer_device,
472+
amp_mode="amp",
473+
scaler=scaler,
450474
)
451475
_test_create_mocked_supervised_trainer(
452476
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=scaler
@@ -460,11 +484,12 @@ def test_create_supervised_trainer_on_cuda_apex():
460484
_test_create_supervised_trainer_wrong_accumulation(
461485
model_device=model_device, trainer_device=trainer_device, amp_mode="apex"
462486
)
463-
464487
_test_create_supervised_trainer(
465-
model_device=model_device, trainer_device=trainer_device, amp_mode="apex",
488+
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="apex",
489+
)
490+
_test_create_supervised_trainer(
491+
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="apex",
466492
)
467-
468493
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="apex")
469494

470495

@@ -488,7 +513,12 @@ def test_create_supervised_trainer_on_tpu_no_xla():
488513
def test_create_supervised_trainer_on_tpu():
489514
model_device = trainer_device = "xla"
490515
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
491-
_test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
516+
_test_create_supervised_trainer(
517+
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
518+
)
519+
_test_create_supervised_trainer(
520+
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
521+
)
492522
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
493523

494524

@@ -503,7 +533,8 @@ def test_create_supervised_trainer_on_tpu_amp():
503533
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
504534
def test_create_supervised_trainer_on_cuda_with_model_on_cpu():
505535
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cuda")
506-
_test_create_supervised_trainer(trainer_device="cuda")
536+
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cuda")
537+
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cuda")
507538
_test_create_mocked_supervised_trainer(trainer_device="cuda")
508539

509540

0 commit comments

Comments
 (0)