Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def output_transform_fn(x, y, y_pred, loss):
.. versionchanged:: 0.4.7
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
Added ``model_transform`` to transform model's output
"""

device_type = device.type if isinstance(device, torch.device) else device
Expand Down Expand Up @@ -593,6 +593,7 @@ def supervised_evaluation_step(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
) -> Callable:
"""
Expand All @@ -606,6 +607,8 @@ def supervised_evaluation_step(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the predictions:
``y_pred = model_transform(model(x))``.
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
Expand All @@ -624,13 +627,16 @@ def supervised_evaluation_step(
The `model` should be moved by the user before creating an optimizer.

.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
"""

def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
output = model(x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

return evaluate_step
Expand All @@ -641,6 +647,7 @@ def supervised_evaluation_step_amp(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
) -> Callable:
"""
Expand All @@ -654,6 +661,8 @@ def supervised_evaluation_step_amp(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the predictions:
``y_pred = model_transform(model(x))``.
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
Expand All @@ -672,6 +681,8 @@ def supervised_evaluation_step_amp(
The `model` should be moved by the user before creating an optimizer.

.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
"""
try:
from torch.cuda.amp import autocast
Expand All @@ -683,7 +694,8 @@ def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, T
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
y_pred = model(x)
output = model(x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

return evaluate_step
Expand Down Expand Up @@ -711,6 +723,8 @@ def create_supervised_evaluator(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the predictions:
``y_pred = model_transform(model(x))``.
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
Expand All @@ -737,17 +751,33 @@ def create_supervised_evaluator(
- `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

.. versionchanged:: 0.4.5
- Added ``amp_mode`` argument for automatic mixed precision.
Added ``amp_mode`` argument for automatic mixed precision.
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
"""
device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
mode, _ = _check_arg(on_tpu, amp_mode, None)

metrics = metrics or {}
if mode == "amp":
evaluate_step = supervised_evaluation_step_amp(model, device, non_blocking, prepare_batch, output_transform)
evaluate_step = supervised_evaluation_step_amp(
model,
device,
non_blocking=non_blocking,
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
)
else:
evaluate_step = supervised_evaluation_step(model, device, non_blocking, prepare_batch, output_transform)
evaluate_step = supervised_evaluation_step(
model,
device,
non_blocking=non_blocking,
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
)

evaluator = Engine(evaluate_step)

Expand Down
26 changes: 23 additions & 3 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,18 @@ def _default_create_supervised_evaluator(
evaluator_device: Optional[str] = None,
trace: bool = False,
amp_mode: str = None,
with_model_transform: bool = False,
):
model = DummyModel()
if with_model_transform:

def get_first_element(output):
return output[0]

model = DummyModel(output_as_list=True)
model_transform = get_first_element
else:
model = DummyModel()
model_transform = None

if model_device:
model.to(model_device)
Expand All @@ -232,7 +242,12 @@ def _default_create_supervised_evaluator(
example_input = torch.randn(1, 1)
model = torch.jit.trace(model, example_input)

evaluator = create_supervised_evaluator(model, device=evaluator_device, amp_mode=amp_mode)
evaluator = create_supervised_evaluator(
model,
device=evaluator_device,
amp_mode=amp_mode,
model_transform=model_transform if model_transform is not None else lambda x: x,
)

assert model.fc.weight.data[0, 0].item() == approx(0.0)

Expand All @@ -244,9 +259,14 @@ def _test_create_supervised_evaluator(
evaluator_device: Optional[str] = None,
trace: bool = False,
amp_mode: str = None,
with_model_transform: bool = False,
):
model, evaluator = _default_create_supervised_evaluator(
model_device=model_device, evaluator_device=evaluator_device, trace=trace, amp_mode=amp_mode
model_device=model_device,
evaluator_device=evaluator_device,
trace=trace,
amp_mode=amp_mode,
with_model_transform=with_model_transform,
)
x = torch.tensor([[1.0], [2.0]])
y = torch.tensor([[3.0], [5.0]])
Expand Down