Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Aug 23, 2023
1 parent a9ec5c7 commit ab0d1d3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 32 deletions.
3 changes: 3 additions & 0 deletions pytorch_pfn_extras/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from pytorch_pfn_extras.profiler._chrome_tracing import (
get_chrome_tracer, # NOQA
)
from pytorch_pfn_extras.profiler._record import record # NOQA
from pytorch_pfn_extras.profiler._record import record_function # NOQA
from pytorch_pfn_extras.profiler._record import record_iterable # NOQA
Expand Down
34 changes: 21 additions & 13 deletions pytorch_pfn_extras/profiler/_chrome_tracing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import time
import os
import threading
import time
from typing import Any, Dict, List, Literal, Optional, Union

import torch
Expand All @@ -23,7 +24,9 @@ def __init__(
def __enter__(self) -> None:
self.begin_ns = time.perf_counter_ns()

def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]:
def __exit__(
self, exc_type: Any, exc_value: Any, tracebac: Any
) -> Literal[False]:
if self.is_cuda_available:
torch.cuda.synchronize() # Wait for process to complete

Expand All @@ -46,7 +49,9 @@ class ChromeTracingEventDisabled:
def __enter__(self) -> None:
pass

def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]:
def __exit__(
self, exc_type: Any, exc_value: Any, tracebac: Any
) -> Literal[False]:
return False


Expand All @@ -59,14 +64,14 @@ def __call__(self, target: Dict[str, Any], file_o: Any) -> None:
class ChromeTracingEmitter:
def __init__(
self,
out_trace_json_path: str,
writer: Writer = None,
filename: str,
writer: Writer,
max_event_count: Optional[int] = None,
enable: bool = True,
):
self.enable = enable
self.event_list: List[Dict[str, Union[str, int, float]]] = []
self.out_trace_json_path = out_trace_json_path
self.filename = filename
self.max_event_count = max_event_count or float("inf")
self.event_count = 0
self.is_cuda_available = torch.cuda.is_available()
Expand Down Expand Up @@ -96,8 +101,8 @@ def flush(self) -> None:
# file pointer and with json.dumps?
savefun = ChromeTracingSaveFunc()
self.writer(
"chrome_tracing.json",
self.out_trace_json_path,
self.filename,
"", # out_dir arg is ignored in the writer, uses the writer attr
self.event_list,
savefun=savefun,
)
Expand All @@ -106,9 +111,12 @@ def flush(self) -> None:
_thread_local = threading.local()


def get_chrome_tracer(out_dir: str, writer: Writer) -> ChromeTracingEmitter:
def get_chrome_tracer(filename: str, writer: Writer) -> ChromeTracingEmitter:
trace_path = os.path.join(writer.out_dir, filename)
if not hasattr(_thread_local, "chrome_tracer"):
_thread_local.chrome_tracer = ChromeTracingEmitter(
out_dir, writer, None, True
)
return _thread_local.chrome_tracer # type: ignore[no-any-return]
_thread_local.chrome_tracer = {}
tracer = _thread_local.chrome_tracer.get(trace_path, None)
if tracer is None:
tracer = ChromeTracingEmitter(filename, writer, None, True)
_thread_local.chrome_tracer[trace_path] = tracer
return tracer # type: ignore[no-any-return]
41 changes: 32 additions & 9 deletions pytorch_pfn_extras/profiler/_record.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import types
import time
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Expand All @@ -15,6 +14,7 @@
import torch
from pytorch_pfn_extras.profiler import _chrome_tracing, _time_summary
from pytorch_pfn_extras.runtime import runtime_registry
from pytorch_pfn_extras.writing import Writer

if TYPE_CHECKING:
from pytorch_pfn_extras.runtime._runtime import DeviceLike
Expand Down Expand Up @@ -51,17 +51,21 @@ def dummy_tracer(tag: Optional[str]) -> Generator[None, None, None]:

@contextmanager
def tracer(
tag: Optional[str],
tag: str,
device: "DeviceLike" = "cpu",
chrome_tracing_out: Optional[str] = None,
chrome_writer: Optional[Writer] = None,
) -> Generator[None, None, None]:
runtime_cls = runtime_registry.get_runtime_class_for_device_spec(device)
runtime_tracer = runtime_cls.trace

if chrome_tracing_out is not None:
chrome_tracer = _chrome_tracing.get_chrome_tracer(chrome_tracing_out)
assert chrome_writer is not None
chrome_tracer = _chrome_tracing.get_chrome_tracer(
chrome_tracing_out, chrome_writer
).add_event
else:
chrome_tracer = dummy_tracer
chrome_tracer = dummy_tracer # type: ignore[assignment]

with runtime_tracer(tag, None), chrome_tracer(tag):
yield
Expand All @@ -75,6 +79,7 @@ def record(
enable: bool = True,
device: "DeviceLike" = "cpu",
chrome_tracing_out: Optional[str] = None,
chrome_writer: Optional[Writer] = None,
) -> Generator[_time_summary._ReportNotification, None, None]:
if not enable:
yield _DummyReportNotification()
Expand All @@ -89,13 +94,13 @@ def record(
if use_cuda:
torch.cuda.nvtx.range_push(tag) # type: ignore[no-untyped-call]
try:
with tracer(tag, device, chrome_tracing_out):
with tracer(tag, device, chrome_tracing_out, chrome_writer):
time_summary = _time_summary.get_time_summary()
time_summary.report("start_time", time.time())
# time_summary.report("start_time", time.time())
with time_summary.report(metric, use_cuda) as ntf:
yield ntf
finally:
time_summary.report("end_time", time.time())
# time_summary.report("end_time", time.time())
if use_cuda:
torch.cuda.nvtx.range_pop() # type: ignore[no-untyped-call]

Expand All @@ -109,11 +114,20 @@ def record_function(
enable: bool = True,
device: "DeviceLike" = "cpu",
chrome_tracing_out: Optional[str] = None,
chrome_writer: Optional[Writer] = None,
) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
def wrapper(f: Callable[..., _T]) -> Callable[..., _T]:
def wrapped(*args: Any, **kwargs: Any) -> _T:
name = tag or f.__name__
with record(name, None, use_cuda, device, chrome_tracing_out):
with record(
name,
None,
use_cuda,
enable,
device,
chrome_tracing_out,
chrome_writer,
):
return f(*args, **kwargs)

return wrapped
Expand All @@ -129,6 +143,7 @@ def record_iterable(
enable: bool = True,
device: "DeviceLike" = "cpu",
chrome_tracing_out: Optional[str] = None,
chrome_writer: Optional[Writer] = None,
) -> Iterable[_T]:
if tag is None:
tag = _infer_tag_name(inspect.currentframe(), depth=1)
Expand All @@ -137,7 +152,15 @@ def wrapped() -> Iterable[_T]:
for i, x in enumerate(iter):
name = f"{tag}-{i}"
metric = name if divide_metric else tag
with record(name, metric, use_cuda, device, chrome_tracing_out):
with record(
name,
metric,
use_cuda,
enable,
device,
chrome_tracing_out,
chrome_writer,
):
yield x

return wrapped()
21 changes: 11 additions & 10 deletions tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tempfile

import pytest
import pytorch_pfn_extras as ppe
Expand Down Expand Up @@ -138,14 +139,14 @@ def test_record_iterable_with_chrome_trace(device):
model.to(device)

x = torch.arange(30, dtype=torch.float32).to(device)
iters = [x, x, x]

with torch.profiler.profile() as prof:
for x in ppe.profiler.record(None, iters):
model(x)

keys = [event.key for event in prof.key_averages()]
assert "aten::linear" in keys
assert any(k.endswith("test_record_iterable_without_tag-0") for k in keys)
assert any(k.endswith("test_record_iterable_without_tag-1") for k in keys)
assert any(k.endswith("test_record_iterable_without_tag-2") for k in keys)
x = torch.arange(30, dtype=torch.float32).to(device)
with tempfile.TemporaryDirectory() as t_path:
w = ppe.writing.SimpleWriter(out_dir=t_path)
with torch.profiler.profile():
with ppe.profiler.record(
"tag", chrome_tracing_out="trace.json", chrome_writer=w
):
model(x)
ppe.profiler.get_chrome_tracer("trace.json", w).flush()
assert os.path.exists(os.path.join(t_path, "trace.json"))

0 comments on commit ab0d1d3

Please sign in to comment.