-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Emilio Castillo
committed
Aug 23, 2023
1 parent
b640d4d
commit a9ec5c7
Showing
1 changed file
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import json | ||
import time | ||
import threading | ||
from typing import Any, Dict, List, Literal, Optional, Union | ||
|
||
import torch | ||
from pytorch_pfn_extras.writing import Writer | ||
|
||
|
||
class ChromeTracingEvent: | ||
def __init__( | ||
self, | ||
emitter: Any, | ||
name: str, | ||
category_list: Optional[List[str]], | ||
is_cuda_available: bool, | ||
): | ||
self.emitter = emitter | ||
self.name = name | ||
self.category_list = category_list or [] | ||
self.is_cuda_available = is_cuda_available | ||
|
||
def __enter__(self) -> None: | ||
self.begin_ns = time.perf_counter_ns() | ||
|
||
def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]: | ||
if self.is_cuda_available: | ||
torch.cuda.synchronize() # Wait for process to complete | ||
|
||
duration_ns = time.perf_counter_ns() - self.begin_ns | ||
self.emitter.emit( | ||
dict( | ||
name=self.name, | ||
cat=",".join(self.category_list), | ||
ph="X", | ||
ts=self.begin_ns / 1000, # nano sec -> micro sec | ||
dur=duration_ns / 1000, # ditto | ||
pid=0, | ||
tid=0, | ||
) | ||
) | ||
return False | ||
|
||
|
||
class ChromeTracingEventDisabled: | ||
def __enter__(self) -> None: | ||
pass | ||
|
||
def __exit__(self, exc_type: Any, exc_value: Any, tracebac: Any) -> Literal[False]: | ||
return False | ||
|
||
|
||
class ChromeTracingSaveFunc: | ||
def __call__(self, target: Dict[str, Any], file_o: Any) -> None: | ||
log = json.dumps(target, indent=4) | ||
file_o.write(bytes(log.encode("ascii"))) | ||
|
||
|
||
class ChromeTracingEmitter: | ||
def __init__( | ||
self, | ||
out_trace_json_path: str, | ||
writer: Writer = None, | ||
max_event_count: Optional[int] = None, | ||
enable: bool = True, | ||
): | ||
self.enable = enable | ||
self.event_list: List[Dict[str, Union[str, int, float]]] = [] | ||
self.out_trace_json_path = out_trace_json_path | ||
self.max_event_count = max_event_count or float("inf") | ||
self.event_count = 0 | ||
self.is_cuda_available = torch.cuda.is_available() | ||
self.writer = writer | ||
|
||
def add_event( | ||
self, name: str, category_list: Optional[List[str]] = None | ||
) -> Union[ChromeTracingEvent, ChromeTracingEventDisabled]: | ||
if not self.enable or self.event_count >= self.max_event_count: | ||
return ChromeTracingEventDisabled() | ||
self.event_count += 1 | ||
return ChromeTracingEvent( | ||
self, | ||
name=name, | ||
category_list=category_list, | ||
is_cuda_available=self.is_cuda_available, | ||
) | ||
|
||
def emit(self, event: Dict[str, Union[str, int, float]]) -> None: | ||
assert self.enable | ||
self.event_list.append(event) | ||
|
||
def flush(self) -> None: | ||
if not self.enable: | ||
return | ||
# TODO(ecastill): try to work on some append mode manipulating the | ||
# file pointer and with json.dumps? | ||
savefun = ChromeTracingSaveFunc() | ||
self.writer( | ||
"chrome_tracing.json", | ||
self.out_trace_json_path, | ||
self.event_list, | ||
savefun=savefun, | ||
) | ||
|
||
|
||
_thread_local = threading.local() | ||
|
||
|
||
def get_chrome_tracer(out_dir: str, writer: Writer) -> ChromeTracingEmitter: | ||
if not hasattr(_thread_local, "chrome_tracer"): | ||
_thread_local.chrome_tracer = ChromeTracingEmitter( | ||
out_dir, writer, None, True | ||
) | ||
return _thread_local.chrome_tracer # type: ignore[no-any-return] |