18
18
Metric ,
19
19
reinit__is_reduced ,
20
20
RunningBatchWise ,
21
+ SingleEpochRunningBatchWise ,
21
22
RunningEpochWise ,
22
23
sync_all_reduce ,
23
24
)
@@ -855,25 +856,26 @@ def test_usage_exception():
855
856
m .attach (engine , "dummy" , usage = "fake" )
856
857
857
858
858
- def test_epochwise_usage ():
859
- class MyMetric (Metric ):
860
- def __init__ (self ):
861
- super (MyMetric , self ).__init__ ()
862
- self .value = []
859
+ class DummyAccumulateInListMetric (Metric ):
860
+ def __init__ (self ):
861
+ super (DummyAccumulateInListMetric , self ).__init__ ()
862
+ self .value = []
863
863
864
- def reset (self ):
865
- self .value = []
864
+ def reset (self ):
865
+ self .value = []
866
866
867
- def compute (self ):
868
- return self .value
867
+ def compute (self ):
868
+ return self .value
869
+
870
+ def update (self , output ):
871
+ self .value .append (output )
869
872
870
- def update (self , output ):
871
- self .value .append (output )
872
873
874
+ def test_epochwise_usage ():
873
875
def test (usage ):
874
876
engine = Engine (lambda e , b : b )
875
877
876
- m = MyMetric ()
878
+ m = DummyAccumulateInListMetric ()
877
879
878
880
m .attach (engine , "ewm" , usage = usage )
879
881
@@ -891,20 +893,22 @@ def _():
891
893
test (EpochWise ())
892
894
893
895
894
- def test_running_epochwise_usage ():
895
- class MyMetric (Metric ):
896
- def __init__ (self ):
897
- super (MyMetric , self ).__init__ ()
898
- self .value = 0
896
+ class DummyAccumulateMetric (Metric ):
897
+ def __init__ (self ):
898
+ super (DummyAccumulateMetric , self ).__init__ ()
899
+ self .value = 0
899
900
900
- def reset (self ):
901
- self .value = 0
901
+ def reset (self ):
902
+ self .value = 0
902
903
903
- def compute (self ):
904
- return self .value
904
+ def compute (self ):
905
+ return self .value
905
906
906
- def update (self , output ):
907
- self .value += output
907
+ def update (self , output ):
908
+ self .value += output
909
+
910
+
911
+ def test_running_epochwise_usage ():
908
912
909
913
def test (usage ):
910
914
engine = Engine (lambda e , b : e .state .metrics ["ewm" ])
@@ -915,12 +919,12 @@ def test(usage):
915
919
def _ ():
916
920
engine .state .metrics ["ewm" ] += 1
917
921
918
- m = MyMetric ()
922
+ m = DummyAccumulateMetric ()
919
923
m .attach (engine , "rewm" , usage = usage )
920
924
921
925
@engine .on (Events .EPOCH_COMPLETED )
922
926
def _ ():
923
- assert engine .state .metrics ["rewm" ] == sum (range (engine .state .epoch + 1 ))
927
+ assert engine .state .metrics ["rewm" ] == 3 * sum (range (engine .state .epoch + 1 ))
924
928
925
929
engine .run ([0 , 1 , 2 ], max_epochs = 10 )
926
930
@@ -932,24 +936,10 @@ def _():
932
936
933
937
934
938
def test_batchwise_usage ():
935
- class MyMetric (Metric ):
936
- def __init__ (self ):
937
- super (MyMetric , self ).__init__ ()
938
- self .value = []
939
-
940
- def reset (self ):
941
- self .value = []
942
-
943
- def compute (self ):
944
- return self .value
945
-
946
- def update (self , output ):
947
- self .value .append (output )
948
-
949
939
def test (usage ):
950
940
engine = Engine (lambda e , b : b )
951
941
952
- m = MyMetric ()
942
+ m = DummyAccumulateInListMetric ()
953
943
954
944
m .attach (engine , "bwm" , usage = usage )
955
945
@@ -968,37 +958,43 @@ def _():
968
958
969
959
970
960
def test_running_batchwise_usage ():
971
- class MyMetric (Metric ):
972
- def __init__ (self ):
973
- super (MyMetric , self ).__init__ ()
974
- self .value = 0
961
+ def test (usage ):
962
+ engine = Engine (lambda e , b : b )
975
963
976
- def reset ( self ):
977
- self . value = 0
964
+ m = DummyAccumulateMetric ()
965
+ m . attach ( engine , "rbwm" , usage = usage )
978
966
979
- def compute (self ):
980
- return self .value
967
+ @engine .on (Events .EPOCH_COMPLETED )
968
+ def _ ():
969
+ assert engine .state .metrics ["rbwm" ] == 6 * engine .state .epoch
970
+
971
+ engine .run ([0 , 1 , 2 , 3 ], max_epochs = 10 )
972
+
973
+ m .detach (engine , usage = usage )
974
+
975
+ test ("running_batch_wise" )
976
+ test (RunningBatchWise .usage_name )
977
+ test (RunningBatchWise ())
981
978
982
- def update (self , output ):
983
- self .value += output
984
979
980
+ def test_single_epoch_running_batchwise_usage ():
985
981
def test (usage ):
986
982
engine = Engine (lambda e , b : b )
987
983
988
- m = MyMetric ()
984
+ m = DummyAccumulateMetric ()
989
985
m .attach (engine , "rbwm" , usage = usage )
990
986
991
987
@engine .on (Events .EPOCH_COMPLETED )
992
988
def _ ():
993
- assert engine .state .metrics ["rbwm" ] == 3 * engine . state . epoch
989
+ assert engine .state .metrics ["rbwm" ] == 6
994
990
995
- engine .run ([0 , 1 , 2 ], max_epochs = 10 )
991
+ engine .run ([0 , 1 , 2 , 3 ], max_epochs = 10 )
996
992
997
993
m .detach (engine , usage = usage )
998
994
999
- test ("running_batch_wise " )
1000
- test (RunningBatchWise .usage_name )
1001
- test (RunningBatchWise ())
995
+ test ("single_epoch_running_batch_wise " )
996
+ test (SingleEpochRunningBatchWise .usage_name )
997
+ test (SingleEpochRunningBatchWise ())
1002
998
1003
999
1004
1000
def test_batchfiltered_usage ():
0 commit comments