Skip to content

Commit

Permalink
Add step closures (pytorch#84300)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim authored and pytorchmergebot committed Sep 6, 2022
1 parent 02da943 commit bab1304
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 2 deletions.
91 changes: 91 additions & 0 deletions test/lazy/test_step_closures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Owner(s): ["oncall: jit"]

from threading import Event
from time import sleep

import torch._lazy
import torch._lazy.ts_backend
from torch.testing._internal.common_utils import run_tests, TestCase

torch._lazy.ts_backend.init()


class ClosuresTest(TestCase):
def test_synchronous(self):
flag = Event()
assert not flag.is_set()

def closure():
sleep(1)
assert not flag.is_set()
flag.set()

torch._lazy.add_step_closure(closure)
torch._lazy.mark_step()

# should not get to this part before closure is finished running
assert flag.is_set()

def test_asynchronous(self):
flag = Event()
assert not flag.is_set()

def closure():
sleep(1)
assert flag.is_set()

torch._lazy.add_step_closure(closure, run_async=True)
torch._lazy.mark_step()

# should get to this part and complete before closure is finished running
assert not flag.is_set()
flag.set()

def test_synchronous_exception(self):
flag = Event()
assert not flag.is_set()

try:

def closure():
flag.set()
raise RuntimeError("Simulating exception in closure")

torch._lazy.add_step_closure(closure)
torch._lazy.mark_step()

raise AssertionError() # Should not reach here
except RuntimeError as e:
assert flag.is_set(), "Should have caught exception from closure"

def test_asynchronous_exception(self):
flag = Event()
assert not flag.is_set()

def closure1():
flag.set()
raise RuntimeError("Simulating exception in closure1")

torch._lazy.add_step_closure(closure1, run_async=True)
torch._lazy.mark_step()

flag.wait(timeout=5)

try:

def closure2(): # Should never execute
flag.clear()

torch._lazy.add_step_closure(closure2, run_async=True)
torch._lazy.mark_step()

raise AssertionError() # Should not reach here
except RuntimeError as e:
# Should have caught exception from closure1
pass

assert flag.is_set()


if __name__ == "__main__":
run_tests()
10 changes: 8 additions & 2 deletions torch/_C/_lazy.pyi
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from typing import List
from torch import Tensor

#defined in torch/csrc/lazy/python/init.cpp
# defined in torch/csrc/lazy/python/init.cpp
def _mark_step(device: str, devices: List[str], wait: bool): ...
def _wait_device_ops(devices: List[str]): ...
def _reset_metrics(): ...
def _counter_names() -> List[str]: ...
def _counter_value(name: str) -> int: ...
def _metrics_report() -> str: ...
def _get_graph_hash(tensors: List[Tensor]) -> str: ...
def _sync_multi(tensors: List[Tensor], devices: List[str], wait: bool = True, sync_ltc_data: bool = True): ...
def _sync_multi(
tensors: List[Tensor],
devices: List[str],
wait: bool = True,
sync_ltc_data: bool = True,
): ...
def _get_tensor_id(tensor: Tensor) -> int: ...
def _get_tensors_text(tensors: List[Tensor]) -> str: ...
def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
Expand All @@ -19,3 +24,4 @@ def _set_force_fallback(newval: str): ...
def _clear_ir_cache(): ...
def _dump_ir_cache(filename: str): ...
def _set_reuse_ir(val: bool): ...
def _get_default_device_type(): ...
6 changes: 6 additions & 0 deletions torch/_lazy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import threading

import torch._C._lazy
from torch.utils._pytree import tree_flatten, tree_unflatten

from .closure import add_step_closure, run_step_closures


def mark_step(device: str = "", wait=False):
"""Triggers a mark step, which amounts to
Expand All @@ -12,6 +16,8 @@ def mark_step(device: str = "", wait=False):
# TODO(whc) expand this to include backend hooks and align with XLA backend needs
torch._C._lazy._mark_step(device, [], wait=wait)

run_step_closures()


def wait_device_ops(devices=None):
"""Waits for all the async operations on the given devices to complete.
Expand Down
134 changes: 134 additions & 0 deletions torch/_lazy/closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
import threading
from queue import Empty as EmptyQueue, Queue

from torch._lazy.device_context import get_device_context


class ClosureHandler:
def __init__(self):
pass

def run(self, closure):
"""Run closure function
Args:
closure: callable function to run
"""
closure()

def __call__(self, closures):
for closure in closures:
self.run(closure)


class AsyncClosureHandler(ClosureHandler):
"""Handler for Asynchronous Step Closures
Args:
max_queue_size: The maximum length of the closure queue after which
the training loop will block until closures are evaluated.
By default, a reasonable limit of a maximum of 100 on the queue.
This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment
variable.
"""

def __init__(self, max_queue_size=100):
super().__init__()
self._closure_queue: Queue = Queue(
int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size))
)
self._closure_exception: Queue = Queue()
self._closure_lock = threading.Lock()
self._closure_event_loop_finished = threading.Event()
self._closure_event_loop = None

def start_event_loop(self):
"""Start closure event loop if not started"""
if self._closure_event_loop is None:

def event_loop():
# Run loop until closure event is set and closure queue is empty
while True:
try:
closure = self._closure_queue.get(block=True, timeout=3)
closure()
self._closure_queue.task_done()
except EmptyQueue:
with self._closure_lock:
if self._closure_queue.empty():
self._closure_event_loop_finished.set()
return
except Exception as e:
self._closure_exception.put(e)
return

self._closure_event_loop = threading.Thread(target=event_loop)
self._closure_event_loop.start()

def run(self, closure):
with self._closure_lock:
self._closure_queue.put(closure, block=True)
if (
self._closure_event_loop is None
or not self._closure_event_loop.is_alive()
):
try:
e = self._closure_exception.get(block=False)
raise RuntimeError(
"Cannot run asynchronous closure due to previously raised exception"
) from e
except EmptyQueue:
self._closure_event_loop = None
self.start_event_loop()


def add_step_closure(closure, args=(), run_async=False):
"""Adds a closure to the list of the ones to be run at the end of the step.
Many times during model training there is the need to print/report (print to
console, post to tensorboard, etc...) information which require the content of
intermediary tensors to be inspected.
Inspecting different tensors content in different points of the model code
requires many executions and typically causes performance issues.
Adding a step closure will ensure that it will be run after the barrier, when
all the live tensors will be already materialized to device data.
Live tensors which will include the ones captured by the closure arguments.
So using `add_step_closure()` will ensure a single execution will be
performed, even when multiple closures are queued, requiring multiple tensors
to be inspected.
Step closures will be run sequentially in the order they have been queued.
Note that even though using this API the execution will be optimized, it is
advised to throttle the printing/reporting events once every N steps.
Args:
closure (callable): The function to be called.
args (tuple): The arguments to be passed to the closure.
run_async: If True, run the closure asynchronously.
"""
devctx = get_device_context()
closures_type = "async_step_closures" if run_async else "step_closures"
step_closures = getattr(devctx, closures_type, None)
if step_closures is None:
step_closures = []
setattr(devctx, closures_type, step_closures)
step_closures.append(lambda a=args: closure(*a))


def run_step_closures():
devctx = get_device_context()
async_step_closures = getattr(devctx, "async_step_closures", None)
if async_step_closures is not None:
devctx.async_step_closures = []
async_closure_handler = getattr(devctx, "async_closure_handler", None)
if async_closure_handler is None:
async_closure_handler = AsyncClosureHandler()
devctx.async_closure_handler = async_closure_handler
async_closure_handler(async_step_closures)

step_closures = getattr(devctx, "step_closures", None)
if step_closures is not None:
devctx.step_closures = []
closure_handler = getattr(devctx, "closure_handler", None)
if closure_handler is None:
closure_handler = ClosureHandler()
devctx.closure_handler = closure_handler
closure_handler(step_closures)
return devctx
25 changes: 25 additions & 0 deletions torch/_lazy/device_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import threading
from typing import Any, Dict

import torch._C._lazy


class DeviceContext:
_CONTEXTS: Dict[str, Any] = dict()
_CONTEXTS_LOCK = threading.Lock()

def __init__(self, device):
self.device = device


def get_device_context(device=None):
if device is None:
device = torch._C._lazy._get_default_device_type()
else:
device = str(device)
with DeviceContext._CONTEXTS_LOCK:
devctx = DeviceContext._CONTEXTS.get(device, None)
if devctx is None:
devctx = DeviceContext(device)
DeviceContext._CONTEXTS[device] = devctx
return devctx
3 changes: 3 additions & 0 deletions torch/csrc/lazy/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ void initLazyBindings(PyObject* module) {
lazy.def("_get_symbolic_shape_mode", []() {
return FLAGS_ltc_enable_symbolic_shapes;
});
lazy.def("_get_default_device_type", []() {
return getBackend()->GetDefaultDeviceType()->toString();
});

lazy_ts_backend.def("_init", []() {
#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
Expand Down

0 comments on commit bab1304

Please sign in to comment.