Skip to content

Commit 76ec0a2

Browse files
louaaronfacebook-github-bot
authored andcommitted
Added pytorch 1.8 profiler as hook with tensorboard visualization.
Summary: Added a new hook with uses pytorch's new profiler (in versions 1.8.1+) to better log and visualize training details. In particular, this new hook includes the ability to use tensorboard visualizations when compared to the previous Autograd Hook. Reviewed By: vaibhava0 Differential Revision: D29624951 fbshipit-source-id: 26e2b9cecf85ae2c545dc15a8103d6e1d983a94a
1 parent 87e9946 commit 76ec0a2

File tree

2 files changed

+84
-17
lines changed

2 files changed

+84
-17
lines changed

detectron2/engine/hooks.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import tempfile
99
import time
10+
import warnings
1011
from collections import Counter
1112
import torch
1213
from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
@@ -31,6 +32,7 @@
3132
"AutogradProfiler",
3233
"EvalHook",
3334
"PreciseBN",
35+
"TorchProfiler",
3436
]
3537

3638

@@ -268,45 +270,59 @@ def load_state_dict(self, state_dict):
268270
self.scheduler.load_state_dict(state_dict)
269271

270272

271-
class AutogradProfiler(HookBase):
273+
class TorchProfiler(HookBase):
272274
"""
273-
A hook which runs `torch.autograd.profiler.profile`.
275+
A hook which runs `torch.profiler.profile`.
274276
275277
Examples:
276278
::
277-
hooks.AutogradProfiler(
278-
lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
279+
hooks.TorchProfiler(
280+
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
279281
)
280282
281283
The above example will run the profiler for iteration 10~20 and dump
282284
results to ``OUTPUT_DIR``. We did not profile the first few iterations
283285
because they are typically slower than the rest.
284-
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
285-
286-
Note:
287-
When used together with NCCL on older version of GPUs,
288-
autograd profiler may cause deadlock because it unnecessarily allocates
289-
memory on every device it sees. The memory management calls, if
290-
interleaved with NCCL calls, lead to deadlock on GPUs that do not
291-
support ``cudaLaunchCooperativeKernelMultiDevice``.
286+
The result files can be loaded in the ``chrome://tracing`` page in chrome browser,
287+
and the tensorboard visualizations can be visualized using
288+
``tensorboard --logdir OUTPUT_DIR/log``
292289
"""
293290

294-
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
291+
def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True):
295292
"""
296293
Args:
297294
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
298295
and returns whether to enable the profiler.
299296
It will be called once every step, and can be used to select which steps to profile.
300297
output_dir (str): the output directory to dump tracing files.
301-
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
298+
activities (iterable): same as in `torch.profiler.profile`.
299+
save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/
302300
"""
303301
self._enable_predicate = enable_predicate
304-
self._use_cuda = use_cuda
302+
self._activities = activities
305303
self._output_dir = output_dir
304+
self._save_tensorboard = save_tensorboard
306305

307306
def before_step(self):
308307
if self._enable_predicate(self.trainer):
309-
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
308+
if self._save_tensorboard:
309+
on_trace_ready = torch.profiler.tensorboard_trace_handler(
310+
os.path.join(
311+
self._output_dir,
312+
"log",
313+
"profiler-tensorboard-iter{}".format(self.trainer.iter),
314+
)
315+
)
316+
else:
317+
on_trace_ready = None
318+
self._profiler = torch.profiler.profile(
319+
activities=self._activities,
320+
on_trace_ready=on_trace_ready,
321+
record_shapes=True,
322+
profile_memory=True,
323+
with_stack=True,
324+
with_flops=True,
325+
)
310326
self._profiler.__enter__()
311327
else:
312328
self._profiler = None
@@ -332,6 +348,51 @@ def after_step(self):
332348
f.write(content)
333349

334350

351+
class AutogradProfiler(TorchProfiler):
352+
"""
353+
A hook which runs `torch.autograd.profiler.profile`.
354+
355+
Examples:
356+
::
357+
hooks.AutogradProfiler(
358+
lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
359+
)
360+
361+
The above example will run the profiler for iteration 10~20 and dump
362+
results to ``OUTPUT_DIR``. We did not profile the first few iterations
363+
because they are typically slower than the rest.
364+
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
365+
366+
Note:
367+
When used together with NCCL on older version of GPUs,
368+
autograd profiler may cause deadlock because it unnecessarily allocates
369+
memory on every device it sees. The memory management calls, if
370+
interleaved with NCCL calls, lead to deadlock on GPUs that do not
371+
support ``cudaLaunchCooperativeKernelMultiDevice``.
372+
"""
373+
374+
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
375+
"""
376+
Args:
377+
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
378+
and returns whether to enable the profiler.
379+
It will be called once every step, and can be used to select which steps to profile.
380+
output_dir (str): the output directory to dump tracing files.
381+
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
382+
"""
383+
warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.")
384+
self._enable_predicate = enable_predicate
385+
self._use_cuda = use_cuda
386+
self._output_dir = output_dir
387+
388+
def before_step(self):
389+
if self._enable_predicate(self.trainer):
390+
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
391+
self._profiler.__enter__()
392+
else:
393+
self._profiler = None
394+
395+
335396
class EvalHook(HookBase):
336397
"""
337398
Run an evaluation function periodically, and at the end of training.

tools/benchmark.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,13 @@ def f():
113113
max_iter = 400
114114
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, f(), optimizer)
115115
trainer.register_hooks(
116-
[hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])]
116+
[
117+
hooks.IterationTimer(),
118+
hooks.PeriodicWriter([CommonMetricPrinter(max_iter)]),
119+
hooks.TorchProfiler(
120+
lambda trainer: trainer.iter == max_iter - 1, cfg.OUTPUT_DIR, save_tensorboard=True
121+
),
122+
]
117123
)
118124
trainer.train(1, max_iter)
119125

0 commit comments

Comments
 (0)