diff --git a/pytorch_pfn_extras/profiler/__init__.py b/pytorch_pfn_extras/profiler/__init__.py index 88abf507..92e96782 100644 --- a/pytorch_pfn_extras/profiler/__init__.py +++ b/pytorch_pfn_extras/profiler/__init__.py @@ -1,3 +1,6 @@ +from pytorch_pfn_extras.profiler._chrome_tracing import ( + get_chrome_tracer, # NOQA +) from pytorch_pfn_extras.profiler._record import record # NOQA from pytorch_pfn_extras.profiler._record import record_function # NOQA from pytorch_pfn_extras.profiler._record import record_iterable # NOQA diff --git a/pytorch_pfn_extras/profiler/_chrome_tracing.py b/pytorch_pfn_extras/profiler/_chrome_tracing.py index 17b482de..e912c249 100644 --- a/pytorch_pfn_extras/profiler/_chrome_tracing.py +++ b/pytorch_pfn_extras/profiler/_chrome_tracing.py @@ -1,6 +1,7 @@ import json -import time +import os import threading +import time from typing import Any, Dict, List, Literal, Optional, Union import torch @@ -23,7 +24,9 @@ def __init__( def __enter__(self) -> None: self.begin_ns = time.perf_counter_ns() - def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]: + def __exit__( + self, exc_type: Any, exc_value: Any, tracebac: Any + ) -> Literal[False]: if self.is_cuda_available: torch.cuda.synchronize() # Wait for process to complete @@ -46,7 +49,9 @@ class ChromeTracingEventDisabled: def __enter__(self) -> None: pass - def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]: + def __exit__( + self, exc_type: Any, exc_value: Any, tracebac: Any + ) -> Literal[False]: return False @@ -59,14 +64,14 @@ def __call__(self, target: Dict[str, Any], file_o: Any) -> None: class ChromeTracingEmitter: def __init__( self, - out_trace_json_path: str, - writer: Writer = None, + filename: str, + writer: Writer, max_event_count: Optional[int] = None, enable: bool = True, ): self.enable = enable self.event_list: List[Dict[str, Union[str, int, float]]] = [] - self.out_trace_json_path = out_trace_json_path + self.filename = filename self.max_event_count = max_event_count or float("inf") self.event_count = 0 self.is_cuda_available = torch.cuda.is_available() @@ -96,8 +101,8 @@ def flush(self) -> None: # file pointer and with json.dumps? savefun = ChromeTracingSaveFunc() self.writer( - "chrome_tracing.json", - self.out_trace_json_path, + self.filename, + "", # out_dir arg is ignored in the writer, uses the writer attr self.event_list, savefun=savefun, ) @@ -106,9 +111,12 @@ def flush(self) -> None: _thread_local = threading.local() -def get_chrome_tracer(out_dir: str, writer: Writer) -> ChromeTracingEmitter: +def get_chrome_tracer(filename: str, writer: Writer) -> ChromeTracingEmitter: + trace_path = os.path.join(writer.out_dir, filename) if not hasattr(_thread_local, "chrome_tracer"): - _thread_local.chrome_tracer = ChromeTracingEmitter( - out_dir, writer, None, True - ) - return _thread_local.chrome_tracer # type: ignore[no-any-return] + _thread_local.chrome_tracer = {} + tracer = _thread_local.chrome_tracer.get(trace_path, None) + if tracer is None: + tracer = ChromeTracingEmitter(filename, writer, None, True) + _thread_local.chrome_tracer[trace_path] = tracer + return tracer # type: ignore[no-any-return] diff --git a/pytorch_pfn_extras/profiler/_record.py b/pytorch_pfn_extras/profiler/_record.py index f8048c63..d601073a 100644 --- a/pytorch_pfn_extras/profiler/_record.py +++ b/pytorch_pfn_extras/profiler/_record.py @@ -1,6 +1,5 @@ import inspect import types -import time from contextlib import contextmanager from typing import ( TYPE_CHECKING, @@ -15,6 +14,7 @@ import torch from pytorch_pfn_extras.profiler import _chrome_tracing, _time_summary from pytorch_pfn_extras.runtime import runtime_registry +from pytorch_pfn_extras.writing import Writer if TYPE_CHECKING: from pytorch_pfn_extras.runtime._runtime import DeviceLike @@ -51,17 +51,21 @@ def dummy_tracer(tag: Optional[str]) -> Generator[None, None, None]: @contextmanager def tracer( - tag: Optional[str], + tag: str, device: "DeviceLike" = "cpu", chrome_tracing_out: Optional[str] = None, + chrome_writer: Optional[Writer] = None, ) -> Generator[None, None, None]: runtime_cls = runtime_registry.get_runtime_class_for_device_spec(device) runtime_tracer = runtime_cls.trace if chrome_tracing_out is not None: - chrome_tracer = _chrome_tracing.get_chrome_tracer(chrome_tracing_out) + assert chrome_writer is not None + chrome_tracer = _chrome_tracing.get_chrome_tracer( + chrome_tracing_out, chrome_writer + ).add_event else: - chrome_tracer = dummy_tracer + chrome_tracer = dummy_tracer # type: ignore[assignment] with runtime_tracer(tag, None), chrome_tracer(tag): yield @@ -75,6 +79,7 @@ def record( enable: bool = True, device: "DeviceLike" = "cpu", chrome_tracing_out: Optional[str] = None, + chrome_writer: Optional[Writer] = None, ) -> Generator[_time_summary._ReportNotification, None, None]: if not enable: yield _DummyReportNotification() @@ -89,13 +94,13 @@ def record( if use_cuda: torch.cuda.nvtx.range_push(tag) # type: ignore[no-untyped-call] try: - with tracer(tag, device, chrome_tracing_out): + with tracer(tag, device, chrome_tracing_out, chrome_writer): time_summary = _time_summary.get_time_summary() - time_summary.report("start_time", time.time()) + # time_summary.report("start_time", time.time()) with time_summary.report(metric, use_cuda) as ntf: yield ntf finally: - time_summary.report("end_time", time.time()) + # time_summary.report("end_time", time.time()) if use_cuda: torch.cuda.nvtx.range_pop() # type: ignore[no-untyped-call] @@ -109,11 +114,20 @@ def record_function( enable: bool = True, device: "DeviceLike" = "cpu", chrome_tracing_out: Optional[str] = None, + chrome_writer: Optional[Writer] = None, ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: def wrapper(f: Callable[..., _T]) -> Callable[..., _T]: def wrapped(*args: Any, **kwargs: Any) -> _T: name = tag or f.__name__ - with record(name, None, use_cuda, device, chrome_tracing_out): + with record( + name, + None, + use_cuda, + enable, + device, + chrome_tracing_out, + chrome_writer, + ): return f(*args, **kwargs) return wrapped @@ -129,6 +143,7 @@ def record_iterable( enable: bool = True, device: "DeviceLike" = "cpu", chrome_tracing_out: Optional[str] = None, + chrome_writer: Optional[Writer] = None, ) -> Iterable[_T]: if tag is None: tag = _infer_tag_name(inspect.currentframe(), depth=1) @@ -137,7 +152,15 @@ def wrapped() -> Iterable[_T]: for i, x in enumerate(iter): name = f"{tag}-{i}" metric = name if divide_metric else tag - with record(name, metric, use_cuda, device, chrome_tracing_out): + with record( + name, + metric, + use_cuda, + enable, + device, + chrome_tracing_out, + chrome_writer, + ): yield x return wrapped() diff --git a/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py b/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py index 5126cc1f..5f6e5cb0 100644 --- a/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py +++ b/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py @@ -1,4 +1,5 @@ import os +import tempfile import pytest import pytorch_pfn_extras as ppe @@ -138,14 +139,14 @@ def test_record_iterable_with_chrome_trace(device): model.to(device) x = torch.arange(30, dtype=torch.float32).to(device) - iters = [x, x, x] - with torch.profiler.profile() as prof: - for x in ppe.profiler.record(None, iters): - model(x) - - keys = [event.key for event in prof.key_averages()] - assert "aten::linear" in keys - assert any(k.endswith("test_record_iterable_without_tag-0") for k in keys) - assert any(k.endswith("test_record_iterable_without_tag-1") for k in keys) - assert any(k.endswith("test_record_iterable_without_tag-2") for k in keys) + x = torch.arange(30, dtype=torch.float32).to(device) + with tempfile.TemporaryDirectory() as t_path: + w = ppe.writing.SimpleWriter(out_dir=t_path) + with torch.profiler.profile(): + with ppe.profiler.record( + "tag", chrome_tracing_out="trace.json", chrome_writer=w + ): + model(x) + ppe.profiler.get_chrome_tracer("trace.json", w).flush() + assert os.path.exists(os.path.join(t_path, "trace.json"))