Skip to content

Commit 084b8c3

Browse files
committed
fixed ETA in steps calculation, fixed sending NaN
bumped version
1 parent 06ed3b4 commit 084b8c3

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

deepkit/experiment.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from datetime import datetime
1111
from threading import Lock
1212
from typing import Optional, Callable, NamedTuple, Dict, List
13+
import math
1314

1415
import PIL.Image
1516
import numpy as np
@@ -356,6 +357,12 @@ def step(self, current: int, total: int = None, size: int = 1):
356357
# nothing to do
357358
return
358359

360+
if current < self.job_step:
361+
# it was reset, new epoch/iteration basically
362+
self.job_step = 0
363+
364+
steps_made = current - self.job_step
365+
359366
self.job_step = current
360367
if total is not None:
361368
self.job_steps = total
@@ -371,8 +378,12 @@ def step(self, current: int, total: int = None, size: int = 1):
371378
speed_per_second = size / (now - self.last_batch_time) if self.last_batch_time else size
372379

373380
if self.last_batch_time:
381+
time_per_step = step_since_last_took = (now - self.last_batch_time)
382+
if steps_made > 0:
383+
time_per_step = step_since_last_took / steps_made
384+
374385
self.seconds_per_iterations.append({
375-
'diff': (now - self.last_batch_time) * total,
386+
'diff': time_per_step * total,
376387
'when': now
377388
})
378389

@@ -712,7 +723,12 @@ def log_metric(self, name: str, *y, x=None):
712723
if not isinstance(y, (list, tuple)):
713724
y = [y]
714725

715-
y = [float(v) if v is not None else 0 for v in y]
726+
def convert(v):
727+
if v is None: return 0
728+
if math.isnan(v): return 0
729+
return float(v)
730+
731+
y = [convert(v) for v in y]
716732

717733
if x is None:
718734
if self.job_steps > 0:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup
22
from setuptools import find_packages
3-
__version__ = '1.0.8'
3+
__version__ = '1.0.9'
44

55
setup(name='deepkit',
66
version=__version__,

0 commit comments

Comments
 (0)