@@ -524,7 +524,7 @@ def output_transform_fn(x, y, y_pred, loss):
524
524
.. versionchanged:: 0.4.7
525
525
Added Gradient Accumulation argument for all supervised training methods.
526
526
.. versionchanged:: 0.4.11
527
- Added `model_transform` to transform model's output
527
+ Added `` model_transform` ` to transform model's output
528
528
"""
529
529
530
530
device_type = device .type if isinstance (device , torch .device ) else device
@@ -593,6 +593,7 @@ def supervised_evaluation_step(
593
593
device : Optional [Union [str , torch .device ]] = None ,
594
594
non_blocking : bool = False ,
595
595
prepare_batch : Callable = _prepare_batch ,
596
+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
596
597
output_transform : Callable [[Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
597
598
) -> Callable :
598
599
"""
@@ -606,6 +607,8 @@ def supervised_evaluation_step(
606
607
with respect to the host. For other cases, this argument has no effect.
607
608
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
608
609
tuple of tensors `(batch_x, batch_y)`.
610
+ model_transform: function that receives the output from the model and convert it into the predictions:
611
+ ``y_pred = model_transform(model(x))``.
609
612
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
610
613
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
611
614
output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -624,13 +627,16 @@ def supervised_evaluation_step(
624
627
The `model` should be moved by the user before creating an optimizer.
625
628
626
629
.. versionadded:: 0.4.5
630
+ .. versionchanged:: 0.4.12
631
+ Added ``model_transform`` to transform model's output
627
632
"""
628
633
629
634
def evaluate_step (engine : Engine , batch : Sequence [torch .Tensor ]) -> Union [Any , Tuple [torch .Tensor ]]:
630
635
model .eval ()
631
636
with torch .no_grad ():
632
637
x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
633
- y_pred = model (x )
638
+ output = model (x )
639
+ y_pred = model_transform (output )
634
640
return output_transform (x , y , y_pred )
635
641
636
642
return evaluate_step
@@ -641,6 +647,7 @@ def supervised_evaluation_step_amp(
641
647
device : Optional [Union [str , torch .device ]] = None ,
642
648
non_blocking : bool = False ,
643
649
prepare_batch : Callable = _prepare_batch ,
650
+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
644
651
output_transform : Callable [[Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
645
652
) -> Callable :
646
653
"""
@@ -654,6 +661,8 @@ def supervised_evaluation_step_amp(
654
661
with respect to the host. For other cases, this argument has no effect.
655
662
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
656
663
tuple of tensors `(batch_x, batch_y)`.
664
+ model_transform: function that receives the output from the model and convert it into the predictions:
665
+ ``y_pred = model_transform(model(x))``.
657
666
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
658
667
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
659
668
output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -672,6 +681,8 @@ def supervised_evaluation_step_amp(
672
681
The `model` should be moved by the user before creating an optimizer.
673
682
674
683
.. versionadded:: 0.4.5
684
+ .. versionchanged:: 0.4.12
685
+ Added ``model_transform`` to transform model's output
675
686
"""
676
687
try :
677
688
from torch .cuda .amp import autocast
@@ -683,7 +694,8 @@ def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, T
683
694
with torch .no_grad ():
684
695
x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
685
696
with autocast (enabled = True ):
686
- y_pred = model (x )
697
+ output = model (x )
698
+ y_pred = model_transform (output )
687
699
return output_transform (x , y , y_pred )
688
700
689
701
return evaluate_step
@@ -711,6 +723,8 @@ def create_supervised_evaluator(
711
723
with respect to the host. For other cases, this argument has no effect.
712
724
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
713
725
tuple of tensors `(batch_x, batch_y)`.
726
+ model_transform: function that receives the output from the model and convert it into the predictions:
727
+ ``y_pred = model_transform(model(x))``.
714
728
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
715
729
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
716
730
output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -737,17 +751,33 @@ def create_supervised_evaluator(
737
751
- `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
738
752
739
753
.. versionchanged:: 0.4.5
740
- - Added ``amp_mode`` argument for automatic mixed precision.
754
+ Added ``amp_mode`` argument for automatic mixed precision.
755
+ .. versionchanged:: 0.4.12
756
+ Added ``model_transform`` to transform model's output
741
757
"""
742
758
device_type = device .type if isinstance (device , torch .device ) else device
743
759
on_tpu = "xla" in device_type if device_type is not None else False
744
760
mode , _ = _check_arg (on_tpu , amp_mode , None )
745
761
746
762
metrics = metrics or {}
747
763
if mode == "amp" :
748
- evaluate_step = supervised_evaluation_step_amp (model , device , non_blocking , prepare_batch , output_transform )
764
+ evaluate_step = supervised_evaluation_step_amp (
765
+ model ,
766
+ device ,
767
+ non_blocking = non_blocking ,
768
+ prepare_batch = prepare_batch ,
769
+ model_transform = model_transform ,
770
+ output_transform = output_transform ,
771
+ )
749
772
else :
750
- evaluate_step = supervised_evaluation_step (model , device , non_blocking , prepare_batch , output_transform )
773
+ evaluate_step = supervised_evaluation_step (
774
+ model ,
775
+ device ,
776
+ non_blocking = non_blocking ,
777
+ prepare_batch = prepare_batch ,
778
+ model_transform = model_transform ,
779
+ output_transform = output_transform ,
780
+ )
751
781
752
782
evaluator = Engine (evaluate_step )
753
783
0 commit comments