12
12
from collections import OrderedDict
13
13
import json
14
14
from numbers import Number
15
+ import os
16
+ import re
15
17
import sys
16
18
17
19
from tqdm import tqdm
18
20
19
- from fairseq .meters import AverageMeter
21
+ from fairseq import distributed_utils
22
+ from fairseq .meters import AverageMeter , StopwatchMeter , TimeMeter
20
23
21
24
22
25
def build_progress_bar (args , iterator , epoch = None , prefix = None , default = 'tqdm' , no_progress_bar = 'none' ):
@@ -36,9 +39,25 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm',
36
39
bar = tqdm_progress_bar (iterator , epoch , prefix )
37
40
else :
38
41
raise ValueError ('Unknown log format: {}' .format (args .log_format ))
42
+
43
+ if args .tensorboard_logdir and distributed_utils .is_master (args ):
44
+ bar = tensorboard_log_wrapper (bar , args .tensorboard_logdir )
45
+
39
46
return bar
40
47
41
48
49
+ def format_stat (stat ):
50
+ if isinstance (stat , Number ):
51
+ stat = '{:g}' .format (stat )
52
+ elif isinstance (stat , AverageMeter ):
53
+ stat = '{:.3f}' .format (stat .avg )
54
+ elif isinstance (stat , TimeMeter ):
55
+ stat = '{:g}' .format (round (stat .avg ))
56
+ elif isinstance (stat , StopwatchMeter ):
57
+ stat = '{:g}' .format (round (stat .sum ))
58
+ return stat
59
+
60
+
42
61
class progress_bar (object ):
43
62
"""Abstract class for progress bars."""
44
63
def __init__ (self , iterable , epoch = None , prefix = None ):
@@ -59,11 +78,11 @@ def __exit__(self, *exc):
59
78
def __iter__ (self ):
60
79
raise NotImplementedError
61
80
62
- def log (self , stats ):
81
+ def log (self , stats , tag = '' , step = None ):
63
82
"""Log intermediate stats according to log_interval."""
64
83
raise NotImplementedError
65
84
66
- def print (self , stats ):
85
+ def print (self , stats , tag = '' , step = None ):
67
86
"""Print end-of-epoch stats."""
68
87
raise NotImplementedError
69
88
@@ -79,17 +98,7 @@ def _format_stats(self, stats):
79
98
postfix = OrderedDict (stats )
80
99
# Preprocess stats according to datatype
81
100
for key in postfix .keys ():
82
- # Number: limit the length of the string
83
- if isinstance (postfix [key ], Number ):
84
- postfix [key ] = '{:g}' .format (postfix [key ])
85
- # Meter: display both current and average value
86
- elif isinstance (postfix [key ], AverageMeter ):
87
- postfix [key ] = '{:.2f} ({:.2f})' .format (
88
- postfix [key ].val , postfix [key ].avg )
89
- # Else for any other type, try to get the string conversion
90
- elif not isinstance (postfix [key ], str ):
91
- postfix [key ] = str (postfix [key ])
92
- # Else if it's a string, don't need to preprocess anything
101
+ postfix [key ] = str (format_stat (postfix [key ]))
93
102
return postfix
94
103
95
104
@@ -111,13 +120,15 @@ def __iter__(self):
111
120
stats = self ._format_stats (self .stats , epoch = self .epoch , update = update )
112
121
print (json .dumps (stats ), flush = True )
113
122
114
- def log (self , stats ):
123
+ def log (self , stats , tag = '' , step = None ):
115
124
"""Log intermediate stats according to log_interval."""
116
125
self .stats = stats
117
126
118
- def print (self , stats ):
127
+ def print (self , stats , tag = '' , step = None ):
119
128
"""Print end-of-epoch stats."""
120
129
self .stats = stats
130
+ if tag != '' :
131
+ self .stats = OrderedDict ([(tag + '_' + k , v ) for k , v in self .stats .items ()])
121
132
stats = self ._format_stats (self .stats , epoch = self .epoch )
122
133
print (json .dumps (stats ), flush = True )
123
134
@@ -126,15 +137,10 @@ def _format_stats(self, stats, epoch=None, update=None):
126
137
if epoch is not None :
127
138
postfix ['epoch' ] = epoch
128
139
if update is not None :
129
- postfix ['update' ] = update
140
+ postfix ['update' ] = round ( update , 3 )
130
141
# Preprocess stats according to datatype
131
142
for key in stats .keys ():
132
- # Meter: display both current and average value
133
- if isinstance (stats [key ], AverageMeter ):
134
- postfix [key ] = stats [key ].val
135
- postfix [key + '_avg' ] = stats [key ].avg
136
- else :
137
- postfix [key ] = stats [key ]
143
+ postfix [key ] = format_stat (stats [key ])
138
144
return postfix
139
145
140
146
@@ -148,11 +154,11 @@ def __iter__(self):
148
154
for obj in self .iterable :
149
155
yield obj
150
156
151
- def log (self , stats ):
157
+ def log (self , stats , tag = '' , step = None ):
152
158
"""Log intermediate stats according to log_interval."""
153
159
pass
154
160
155
- def print (self , stats ):
161
+ def print (self , stats , tag = '' , step = None ):
156
162
"""Print end-of-epoch stats."""
157
163
pass
158
164
@@ -175,11 +181,11 @@ def __iter__(self):
175
181
print ('{}: {:5d} / {:d} {}' .format (self .prefix , i , size , postfix ),
176
182
flush = True )
177
183
178
- def log (self , stats ):
184
+ def log (self , stats , tag = '' , step = None ):
179
185
"""Log intermediate stats according to log_interval."""
180
186
self .stats = self ._format_stats (stats )
181
187
182
- def print (self , stats ):
188
+ def print (self , stats , tag = '' , step = None ):
183
189
"""Print end-of-epoch stats."""
184
190
postfix = self ._str_pipes (self ._format_stats (stats ))
185
191
print ('{} | {}' .format (self .prefix , postfix ), flush = True )
@@ -195,11 +201,62 @@ def __init__(self, iterable, epoch=None, prefix=None):
195
201
def __iter__ (self ):
196
202
return iter (self .tqdm )
197
203
198
- def log (self , stats ):
204
+ def log (self , stats , tag = '' , step = None ):
199
205
"""Log intermediate stats according to log_interval."""
200
206
self .tqdm .set_postfix (self ._format_stats (stats ), refresh = False )
201
207
202
- def print (self , stats ):
208
+ def print (self , stats , tag = '' , step = None ):
203
209
"""Print end-of-epoch stats."""
204
210
postfix = self ._str_pipes (self ._format_stats (stats ))
205
211
self .tqdm .write ('{} | {}' .format (self .tqdm .desc , postfix ))
212
+
213
+
214
+ class tensorboard_log_wrapper (progress_bar ):
215
+ """Log to tensorboard."""
216
+
217
+ def __init__ (self , wrapped_bar , tensorboard_logdir ):
218
+ self .wrapped_bar = wrapped_bar
219
+ self .tensorboard_logdir = tensorboard_logdir
220
+
221
+ try :
222
+ from tensorboardX import SummaryWriter
223
+ self .SummaryWriter = SummaryWriter
224
+ self ._writers = {}
225
+ except ImportError :
226
+ print ("tensorboard or required dependencies not found, "
227
+ "please see README for using tensorboard." )
228
+ self .SummaryWriter = None
229
+
230
+ def _writer (self , key ):
231
+ if self .SummaryWriter is None :
232
+ return None
233
+ if key not in self ._writers :
234
+ self ._writers [key ] = self .SummaryWriter (
235
+ log_dir = os .path .join (self .tensorboard_logdir , key ),
236
+ )
237
+ return self ._writers [key ]
238
+
239
+ def __iter__ (self ):
240
+ return iter (self .wrapped_bar )
241
+
242
+ def log (self , stats , tag = '' , step = None ):
243
+ """Log intermediate stats to tensorboard."""
244
+ self ._log_to_tensorboard (stats , tag , step )
245
+ self .wrapped_bar .log (stats , tag = tag , step = step )
246
+
247
+ def print (self , stats , tag = '' , step = None ):
248
+ """Print end-of-epoch stats."""
249
+ self ._log_to_tensorboard (stats , tag , step )
250
+ self .wrapped_bar .print (stats , tag = tag , step = step )
251
+
252
+ def _log_to_tensorboard (self , stats , tag = '' , step = None ):
253
+ writer = self ._writer (tag )
254
+ if writer is None :
255
+ return
256
+ if step is None :
257
+ step = stats ['num_updates' ]
258
+ for key in stats .keys () - {'num_updates' }:
259
+ if isinstance (stats [key ], AverageMeter ):
260
+ writer .add_scalar (key , stats [key ].val , step )
261
+ elif isinstance (stats [key ], Number ):
262
+ writer .add_scalar (key , stats [key ], step )
0 commit comments