Skip to content

Commit

Permalink
Merge pull request #742 from emcastillo/chrome-tracing
Browse files Browse the repository at this point in the history
Support chrome tracing in ppe.profiler
  • Loading branch information
linshokaku authored Nov 17, 2023
2 parents 89a92c5 + 3fc39b4 commit 8500fe2
Show file tree
Hide file tree
Showing 15 changed files with 837 additions and 111 deletions.
7 changes: 7 additions & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ Profiler

profiler.TimeSummary.report

profiler.clear_tracer
profiler.enable_global_trace
profiler.enable_thread_trace
profiler.get_tracer
profiler.ChromeTracer
profiler.TraceableDataset

Distributed Training
---------------------

Expand Down
58 changes: 37 additions & 21 deletions example/mnist_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,27 +129,49 @@ def main():
default=None,
help="output mode for profiler results",
)
parser.add_argument(
"--trace",
type=str,
default=None,
help="output trace for timeline in chrome format",
)
args = parser.parse_args()

torch.manual_seed(args.seed)
numpy.random.seed(args.seed)
torch.use_deterministic_algorithms(args.deterministic)

train_dataset = datasets.MNIST(
"../data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
eval_dataset = datasets.MNIST(
"../data",
train=False,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
if args.trace is not None:
train_dataset = ppe.profiler.TraceableDataset(
train_dataset, "train_dataset_read"
)

use_cuda = args.device.startswith("cuda")

kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=ppe.dataloaders.utils.CollateAsDict(
Expand All @@ -158,16 +180,7 @@ def main():
**kwargs, # type: ignore[arg-type]
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=False,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
eval_dataset,
batch_size=args.test_batch_size,
shuffle=True,
collate_fn=ppe.dataloaders.utils.CollateAsDict(
Expand Down Expand Up @@ -204,6 +217,8 @@ def main():
),
extensions.snapshot(),
]
if args.trace is not None:
my_extensions.append(extensions.TimelineTrace(filename=args.trace))

# Custom stop triggers can be added to the manager and
# their status accessed through `manager.stop_trigger`
Expand Down Expand Up @@ -272,6 +287,7 @@ def callback(prof):
metrics=[ppe.training.metrics.AccuracyMetric("target", "output")],
options={"eval_report_keys": ["loss", "accuracy"]},
),
enable_trace=args.trace is not None,
options={"train_report_keys": ["loss"]},
profile=profile,
)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_pfn_extras/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _synchronize() -> None:
]
groups = _group_by_type(grads)
with record(
"pytorch_pfn_extras.nn.parallel."
"ppe.nn.parallel."
"DistributedDataParallel:reduce_gradient",
use_cuda=torch.cuda.is_available(),
):
Expand Down
9 changes: 9 additions & 0 deletions pytorch_pfn_extras/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,12 @@
from pytorch_pfn_extras.profiler._record import record_iterable # NOQA
from pytorch_pfn_extras.profiler._time_summary import TimeSummary # NOQA
from pytorch_pfn_extras.profiler._time_summary import get_time_summary # NOQA
from pytorch_pfn_extras.profiler._tracing import ( # NOQA
ChromeTracer,
TraceableDataset,
Tracer,
clear_tracer,
enable_global_trace,
enable_thread_trace,
get_tracer,
)
70 changes: 59 additions & 11 deletions pytorch_pfn_extras/profiler/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
Iterable,
Optional,
TypeVar,
Union,
)

import torch
from pytorch_pfn_extras.profiler import _time_summary
from pytorch_pfn_extras.profiler import _time_summary, _tracing
from pytorch_pfn_extras.runtime import runtime_registry

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,15 +44,43 @@ def complete(self) -> None:
pass


@contextmanager
def dummy_tracer(name: str) -> Generator[None, None, None]:
yield None


@contextmanager
def tracer(
tag: str,
device: "DeviceLike" = "cpu",
trace: Union[_tracing.Tracer, bool] = False,
) -> Generator[None, None, None]:
# this uses the PyTorch autograd tracer or one for custom devices
runtime_cls = runtime_registry.get_runtime_class_for_device_spec(device)
runtime_tracer = runtime_cls.trace

user_tracer: _tracing.Tracer
if isinstance(trace, bool) and not trace:
user_tracer = _tracing.DummyTracer()
elif isinstance(trace, bool):
user_tracer = _tracing.get_tracer()
elif isinstance(trace, _tracing.Tracer):
user_tracer = trace

with runtime_tracer(tag, None), user_tracer.add_event(tag):
yield


@contextmanager
def record(
tag: Optional[str],
metric: Optional[str] = None,
use_cuda: bool = False,
enable: bool = True,
device: "DeviceLike" = "cpu",
trace: Union[_tracing.Tracer, bool] = False,
) -> Generator[_time_summary._ReportNotification, None, None]:
if not enable:
if not enable and not trace:
yield _DummyReportNotification()
return

Expand All @@ -61,16 +90,16 @@ def record(
if metric is None:
metric = tag

runtime_cls = runtime_registry.get_runtime_class_for_device_spec(device)
runtime_tracer = runtime_cls.trace

if use_cuda:
torch.cuda.nvtx.range_push(tag) # type: ignore[no-untyped-call]
try:
with runtime_tracer(tag, None):
time_summary = _time_summary.get_time_summary()
with time_summary.report(metric, use_cuda) as ntf:
yield ntf
with tracer(tag, device, trace):
if not enable:
time_summary = _time_summary.get_time_summary()
with time_summary.report(metric, use_cuda) as ntf:
yield ntf
else:
yield _DummyReportNotification()
finally:
if use_cuda:
torch.cuda.nvtx.range_pop() # type: ignore[no-untyped-call]
Expand All @@ -83,10 +112,20 @@ def record_function(
tag: Optional[str],
use_cuda: bool = False,
enable: bool = True,
device: "DeviceLike" = "cpu",
trace: Union[_tracing.Tracer, bool] = False,
) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
def wrapper(f: Callable[..., _T]) -> Callable[..., _T]:
def wrapped(*args: Any, **kwargs: Any) -> _T:
with record(tag or f.__name__, use_cuda=use_cuda, enable=enable):
name = tag or f.__name__
with record(
name,
None,
use_cuda,
enable,
device,
trace,
):
return f(*args, **kwargs)

return wrapped
Expand All @@ -100,6 +139,8 @@ def record_iterable(
divide_metric: bool = False,
use_cuda: bool = False,
enable: bool = True,
device: "DeviceLike" = "cpu",
trace: Union[_tracing.Tracer, bool] = False,
) -> Iterable[_T]:
if tag is None:
tag = _infer_tag_name(inspect.currentframe(), depth=1)
Expand All @@ -108,7 +149,14 @@ def wrapped() -> Iterable[_T]:
for i, x in enumerate(iter):
name = f"{tag}-{i}"
metric = name if divide_metric else tag
with record(name, metric, use_cuda=use_cuda, enable=enable):
with record(
name,
metric,
use_cuda,
enable,
device,
trace,
):
yield x

return wrapped()
72 changes: 4 additions & 68 deletions pytorch_pfn_extras/profiler/_time_summary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import atexit
import multiprocessing as mp
import os
import queue
import threading
Expand All @@ -9,6 +8,7 @@
from typing import Callable, Dict, Generator, Optional, Tuple

import torch
from pytorch_pfn_extras.profiler import _util
from pytorch_pfn_extras.reporting import DictSummary

Events = Tuple[torch.cuda.Event, torch.cuda.Event]
Expand Down Expand Up @@ -39,72 +39,6 @@ def complete(self) -> None:
)


class _CPUWorker:
def __init__(
self,
add: Callable[[str, float], None],
max_queue_size: int,
) -> None:
self._add = add
self._max_queue_size = max_queue_size
self._initialized = False
self._queue: Optional[
mp.JoinableQueue[Optional[Tuple[str, float]]]
] = None
self._thread: Optional[threading.Thread] = None
self._thread_exited = False

def initialize(self) -> None:
if self._initialized:
return
self._queue = mp.JoinableQueue(self._max_queue_size)
self._thread = threading.Thread(target=self._worker, daemon=True)
self._thread.start()
self._initialized = True
self._thread_exited = False

def finalize(self) -> None:
if not self._initialized:
return
assert self._queue is not None
assert self._thread is not None
# In some situations, (when this runs in a subprocess), the queue might have
# been cut in the worker thread before this function is called
# due to the non-deterministic shutdown process.
if not self._thread_exited:
self._queue.put(None)
self._queue.join()
self._queue.close()
self._queue.join_thread()
self._initialized = False

def synchronize(self) -> None:
assert self._queue is not None
self._queue.join()

def put(self, name: str, value: float) -> None:
assert self._queue is not None
assert not self._thread_exited
self._queue.put((name, value))

def _worker(self) -> None:
assert self._queue is not None
while True:
try:
v = self._queue.get()
# If this runs in a subprocess, the cleanup may throw an EOF here
# before the queue cleanup code is executed
except EOFError:
self._thread_exited = True
break
if v is None:
self._queue.task_done()
break
name, value = v
self._add(name, value)
self._queue.task_done()


_QueueElem = Tuple[str, Tuple[torch.cuda.Event, torch.cuda.Event]]


Expand Down Expand Up @@ -226,7 +160,9 @@ def __init__(
self._summary = DictSummary()
self._additional_stats: Dict[str, float] = {}

self._cpu_worker = _CPUWorker(self._add_from_worker, max_queue_size)
self._cpu_worker = _util.QueueWorker(
self._add_from_worker, max_queue_size
)
self._cuda_worker: Optional[_CUDAWorker] = None
if torch.cuda.is_available():
self._cuda_worker = _CUDAWorker(
Expand Down
Loading

0 comments on commit 8500fe2

Please sign in to comment.