Skip to content

Commit d3adb53

Browse files
committed
fix bug for num_iters in fit/evaluate
1 parent 2c94573 commit d3adb53

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

python/paddle/hapi/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,8 @@ def fit(self,
17091709

17101710
steps = self._len_data_loader(train_loader)
17111711
self.num_iters = num_iters
1712-
if num_iters is not None and isinstance(num_iters, int):
1712+
if num_iters is not None and isinstance(num_iters, int) and isinstance(
1713+
steps, int):
17131714
assert num_iters > 0, "num_iters must be greater than 0!"
17141715
epochs = (num_iters // steps) + 1
17151716
steps = min(num_iters, steps)
@@ -1744,8 +1745,8 @@ def fit(self,
17441745
eval_logs = self._run_one_epoch(eval_loader, cbks, 'eval')
17451746

17461747
cbks.on_end('eval', eval_logs)
1747-
if self.stop_training:
1748-
break
1748+
if self.stop_training:
1749+
break
17491750

17501751
cbks.on_end('train', logs)
17511752
self._test_dataloader = None
@@ -1832,7 +1833,8 @@ def evaluate(self,
18321833

18331834
eval_steps = self._len_data_loader(eval_loader)
18341835
self.num_iters = num_iters
1835-
if num_iters is not None and isinstance(num_iters, int):
1836+
if num_iters is not None and isinstance(num_iters, int) and isinstance(
1837+
eval_steps, int):
18361838
assert num_iters > 0, "num_iters must be greater than 0!"
18371839
eval_steps = min(num_iters, eval_steps)
18381840
self.num_iters = eval_steps
@@ -2094,7 +2096,9 @@ def _run_one_epoch(
20942096
callbacks.on_batch_end(mode, step, logs)
20952097
if hasattr(self, 'num_iters') and self.num_iters is not None:
20962098
self.num_iters -= 1
2097-
if self.num_iters == 0:
2099+
if self.num_iters <= 0:
2100+
self.stop_training = True
2101+
del self.num_iters
20982102
break
20992103
self._reset_metrics()
21002104

0 commit comments

Comments
 (0)