10
10
from datetime import datetime
11
11
from threading import Lock
12
12
from typing import Optional , Callable , NamedTuple , Dict , List
13
+ import math
13
14
14
15
import PIL .Image
15
16
import numpy as np
@@ -356,6 +357,12 @@ def step(self, current: int, total: int = None, size: int = 1):
356
357
# nothing to do
357
358
return
358
359
360
+ if current < self .job_step :
361
+ # it was reset, new epoch/iteration basically
362
+ self .job_step = 0
363
+
364
+ steps_made = current - self .job_step
365
+
359
366
self .job_step = current
360
367
if total is not None :
361
368
self .job_steps = total
@@ -371,8 +378,12 @@ def step(self, current: int, total: int = None, size: int = 1):
371
378
speed_per_second = size / (now - self .last_batch_time ) if self .last_batch_time else size
372
379
373
380
if self .last_batch_time :
381
+ time_per_step = step_since_last_took = (now - self .last_batch_time )
382
+ if steps_made > 0 :
383
+ time_per_step = step_since_last_took / steps_made
384
+
374
385
self .seconds_per_iterations .append ({
375
- 'diff' : ( now - self . last_batch_time ) * total ,
386
+ 'diff' : time_per_step * total ,
376
387
'when' : now
377
388
})
378
389
@@ -712,7 +723,12 @@ def log_metric(self, name: str, *y, x=None):
712
723
if not isinstance (y , (list , tuple )):
713
724
y = [y ]
714
725
715
- y = [float (v ) if v is not None else 0 for v in y ]
726
+ def convert (v ):
727
+ if v is None : return 0
728
+ if math .isnan (v ): return 0
729
+ return float (v )
730
+
731
+ y = [convert (v ) for v in y ]
716
732
717
733
if x is None :
718
734
if self .job_steps > 0 :
0 commit comments