Skip to content

Commit ae2e1e8

Browse files
committed
add iteration-based training
1 parent b99e576 commit ae2e1e8

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

trainer/trainer.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from torchvision.utils import make_grid
44
from base import BaseTrainer
5+
from utils import inf_loop
56

67

78
class Trainer(BaseTrainer):
@@ -11,11 +12,18 @@ class Trainer(BaseTrainer):
1112
Note:
1213
Inherited from BaseTrainer.
1314
"""
14-
def __init__(self, model, loss, metrics, optimizer, config,
15-
data_loader, valid_data_loader=None, lr_scheduler=None):
15+
def __init__(self, model, loss, metrics, optimizer, config, data_loader,
16+
valid_data_loader=None, lr_scheduler=None, len_epoch=None):
1617
super().__init__(model, loss, metrics, optimizer, config)
1718
self.config = config
1819
self.data_loader = data_loader
20+
if len_epoch is None:
21+
# epoch-based training
22+
self.len_epoch = len(self.data_loader)
23+
else:
24+
# iteration-based training
25+
self.data_loader = inf_loop(data_loader)
26+
self.len_epoch = len_epoch
1927
self.valid_data_loader = valid_data_loader
2028
self.do_validation = self.valid_data_loader is not None
2129
self.lr_scheduler = lr_scheduler
@@ -57,28 +65,29 @@ def _train_epoch(self, epoch):
5765
loss.backward()
5866
self.optimizer.step()
5967

60-
self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx)
68+
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
6169
self.writer.add_scalar('loss', loss.item())
6270
total_loss += loss.item()
6371
total_metrics += self._eval_metrics(output, target)
6472

6573
if batch_idx % self.log_step == 0:
66-
self.logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
74+
self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
6775
epoch,
68-
batch_idx * self.data_loader.batch_size,
69-
self.data_loader.n_samples,
70-
100.0 * batch_idx / len(self.data_loader),
76+
self._progress(batch_idx),
7177
loss.item()))
7278
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
7379

80+
if batch_idx == self.len_epoch:
81+
break
82+
7483
log = {
75-
'loss': total_loss / len(self.data_loader),
76-
'metrics': (total_metrics / len(self.data_loader)).tolist()
84+
'loss': total_loss / self.len_epoch,
85+
'metrics': (total_metrics / self.len_epoch).tolist()
7786
}
7887

7988
if self.do_validation:
8089
val_log = self._valid_epoch(epoch)
81-
log = {**log, **val_log}
90+
log.update(val_log)
8291

8392
if self.lr_scheduler is not None:
8493
self.lr_scheduler.step()
@@ -118,3 +127,13 @@ def _valid_epoch(self, epoch):
118127
'val_loss': total_val_loss / len(self.valid_data_loader),
119128
'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
120129
}
130+
131+
def _progress(self, batch_idx):
132+
base = '[{}/{} ({:.0f}%)]'
133+
if hasattr(self.data_loader, 'n_samples'):
134+
current = batch_idx * self.data_loader.batch_size
135+
total = self.data_loader.n_samples
136+
else:
137+
current = batch_idx
138+
total = self.len_epoch
139+
return base.format(current, total, 100.0 * current / total)

utils/util.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
from pathlib import Path
33
from datetime import datetime
4+
from itertools import repeat
45
from collections import OrderedDict
56

67

@@ -17,6 +18,14 @@ def write_json(content, fname):
1718
with fname.open('wt') as handle:
1819
json.dump(content, handle, indent=4, sort_keys=False)
1920

21+
def inf_loop(data_loader):
22+
'''
23+
wrapper function to make pytorch data loader loops endlessly.
24+
'''
25+
for loader in repeat(data_loader):
26+
for data, target in loader:
27+
yield data, target
28+
2029
class Timer:
2130
def __init__(self):
2231
self.cache = datetime.now()
@@ -28,4 +37,5 @@ def check(self):
2837
return duration.total_seconds()
2938

3039
def reset(self):
31-
self.cache = datetime.now()
40+
self.cache = datetime.now()
41+

0 commit comments

Comments
 (0)