Skip to content

Commit 1dd7abe

Browse files
authored
Merge pull request #112 from JetBrains-Research/reset-metrics
Reset metrics after each epoch
2 parents be05daf + f1a11b8 commit 1dd7abe

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

code2seq/model/code2class.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def _shared_epoch_end(self, outputs: EPOCH_OUTPUT, step: str):
8484
mean_loss = torch.stack([out[f"{step}/loss"] for out in outputs]).mean()
8585
accuracy = self.__metrics[f"{step}_acc"].compute()
8686
log = {f"{step}/loss": mean_loss, f"{step}/accuracy": accuracy}
87+
self.__metrics[f"{step}_acc"].reset()
8788
self.log_dict(log, on_step=False, on_epoch=True)
8889

8990
def training_epoch_end(self, outputs: EPOCH_OUTPUT):

code2seq/model/code2seq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
140140
f"{step}/precision": metric.precision,
141141
f"{step}/recall": metric.recall,
142142
}
143+
self.__metrics[f"{step}_f1"].reset()
143144
self.log_dict(log, on_step=False, on_epoch=True)
144145

145146
def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup, find_packages
22

3-
VERSION = "1.0.2"
3+
VERSION = "1.0.3"
44

55
with open("README.md") as readme_file:
66
readme = readme_file.read()

0 commit comments

Comments
 (0)