Skip to content

Commit 44d27e6

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommitted
Add Tensorboard support (#530)
Summary: Enable with the `--tensorboard-logdir` option. Pull Request resolved: #530 Differential Revision: D14218430 Pulled By: myleott fbshipit-source-id: e7a54f66f928e3bb02ae03fda09b22fa4fa7d053
1 parent 65c1903 commit 44d27e6

File tree

4 files changed

+115
-55
lines changed

4 files changed

+115
-55
lines changed

fairseq/options.py

+3
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def get_parser(desc, default_task='translation'):
138138
help='log progress every N batches (when progress bar is disabled)')
139139
parser.add_argument('--log-format', default=None, help='log format to use',
140140
choices=['json', 'none', 'simple', 'tqdm'])
141+
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
142+
help='path to save logs for tensorboard, should match --logdir '
143+
'of running tensorboard (default: no tensorboard logging)')
141144
parser.add_argument('--seed', default=1, type=int, metavar='N',
142145
help='pseudo random number generator seed')
143146
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')

fairseq/progress_bar.py

+86-29
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
from collections import OrderedDict
1313
import json
1414
from numbers import Number
15+
import os
16+
import re
1517
import sys
1618

1719
from tqdm import tqdm
1820

19-
from fairseq.meters import AverageMeter
21+
from fairseq import distributed_utils
22+
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
2023

2124

2225
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
@@ -36,9 +39,25 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm',
3639
bar = tqdm_progress_bar(iterator, epoch, prefix)
3740
else:
3841
raise ValueError('Unknown log format: {}'.format(args.log_format))
42+
43+
if args.tensorboard_logdir and distributed_utils.is_master(args):
44+
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir)
45+
3946
return bar
4047

4148

49+
def format_stat(stat):
50+
if isinstance(stat, Number):
51+
stat = '{:g}'.format(stat)
52+
elif isinstance(stat, AverageMeter):
53+
stat = '{:.3f}'.format(stat.avg)
54+
elif isinstance(stat, TimeMeter):
55+
stat = '{:g}'.format(round(stat.avg))
56+
elif isinstance(stat, StopwatchMeter):
57+
stat = '{:g}'.format(round(stat.sum))
58+
return stat
59+
60+
4261
class progress_bar(object):
4362
"""Abstract class for progress bars."""
4463
def __init__(self, iterable, epoch=None, prefix=None):
@@ -59,11 +78,11 @@ def __exit__(self, *exc):
5978
def __iter__(self):
6079
raise NotImplementedError
6180

62-
def log(self, stats):
81+
def log(self, stats, tag='', step=None):
6382
"""Log intermediate stats according to log_interval."""
6483
raise NotImplementedError
6584

66-
def print(self, stats):
85+
def print(self, stats, tag='', step=None):
6786
"""Print end-of-epoch stats."""
6887
raise NotImplementedError
6988

@@ -79,17 +98,7 @@ def _format_stats(self, stats):
7998
postfix = OrderedDict(stats)
8099
# Preprocess stats according to datatype
81100
for key in postfix.keys():
82-
# Number: limit the length of the string
83-
if isinstance(postfix[key], Number):
84-
postfix[key] = '{:g}'.format(postfix[key])
85-
# Meter: display both current and average value
86-
elif isinstance(postfix[key], AverageMeter):
87-
postfix[key] = '{:.2f} ({:.2f})'.format(
88-
postfix[key].val, postfix[key].avg)
89-
# Else for any other type, try to get the string conversion
90-
elif not isinstance(postfix[key], str):
91-
postfix[key] = str(postfix[key])
92-
# Else if it's a string, don't need to preprocess anything
101+
postfix[key] = str(format_stat(postfix[key]))
93102
return postfix
94103

95104

@@ -111,13 +120,15 @@ def __iter__(self):
111120
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
112121
print(json.dumps(stats), flush=True)
113122

114-
def log(self, stats):
123+
def log(self, stats, tag='', step=None):
115124
"""Log intermediate stats according to log_interval."""
116125
self.stats = stats
117126

118-
def print(self, stats):
127+
def print(self, stats, tag='', step=None):
119128
"""Print end-of-epoch stats."""
120129
self.stats = stats
130+
if tag != '':
131+
self.stats = OrderedDict([(tag + '_' + k, v) for k, v in self.stats.items()])
121132
stats = self._format_stats(self.stats, epoch=self.epoch)
122133
print(json.dumps(stats), flush=True)
123134

@@ -126,15 +137,10 @@ def _format_stats(self, stats, epoch=None, update=None):
126137
if epoch is not None:
127138
postfix['epoch'] = epoch
128139
if update is not None:
129-
postfix['update'] = update
140+
postfix['update'] = round(update, 3)
130141
# Preprocess stats according to datatype
131142
for key in stats.keys():
132-
# Meter: display both current and average value
133-
if isinstance(stats[key], AverageMeter):
134-
postfix[key] = stats[key].val
135-
postfix[key + '_avg'] = stats[key].avg
136-
else:
137-
postfix[key] = stats[key]
143+
postfix[key] = format_stat(stats[key])
138144
return postfix
139145

140146

@@ -148,11 +154,11 @@ def __iter__(self):
148154
for obj in self.iterable:
149155
yield obj
150156

151-
def log(self, stats):
157+
def log(self, stats, tag='', step=None):
152158
"""Log intermediate stats according to log_interval."""
153159
pass
154160

155-
def print(self, stats):
161+
def print(self, stats, tag='', step=None):
156162
"""Print end-of-epoch stats."""
157163
pass
158164

@@ -175,11 +181,11 @@ def __iter__(self):
175181
print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix),
176182
flush=True)
177183

178-
def log(self, stats):
184+
def log(self, stats, tag='', step=None):
179185
"""Log intermediate stats according to log_interval."""
180186
self.stats = self._format_stats(stats)
181187

182-
def print(self, stats):
188+
def print(self, stats, tag='', step=None):
183189
"""Print end-of-epoch stats."""
184190
postfix = self._str_pipes(self._format_stats(stats))
185191
print('{} | {}'.format(self.prefix, postfix), flush=True)
@@ -195,11 +201,62 @@ def __init__(self, iterable, epoch=None, prefix=None):
195201
def __iter__(self):
196202
return iter(self.tqdm)
197203

198-
def log(self, stats):
204+
def log(self, stats, tag='', step=None):
199205
"""Log intermediate stats according to log_interval."""
200206
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
201207

202-
def print(self, stats):
208+
def print(self, stats, tag='', step=None):
203209
"""Print end-of-epoch stats."""
204210
postfix = self._str_pipes(self._format_stats(stats))
205211
self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix))
212+
213+
214+
class tensorboard_log_wrapper(progress_bar):
215+
"""Log to tensorboard."""
216+
217+
def __init__(self, wrapped_bar, tensorboard_logdir):
218+
self.wrapped_bar = wrapped_bar
219+
self.tensorboard_logdir = tensorboard_logdir
220+
221+
try:
222+
from tensorboardX import SummaryWriter
223+
self.SummaryWriter = SummaryWriter
224+
self._writers = {}
225+
except ImportError:
226+
print("tensorboard or required dependencies not found, "
227+
"please see README for using tensorboard.")
228+
self.SummaryWriter = None
229+
230+
def _writer(self, key):
231+
if self.SummaryWriter is None:
232+
return None
233+
if key not in self._writers:
234+
self._writers[key] = self.SummaryWriter(
235+
log_dir=os.path.join(self.tensorboard_logdir, key),
236+
)
237+
return self._writers[key]
238+
239+
def __iter__(self):
240+
return iter(self.wrapped_bar)
241+
242+
def log(self, stats, tag='', step=None):
243+
"""Log intermediate stats to tensorboard."""
244+
self._log_to_tensorboard(stats, tag, step)
245+
self.wrapped_bar.log(stats, tag=tag, step=step)
246+
247+
def print(self, stats, tag='', step=None):
248+
"""Print end-of-epoch stats."""
249+
self._log_to_tensorboard(stats, tag, step)
250+
self.wrapped_bar.print(stats, tag=tag, step=step)
251+
252+
def _log_to_tensorboard(self, stats, tag='', step=None):
253+
writer = self._writer(tag)
254+
if writer is None:
255+
return
256+
if step is None:
257+
step = stats['num_updates']
258+
for key in stats.keys() - {'num_updates'}:
259+
if isinstance(stats[key], AverageMeter):
260+
writer.add_scalar(key, stats[key].val, step)
261+
elif isinstance(stats[key], Number):
262+
writer.add_scalar(key, stats[key], step)

tests/test_reproducibility.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def _test_reproducibility(self, name, extra_flags=None):
6161
def cast(s):
6262
return round(float(s), 3)
6363

64-
for k in ['loss', 'ppl', 'num_updates', 'gnorm']:
64+
for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']:
6565
self.assertEqual(cast(train_log[k]), cast(train_res_log[k]))
66-
for k in ['valid_loss', 'valid_ppl', 'num_updates', 'best']:
66+
for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']:
6767
self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k]))
6868

6969
def test_reproducibility(self):

train.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def train(args, trainer, task, epoch_itr):
150150
else:
151151
extra_meters[k].update(v)
152152
stats[k] = extra_meters[k].avg
153-
progress.log(stats)
153+
progress.log(stats, tag='train', step=stats['num_updates'])
154154

155155
# ignore the first mini-batch in words-per-second calculation
156156
if i == 0:
@@ -168,7 +168,7 @@ def train(args, trainer, task, epoch_itr):
168168
stats = get_training_stats(trainer)
169169
for k, meter in extra_meters.items():
170170
stats[k] = meter.avg
171-
progress.print(stats)
171+
progress.print(stats, tag='train', step=stats['num_updates'])
172172

173173
# reset training meters
174174
for k in [
@@ -181,26 +181,26 @@ def train(args, trainer, task, epoch_itr):
181181

182182
def get_training_stats(trainer):
183183
stats = collections.OrderedDict()
184-
stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
184+
stats['loss'] = trainer.get_meter('train_loss')
185185
if trainer.get_meter('train_nll_loss').count > 0:
186-
nll_loss = trainer.get_meter('train_nll_loss').avg
187-
stats['nll_loss'] = '{:.3f}'.format(nll_loss)
186+
nll_loss = trainer.get_meter('train_nll_loss')
187+
stats['nll_loss'] = nll_loss
188188
else:
189-
nll_loss = trainer.get_meter('train_loss').avg
190-
stats['ppl'] = get_perplexity(nll_loss)
191-
stats['wps'] = round(trainer.get_meter('wps').avg)
192-
stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
193-
stats['wpb'] = round(trainer.get_meter('wpb').avg)
194-
stats['bsz'] = round(trainer.get_meter('bsz').avg)
189+
nll_loss = trainer.get_meter('train_loss')
190+
stats['ppl'] = get_perplexity(nll_loss.avg)
191+
stats['wps'] = trainer.get_meter('wps')
192+
stats['ups'] = trainer.get_meter('ups')
193+
stats['wpb'] = trainer.get_meter('wpb')
194+
stats['bsz'] = trainer.get_meter('bsz')
195195
stats['num_updates'] = trainer.get_num_updates()
196196
stats['lr'] = trainer.get_lr()
197-
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
198-
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
199-
stats['oom'] = trainer.get_meter('oom').avg
197+
stats['gnorm'] = trainer.get_meter('gnorm')
198+
stats['clip'] = trainer.get_meter('clip')
199+
stats['oom'] = trainer.get_meter('oom')
200200
if trainer.get_meter('loss_scale') is not None:
201-
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
201+
stats['loss_scale'] = trainer.get_meter('loss_scale')
202202
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
203-
stats['train_wall'] = round(trainer.get_meter('train_wall').sum)
203+
stats['train_wall'] = trainer.get_meter('train_wall')
204204
return stats
205205

206206

@@ -249,24 +249,24 @@ def validate(args, trainer, task, epoch_itr, subsets):
249249
stats = get_valid_stats(trainer)
250250
for k, meter in extra_meters.items():
251251
stats[k] = meter.avg
252-
progress.print(stats)
252+
progress.print(stats, tag=subset, step=trainer.get_num_updates())
253253

254-
valid_losses.append(stats['valid_loss'])
254+
valid_losses.append(stats['loss'].avg)
255255
return valid_losses
256256

257257

258258
def get_valid_stats(trainer):
259259
stats = collections.OrderedDict()
260-
stats['valid_loss'] = trainer.get_meter('valid_loss').avg
260+
stats['loss'] = trainer.get_meter('valid_loss')
261261
if trainer.get_meter('valid_nll_loss').count > 0:
262-
nll_loss = trainer.get_meter('valid_nll_loss').avg
263-
stats['valid_nll_loss'] = nll_loss
262+
nll_loss = trainer.get_meter('valid_nll_loss')
263+
stats['nll_loss'] = nll_loss
264264
else:
265-
nll_loss = trainer.get_meter('valid_loss').avg
266-
stats['valid_ppl'] = get_perplexity(nll_loss)
265+
nll_loss = stats['loss']
266+
stats['ppl'] = get_perplexity(nll_loss.avg)
267267
stats['num_updates'] = trainer.get_num_updates()
268268
if hasattr(save_checkpoint, 'best'):
269-
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
269+
stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
270270
return stats
271271

272272

0 commit comments

Comments
 (0)