Skip to content

Commit

Permalink
Typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Aug 23, 2023
1 parent b640d4d commit a9ec5c7
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions pytorch_pfn_extras/profiler/_chrome_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
import time
import threading
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from pytorch_pfn_extras.writing import Writer


class ChromeTracingEvent:
def __init__(
self,
emitter: Any,
name: str,
category_list: Optional[List[str]],
is_cuda_available: bool,
):
self.emitter = emitter
self.name = name
self.category_list = category_list or []
self.is_cuda_available = is_cuda_available

def __enter__(self) -> None:
self.begin_ns = time.perf_counter_ns()

def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]:
if self.is_cuda_available:
torch.cuda.synchronize() # Wait for process to complete

duration_ns = time.perf_counter_ns() - self.begin_ns
self.emitter.emit(
dict(
name=self.name,
cat=",".join(self.category_list),
ph="X",
ts=self.begin_ns / 1000, # nano sec -> micro sec
dur=duration_ns / 1000, # ditto
pid=0,
tid=0,
)
)
return False


class ChromeTracingEventDisabled:
def __enter__(self) -> None:
pass

def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]:
return False


class ChromeTracingSaveFunc:
def __call__(self, target: Dict[str, Any], file_o: Any) -> None:
log = json.dumps(target, indent=4)
file_o.write(bytes(log.encode("ascii")))


class ChromeTracingEmitter:
def __init__(
self,
out_trace_json_path: str,
writer: Writer = None,
max_event_count: Optional[int] = None,
enable: bool = True,
):
self.enable = enable
self.event_list: List[Dict[str, Union[str, int, float]]] = []
self.out_trace_json_path = out_trace_json_path
self.max_event_count = max_event_count or float("inf")
self.event_count = 0
self.is_cuda_available = torch.cuda.is_available()
self.writer = writer

def add_event(
self, name: str, category_list: Optional[List[str]] = None
) -> Union[ChromeTracingEvent, ChromeTracingEventDisabled]:
if not self.enable or self.event_count >= self.max_event_count:
return ChromeTracingEventDisabled()
self.event_count += 1
return ChromeTracingEvent(
self,
name=name,
category_list=category_list,
is_cuda_available=self.is_cuda_available,
)

def emit(self, event: Dict[str, Union[str, int, float]]) -> None:
assert self.enable
self.event_list.append(event)

def flush(self) -> None:
if not self.enable:
return
# TODO(ecastill): try to work on some append mode manipulating the
# file pointer and with json.dumps?
savefun = ChromeTracingSaveFunc()
self.writer(
"chrome_tracing.json",
self.out_trace_json_path,
self.event_list,
savefun=savefun,
)


_thread_local = threading.local()


def get_chrome_tracer(out_dir: str, writer: Writer) -> ChromeTracingEmitter:
if not hasattr(_thread_local, "chrome_tracer"):
_thread_local.chrome_tracer = ChromeTracingEmitter(
out_dir, writer, None, True
)
return _thread_local.chrome_tracer # type: ignore[no-any-return]

0 comments on commit a9ec5c7

Please sign in to comment.