Skip to content

Commit 635e054

Browse files
Fix a bug and do a refactor in test_metric
and add test for SingleEpochRunningBatchWise in test_metric
1 parent 368b170 commit 635e054

File tree

1 file changed

+52
-56
lines changed

1 file changed

+52
-56
lines changed

tests/ignite/metrics/test_metric.py

Lines changed: 52 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Metric,
1919
reinit__is_reduced,
2020
RunningBatchWise,
21+
SingleEpochRunningBatchWise,
2122
RunningEpochWise,
2223
sync_all_reduce,
2324
)
@@ -855,25 +856,26 @@ def test_usage_exception():
855856
m.attach(engine, "dummy", usage="fake")
856857

857858

858-
def test_epochwise_usage():
859-
class MyMetric(Metric):
860-
def __init__(self):
861-
super(MyMetric, self).__init__()
862-
self.value = []
859+
class DummyAccumulateInListMetric(Metric):
860+
def __init__(self):
861+
super(DummyAccumulateInListMetric, self).__init__()
862+
self.value = []
863863

864-
def reset(self):
865-
self.value = []
864+
def reset(self):
865+
self.value = []
866866

867-
def compute(self):
868-
return self.value
867+
def compute(self):
868+
return self.value
869+
870+
def update(self, output):
871+
self.value.append(output)
869872

870-
def update(self, output):
871-
self.value.append(output)
872873

874+
def test_epochwise_usage():
873875
def test(usage):
874876
engine = Engine(lambda e, b: b)
875877

876-
m = MyMetric()
878+
m = DummyAccumulateInListMetric()
877879

878880
m.attach(engine, "ewm", usage=usage)
879881

@@ -891,20 +893,22 @@ def _():
891893
test(EpochWise())
892894

893895

894-
def test_running_epochwise_usage():
895-
class MyMetric(Metric):
896-
def __init__(self):
897-
super(MyMetric, self).__init__()
898-
self.value = 0
896+
class DummyAccumulateMetric(Metric):
897+
def __init__(self):
898+
super(DummyAccumulateMetric, self).__init__()
899+
self.value = 0
899900

900-
def reset(self):
901-
self.value = 0
901+
def reset(self):
902+
self.value = 0
902903

903-
def compute(self):
904-
return self.value
904+
def compute(self):
905+
return self.value
905906

906-
def update(self, output):
907-
self.value += output
907+
def update(self, output):
908+
self.value += output
909+
910+
911+
def test_running_epochwise_usage():
908912

909913
def test(usage):
910914
engine = Engine(lambda e, b: e.state.metrics["ewm"])
@@ -915,12 +919,12 @@ def test(usage):
915919
def _():
916920
engine.state.metrics["ewm"] += 1
917921

918-
m = MyMetric()
922+
m = DummyAccumulateMetric()
919923
m.attach(engine, "rewm", usage=usage)
920924

921925
@engine.on(Events.EPOCH_COMPLETED)
922926
def _():
923-
assert engine.state.metrics["rewm"] == sum(range(engine.state.epoch + 1))
927+
assert engine.state.metrics["rewm"] == 3 * sum(range(engine.state.epoch + 1))
924928

925929
engine.run([0, 1, 2], max_epochs=10)
926930

@@ -932,24 +936,10 @@ def _():
932936

933937

934938
def test_batchwise_usage():
935-
class MyMetric(Metric):
936-
def __init__(self):
937-
super(MyMetric, self).__init__()
938-
self.value = []
939-
940-
def reset(self):
941-
self.value = []
942-
943-
def compute(self):
944-
return self.value
945-
946-
def update(self, output):
947-
self.value.append(output)
948-
949939
def test(usage):
950940
engine = Engine(lambda e, b: b)
951941

952-
m = MyMetric()
942+
m = DummyAccumulateInListMetric()
953943

954944
m.attach(engine, "bwm", usage=usage)
955945

@@ -968,37 +958,43 @@ def _():
968958

969959

970960
def test_running_batchwise_usage():
971-
class MyMetric(Metric):
972-
def __init__(self):
973-
super(MyMetric, self).__init__()
974-
self.value = 0
961+
def test(usage):
962+
engine = Engine(lambda e, b: b)
975963

976-
def reset(self):
977-
self.value = 0
964+
m = DummyAccumulateMetric()
965+
m.attach(engine, "rbwm", usage=usage)
978966

979-
def compute(self):
980-
return self.value
967+
@engine.on(Events.EPOCH_COMPLETED)
968+
def _():
969+
assert engine.state.metrics["rbwm"] == 6 * engine.state.epoch
970+
971+
engine.run([0, 1, 2, 3], max_epochs=10)
972+
973+
m.detach(engine, usage=usage)
974+
975+
test("running_batch_wise")
976+
test(RunningBatchWise.usage_name)
977+
test(RunningBatchWise())
981978

982-
def update(self, output):
983-
self.value += output
984979

980+
def test_single_epoch_running_batchwise_usage():
985981
def test(usage):
986982
engine = Engine(lambda e, b: b)
987983

988-
m = MyMetric()
984+
m = DummyAccumulateMetric()
989985
m.attach(engine, "rbwm", usage=usage)
990986

991987
@engine.on(Events.EPOCH_COMPLETED)
992988
def _():
993-
assert engine.state.metrics["rbwm"] == 3 * engine.state.epoch
989+
assert engine.state.metrics["rbwm"] == 6
994990

995-
engine.run([0, 1, 2], max_epochs=10)
991+
engine.run([0, 1, 2, 3], max_epochs=10)
996992

997993
m.detach(engine, usage=usage)
998994

999-
test("running_batch_wise")
1000-
test(RunningBatchWise.usage_name)
1001-
test(RunningBatchWise())
995+
test("single_epoch_running_batch_wise")
996+
test(SingleEpochRunningBatchWise.usage_name)
997+
test(SingleEpochRunningBatchWise())
1002998

1003999

10041000
def test_batchfiltered_usage():

0 commit comments

Comments
 (0)