Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 78 additions & 47 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ignite.distributed as idist
from ignite.engine import (
Engine,
Events,
_check_arg,
create_supervised_evaluator,
create_supervised_trainer,
Expand All @@ -26,20 +27,19 @@


def _default_create_supervised_trainer(
gradient_accumulation_steps=1,
gradient_accumulation_steps: int = 1,
model_device: Optional[str] = None,
trainer_device: Optional[str] = None,
trace: bool = False,
amp_mode: str = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
):
model = Linear(1, 1)
model = Linear(1, 1, bias=False)

if model_device:
model.to(model_device)

model.weight.data.zero_()
model.bias.data.zero_()
optimizer = SGD(model.parameters(), 0.1)

if trace:
Expand All @@ -62,16 +62,12 @@ def _default_create_supervised_trainer(
gradient_accumulation_steps=gradient_accumulation_steps,
)
assert model.weight.data[0, 0].item() == approx(0.0)
assert model.bias.item() == approx(0.0)

return trainer, model


def _test_create_supervised_trainer(
gradient_accumulation_steps=3,
loss=0.0045,
weight=0.0540,
bias=0.1133,
gradient_accumulation_steps: int = 1,
model_device: Optional[str] = None,
trainer_device: Optional[str] = None,
trace: bool = False,
Expand All @@ -87,16 +83,32 @@ def _test_create_supervised_trainer(
scaler=scaler,
)

x = torch.tensor([[0.1], [0.3], [0.7], [0.9], [1.3]])
y = torch.tensor([[0.3], [0.5], [0.9], [1.3], [0.3]])
x = torch.tensor([[0.01], [0.02], [0.03], [0.04], [0.05]])
y = torch.tensor([[0.015], [0.025], [0.035], [0.045], [0.055]])
data = [(_x, _y) for _x, _y in zip(x, y)]

theta = [0.0]
accumulation = [0.0]
loss = [0.0]

@trainer.on(Events.ITERATION_COMPLETED)
def _():
_x, _y = trainer.state.batch
_x, _y = _x.to(model_device), _y.to(model_device)
accumulation[0] += 0.2 * _x.item() * (theta[0] * _x.item() - _y.item())
# loss is not accumulated !
loss[0] = mse_loss(model(_x), _y).item() / gradient_accumulation_steps

@trainer.on(Events.ITERATION_COMPLETED(every=gradient_accumulation_steps))
def _():
theta[0] -= accumulation[0] / gradient_accumulation_steps
assert pytest.approx(model.weight.data[0, 0].item(), abs=1.e-5) == theta[0]
assert pytest.approx(trainer.state.output[-1], abs=1e-5) == loss[0]
accumulation[0] = loss[0] = 0.0

if model_device == trainer_device or ((model_device == "cpu") ^ (trainer_device == "cpu")):
state = trainer.run(data)

assert round(state.output[-1], 4) == loss, state.output[-1]
assert round(model.weight.data[0, 0].item(), 4) == weight, model.weight.item()
assert round(model.bias.item(), 4) == bias, model.bias.item()
state = trainer.run(data)

if amp_mode == "amp":
assert state.output[0].dtype is torch.half
Expand All @@ -105,25 +117,6 @@ def _test_create_supervised_trainer(
else:
assert not hasattr(state, "scaler")

# Test for Gradient Accumulation Turned Off
trainer, model = _default_create_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, trace=trace, amp_mode=amp_mode, scaler=scaler,
)
x = torch.tensor([[1.0], [1.0], [1.0], [1.0], [1.0]])
data = [(_x, _y) for _x, _y in zip(x, x)]

for i in range(len(data)):
original_weights = model.weight.data[0, 0].item()
original_bias = model.bias.item()
state = trainer.run([data[i]])
assert state.output[-1] == pytest.approx((1 - (original_weights + original_bias)) ** 2), state.output[-1]
assert model.weight.data[0, 0].item() == pytest.approx(
original_weights + 2 * 0.1 * (1 - (original_weights + original_bias))
), model.weight.item()
assert model.bias.item() == pytest.approx(
original_bias + 2 * 0.1 * (1 - (original_weights + original_bias))
), model.bias.item()

else:
if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"):
# This is broken in 1.6.0 but will be probably fixed with 1.7.0
Expand Down Expand Up @@ -349,19 +342,22 @@ def _test_create_evaluation_step(

def test_create_supervised_trainer():
_test_create_supervised_trainer_wrong_accumulation()
_test_create_supervised_trainer()
_test_create_supervised_trainer(gradient_accumulation_steps=1)
_test_create_supervised_trainer(gradient_accumulation_steps=3)
_test_create_mocked_supervised_trainer()


def test_create_supervised_trainer_with_cpu():
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu")
_test_create_supervised_trainer(trainer_device="cpu")
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu")
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu")
_test_create_mocked_supervised_trainer(trainer_device="cpu")


def test_create_supervised_trainer_traced_with_cpu():
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu")
_test_create_supervised_trainer(trainer_device="cpu", trace=True)
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu", trace=True)
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu", trace=True)
_test_create_mocked_supervised_trainer(trainer_device="cpu", trace=True)


Expand Down Expand Up @@ -412,7 +408,12 @@ def test_create_supervised_trainer_scaler_not_amp():
def test_create_supervised_trainer_on_cuda():
model_device = trainer_device = "cuda"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


Expand All @@ -424,7 +425,10 @@ def test_create_supervised_trainer_on_cuda_amp():
model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
)
_test_create_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp",
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="amp",
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="amp",
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="amp")

Expand All @@ -436,17 +440,37 @@ def test_create_supervised_trainer_on_cuda_amp_scaler():
_test_create_supervised_trainer_wrong_accumulation(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
)

_test_create_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=True,
gradient_accumulation_steps=1,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=True,
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=True,
)
_test_create_mocked_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=True
)

scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
_test_create_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=scaler,
gradient_accumulation_steps=1,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=scaler,
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=scaler,
)
_test_create_mocked_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=scaler
Expand All @@ -460,11 +484,12 @@ def test_create_supervised_trainer_on_cuda_apex():
_test_create_supervised_trainer_wrong_accumulation(
model_device=model_device, trainer_device=trainer_device, amp_mode="apex"
)

_test_create_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="apex",
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="apex",
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="apex",
)

_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="apex")


Expand All @@ -488,7 +513,12 @@ def test_create_supervised_trainer_on_tpu_no_xla():
def test_create_supervised_trainer_on_tpu():
model_device = trainer_device = "xla"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


Expand All @@ -503,7 +533,8 @@ def test_create_supervised_trainer_on_tpu_amp():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_with_model_on_cpu():
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cuda")
_test_create_supervised_trainer(trainer_device="cuda")
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cuda")
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cuda")
_test_create_mocked_supervised_trainer(trainer_device="cuda")


Expand Down