@@ -1709,7 +1709,8 @@ def fit(self,
17091709
17101710 steps = self ._len_data_loader (train_loader )
17111711 self .num_iters = num_iters
1712- if num_iters is not None and isinstance (num_iters , int ):
1712+ if num_iters is not None and isinstance (num_iters , int ) and isinstance (
1713+ steps , int ):
17131714 assert num_iters > 0 , "num_iters must be greater than 0!"
17141715 epochs = (num_iters // steps ) + 1
17151716 steps = min (num_iters , steps )
@@ -1744,8 +1745,8 @@ def fit(self,
17441745 eval_logs = self ._run_one_epoch (eval_loader , cbks , 'eval' )
17451746
17461747 cbks .on_end ('eval' , eval_logs )
1747- if self .stop_training :
1748- break
1748+ if self .stop_training :
1749+ break
17491750
17501751 cbks .on_end ('train' , logs )
17511752 self ._test_dataloader = None
@@ -1832,7 +1833,8 @@ def evaluate(self,
18321833
18331834 eval_steps = self ._len_data_loader (eval_loader )
18341835 self .num_iters = num_iters
1835- if num_iters is not None and isinstance (num_iters , int ):
1836+ if num_iters is not None and isinstance (num_iters , int ) and isinstance (
1837+ eval_steps , int ):
18361838 assert num_iters > 0 , "num_iters must be greater than 0!"
18371839 eval_steps = min (num_iters , eval_steps )
18381840 self .num_iters = eval_steps
@@ -2094,7 +2096,9 @@ def _run_one_epoch(
20942096 callbacks .on_batch_end (mode , step , logs )
20952097 if hasattr (self , 'num_iters' ) and self .num_iters is not None :
20962098 self .num_iters -= 1
2097- if self .num_iters == 0 :
2099+ if self .num_iters <= 0 :
2100+ self .stop_training = True
2101+ del self .num_iters
20982102 break
20992103 self ._reset_metrics ()
21002104
0 commit comments