2
2
import torch
3
3
from torchvision .utils import make_grid
4
4
from base import BaseTrainer
5
+ from utils import inf_loop
5
6
6
7
7
8
class Trainer (BaseTrainer ):
@@ -11,11 +12,18 @@ class Trainer(BaseTrainer):
11
12
Note:
12
13
Inherited from BaseTrainer.
13
14
"""
14
- def __init__ (self , model , loss , metrics , optimizer , config ,
15
- data_loader , valid_data_loader = None , lr_scheduler = None ):
15
+ def __init__ (self , model , loss , metrics , optimizer , config , data_loader ,
16
+ valid_data_loader = None , lr_scheduler = None , len_epoch = None ):
16
17
super ().__init__ (model , loss , metrics , optimizer , config )
17
18
self .config = config
18
19
self .data_loader = data_loader
20
+ if len_epoch is None :
21
+ # epoch-based training
22
+ self .len_epoch = len (self .data_loader )
23
+ else :
24
+ # iteration-based training
25
+ self .data_loader = inf_loop (data_loader )
26
+ self .len_epoch = len_epoch
19
27
self .valid_data_loader = valid_data_loader
20
28
self .do_validation = self .valid_data_loader is not None
21
29
self .lr_scheduler = lr_scheduler
@@ -57,28 +65,29 @@ def _train_epoch(self, epoch):
57
65
loss .backward ()
58
66
self .optimizer .step ()
59
67
60
- self .writer .set_step ((epoch - 1 ) * len ( self .data_loader ) + batch_idx )
68
+ self .writer .set_step ((epoch - 1 ) * self .len_epoch + batch_idx )
61
69
self .writer .add_scalar ('loss' , loss .item ())
62
70
total_loss += loss .item ()
63
71
total_metrics += self ._eval_metrics (output , target )
64
72
65
73
if batch_idx % self .log_step == 0 :
66
- self .logger .debug ('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}' .format (
74
+ self .logger .debug ('Train Epoch: {} {} Loss: {:.6f}' .format (
67
75
epoch ,
68
- batch_idx * self .data_loader .batch_size ,
69
- self .data_loader .n_samples ,
70
- 100.0 * batch_idx / len (self .data_loader ),
76
+ self ._progress (batch_idx ),
71
77
loss .item ()))
72
78
self .writer .add_image ('input' , make_grid (data .cpu (), nrow = 8 , normalize = True ))
73
79
80
+ if batch_idx == self .len_epoch :
81
+ break
82
+
74
83
log = {
75
- 'loss' : total_loss / len ( self .data_loader ) ,
76
- 'metrics' : (total_metrics / len ( self .data_loader ) ).tolist ()
84
+ 'loss' : total_loss / self .len_epoch ,
85
+ 'metrics' : (total_metrics / self .len_epoch ).tolist ()
77
86
}
78
87
79
88
if self .do_validation :
80
89
val_log = self ._valid_epoch (epoch )
81
- log = { ** log , ** val_log }
90
+ log . update ( val_log )
82
91
83
92
if self .lr_scheduler is not None :
84
93
self .lr_scheduler .step ()
@@ -118,3 +127,13 @@ def _valid_epoch(self, epoch):
118
127
'val_loss' : total_val_loss / len (self .valid_data_loader ),
119
128
'val_metrics' : (total_val_metrics / len (self .valid_data_loader )).tolist ()
120
129
}
130
+
131
+ def _progress (self , batch_idx ):
132
+ base = '[{}/{} ({:.0f}%)]'
133
+ if hasattr (self .data_loader , 'n_samples' ):
134
+ current = batch_idx * self .data_loader .batch_size
135
+ total = self .data_loader .n_samples
136
+ else :
137
+ current = batch_idx
138
+ total = self .len_epoch
139
+ return base .format (current , total , 100.0 * current / total )
0 commit comments