Skip to content

Commit 5ccf0d7

Browse files
authored
Reintroduced model_transform into supervised_evaluation (#2896)
FIxes #2894
1 parent 65ac304 commit 5ccf0d7

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

ignite/engine/__init__.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def output_transform_fn(x, y, y_pred, loss):
524524
.. versionchanged:: 0.4.7
525525
Added Gradient Accumulation argument for all supervised training methods.
526526
.. versionchanged:: 0.4.11
527-
Added `model_transform` to transform model's output
527+
Added ``model_transform`` to transform model's output
528528
"""
529529

530530
device_type = device.type if isinstance(device, torch.device) else device
@@ -593,6 +593,7 @@ def supervised_evaluation_step(
593593
device: Optional[Union[str, torch.device]] = None,
594594
non_blocking: bool = False,
595595
prepare_batch: Callable = _prepare_batch,
596+
model_transform: Callable[[Any], Any] = lambda output: output,
596597
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
597598
) -> Callable:
598599
"""
@@ -606,6 +607,8 @@ def supervised_evaluation_step(
606607
with respect to the host. For other cases, this argument has no effect.
607608
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
608609
tuple of tensors `(batch_x, batch_y)`.
610+
model_transform: function that receives the output from the model and convert it into the predictions:
611+
``y_pred = model_transform(model(x))``.
609612
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
610613
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
611614
output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -624,13 +627,16 @@ def supervised_evaluation_step(
624627
The `model` should be moved by the user before creating an optimizer.
625628
626629
.. versionadded:: 0.4.5
630+
.. versionchanged:: 0.4.12
631+
Added ``model_transform`` to transform model's output
627632
"""
628633

629634
def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
630635
model.eval()
631636
with torch.no_grad():
632637
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
633-
y_pred = model(x)
638+
output = model(x)
639+
y_pred = model_transform(output)
634640
return output_transform(x, y, y_pred)
635641

636642
return evaluate_step
@@ -641,6 +647,7 @@ def supervised_evaluation_step_amp(
641647
device: Optional[Union[str, torch.device]] = None,
642648
non_blocking: bool = False,
643649
prepare_batch: Callable = _prepare_batch,
650+
model_transform: Callable[[Any], Any] = lambda output: output,
644651
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
645652
) -> Callable:
646653
"""
@@ -654,6 +661,8 @@ def supervised_evaluation_step_amp(
654661
with respect to the host. For other cases, this argument has no effect.
655662
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
656663
tuple of tensors `(batch_x, batch_y)`.
664+
model_transform: function that receives the output from the model and convert it into the predictions:
665+
``y_pred = model_transform(model(x))``.
657666
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
658667
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
659668
output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -672,6 +681,8 @@ def supervised_evaluation_step_amp(
672681
The `model` should be moved by the user before creating an optimizer.
673682
674683
.. versionadded:: 0.4.5
684+
.. versionchanged:: 0.4.12
685+
Added ``model_transform`` to transform model's output
675686
"""
676687
try:
677688
from torch.cuda.amp import autocast
@@ -683,7 +694,8 @@ def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, T
683694
with torch.no_grad():
684695
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
685696
with autocast(enabled=True):
686-
y_pred = model(x)
697+
output = model(x)
698+
y_pred = model_transform(output)
687699
return output_transform(x, y, y_pred)
688700

689701
return evaluate_step
@@ -711,6 +723,8 @@ def create_supervised_evaluator(
711723
with respect to the host. For other cases, this argument has no effect.
712724
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
713725
tuple of tensors `(batch_x, batch_y)`.
726+
model_transform: function that receives the output from the model and convert it into the predictions:
727+
``y_pred = model_transform(model(x))``.
714728
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
715729
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
716730
output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -737,17 +751,33 @@ def create_supervised_evaluator(
737751
- `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
738752
739753
.. versionchanged:: 0.4.5
740-
- Added ``amp_mode`` argument for automatic mixed precision.
754+
Added ``amp_mode`` argument for automatic mixed precision.
755+
.. versionchanged:: 0.4.12
756+
Added ``model_transform`` to transform model's output
741757
"""
742758
device_type = device.type if isinstance(device, torch.device) else device
743759
on_tpu = "xla" in device_type if device_type is not None else False
744760
mode, _ = _check_arg(on_tpu, amp_mode, None)
745761

746762
metrics = metrics or {}
747763
if mode == "amp":
748-
evaluate_step = supervised_evaluation_step_amp(model, device, non_blocking, prepare_batch, output_transform)
764+
evaluate_step = supervised_evaluation_step_amp(
765+
model,
766+
device,
767+
non_blocking=non_blocking,
768+
prepare_batch=prepare_batch,
769+
model_transform=model_transform,
770+
output_transform=output_transform,
771+
)
749772
else:
750-
evaluate_step = supervised_evaluation_step(model, device, non_blocking, prepare_batch, output_transform)
773+
evaluate_step = supervised_evaluation_step(
774+
model,
775+
device,
776+
non_blocking=non_blocking,
777+
prepare_batch=prepare_batch,
778+
model_transform=model_transform,
779+
output_transform=output_transform,
780+
)
751781

752782
evaluator = Engine(evaluate_step)
753783

tests/ignite/engine/test_create_supervised.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,18 @@ def _default_create_supervised_evaluator(
220220
evaluator_device: Optional[str] = None,
221221
trace: bool = False,
222222
amp_mode: str = None,
223+
with_model_transform: bool = False,
223224
):
224-
model = DummyModel()
225+
if with_model_transform:
226+
227+
def get_first_element(output):
228+
return output[0]
229+
230+
model = DummyModel(output_as_list=True)
231+
model_transform = get_first_element
232+
else:
233+
model = DummyModel()
234+
model_transform = None
225235

226236
if model_device:
227237
model.to(model_device)
@@ -232,7 +242,12 @@ def _default_create_supervised_evaluator(
232242
example_input = torch.randn(1, 1)
233243
model = torch.jit.trace(model, example_input)
234244

235-
evaluator = create_supervised_evaluator(model, device=evaluator_device, amp_mode=amp_mode)
245+
evaluator = create_supervised_evaluator(
246+
model,
247+
device=evaluator_device,
248+
amp_mode=amp_mode,
249+
model_transform=model_transform if model_transform is not None else lambda x: x,
250+
)
236251

237252
assert model.fc.weight.data[0, 0].item() == approx(0.0)
238253

@@ -244,9 +259,14 @@ def _test_create_supervised_evaluator(
244259
evaluator_device: Optional[str] = None,
245260
trace: bool = False,
246261
amp_mode: str = None,
262+
with_model_transform: bool = False,
247263
):
248264
model, evaluator = _default_create_supervised_evaluator(
249-
model_device=model_device, evaluator_device=evaluator_device, trace=trace, amp_mode=amp_mode
265+
model_device=model_device,
266+
evaluator_device=evaluator_device,
267+
trace=trace,
268+
amp_mode=amp_mode,
269+
with_model_transform=with_model_transform,
250270
)
251271
x = torch.tensor([[1.0], [2.0]])
252272
y = torch.tensor([[3.0], [5.0]])

0 commit comments

Comments
 (0)