forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ports over the step closure functionality from PyTorch/XLA to Lazy Tensor Core: References: https://github.com/pytorch/xla/blob/205ae574c0a24e092899ea8610c360f93f5d8142/torch_xla/core/xla_model.py#L852-L900 https://github.com/pytorch/xla/blob/205ae574c0a24e092899ea8610c360f93f5d8142/torch_xla/utils/closures.py#L7-L83 CC: @wconstab @JackCaoG @Krovatkin Pull Request resolved: pytorch#84300 Approved by: https://github.com/JackCaoG, https://github.com/wconstab
- Loading branch information
1 parent
02da943
commit bab1304
Showing
6 changed files
with
267 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Owner(s): ["oncall: jit"] | ||
|
||
from threading import Event | ||
from time import sleep | ||
|
||
import torch._lazy | ||
import torch._lazy.ts_backend | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
|
||
torch._lazy.ts_backend.init() | ||
|
||
|
||
class ClosuresTest(TestCase): | ||
def test_synchronous(self): | ||
flag = Event() | ||
assert not flag.is_set() | ||
|
||
def closure(): | ||
sleep(1) | ||
assert not flag.is_set() | ||
flag.set() | ||
|
||
torch._lazy.add_step_closure(closure) | ||
torch._lazy.mark_step() | ||
|
||
# should not get to this part before closure is finished running | ||
assert flag.is_set() | ||
|
||
def test_asynchronous(self): | ||
flag = Event() | ||
assert not flag.is_set() | ||
|
||
def closure(): | ||
sleep(1) | ||
assert flag.is_set() | ||
|
||
torch._lazy.add_step_closure(closure, run_async=True) | ||
torch._lazy.mark_step() | ||
|
||
# should get to this part and complete before closure is finished running | ||
assert not flag.is_set() | ||
flag.set() | ||
|
||
def test_synchronous_exception(self): | ||
flag = Event() | ||
assert not flag.is_set() | ||
|
||
try: | ||
|
||
def closure(): | ||
flag.set() | ||
raise RuntimeError("Simulating exception in closure") | ||
|
||
torch._lazy.add_step_closure(closure) | ||
torch._lazy.mark_step() | ||
|
||
raise AssertionError() # Should not reach here | ||
except RuntimeError as e: | ||
assert flag.is_set(), "Should have caught exception from closure" | ||
|
||
def test_asynchronous_exception(self): | ||
flag = Event() | ||
assert not flag.is_set() | ||
|
||
def closure1(): | ||
flag.set() | ||
raise RuntimeError("Simulating exception in closure1") | ||
|
||
torch._lazy.add_step_closure(closure1, run_async=True) | ||
torch._lazy.mark_step() | ||
|
||
flag.wait(timeout=5) | ||
|
||
try: | ||
|
||
def closure2(): # Should never execute | ||
flag.clear() | ||
|
||
torch._lazy.add_step_closure(closure2, run_async=True) | ||
torch._lazy.mark_step() | ||
|
||
raise AssertionError() # Should not reach here | ||
except RuntimeError as e: | ||
# Should have caught exception from closure1 | ||
pass | ||
|
||
assert flag.is_set() | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import os | ||
import threading | ||
from queue import Empty as EmptyQueue, Queue | ||
|
||
from torch._lazy.device_context import get_device_context | ||
|
||
|
||
class ClosureHandler: | ||
def __init__(self): | ||
pass | ||
|
||
def run(self, closure): | ||
"""Run closure function | ||
Args: | ||
closure: callable function to run | ||
""" | ||
closure() | ||
|
||
def __call__(self, closures): | ||
for closure in closures: | ||
self.run(closure) | ||
|
||
|
||
class AsyncClosureHandler(ClosureHandler): | ||
"""Handler for Asynchronous Step Closures | ||
Args: | ||
max_queue_size: The maximum length of the closure queue after which | ||
the training loop will block until closures are evaluated. | ||
By default, a reasonable limit of a maximum of 100 on the queue. | ||
This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment | ||
variable. | ||
""" | ||
|
||
def __init__(self, max_queue_size=100): | ||
super().__init__() | ||
self._closure_queue: Queue = Queue( | ||
int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size)) | ||
) | ||
self._closure_exception: Queue = Queue() | ||
self._closure_lock = threading.Lock() | ||
self._closure_event_loop_finished = threading.Event() | ||
self._closure_event_loop = None | ||
|
||
def start_event_loop(self): | ||
"""Start closure event loop if not started""" | ||
if self._closure_event_loop is None: | ||
|
||
def event_loop(): | ||
# Run loop until closure event is set and closure queue is empty | ||
while True: | ||
try: | ||
closure = self._closure_queue.get(block=True, timeout=3) | ||
closure() | ||
self._closure_queue.task_done() | ||
except EmptyQueue: | ||
with self._closure_lock: | ||
if self._closure_queue.empty(): | ||
self._closure_event_loop_finished.set() | ||
return | ||
except Exception as e: | ||
self._closure_exception.put(e) | ||
return | ||
|
||
self._closure_event_loop = threading.Thread(target=event_loop) | ||
self._closure_event_loop.start() | ||
|
||
def run(self, closure): | ||
with self._closure_lock: | ||
self._closure_queue.put(closure, block=True) | ||
if ( | ||
self._closure_event_loop is None | ||
or not self._closure_event_loop.is_alive() | ||
): | ||
try: | ||
e = self._closure_exception.get(block=False) | ||
raise RuntimeError( | ||
"Cannot run asynchronous closure due to previously raised exception" | ||
) from e | ||
except EmptyQueue: | ||
self._closure_event_loop = None | ||
self.start_event_loop() | ||
|
||
|
||
def add_step_closure(closure, args=(), run_async=False): | ||
"""Adds a closure to the list of the ones to be run at the end of the step. | ||
Many times during model training there is the need to print/report (print to | ||
console, post to tensorboard, etc...) information which require the content of | ||
intermediary tensors to be inspected. | ||
Inspecting different tensors content in different points of the model code | ||
requires many executions and typically causes performance issues. | ||
Adding a step closure will ensure that it will be run after the barrier, when | ||
all the live tensors will be already materialized to device data. | ||
Live tensors which will include the ones captured by the closure arguments. | ||
So using `add_step_closure()` will ensure a single execution will be | ||
performed, even when multiple closures are queued, requiring multiple tensors | ||
to be inspected. | ||
Step closures will be run sequentially in the order they have been queued. | ||
Note that even though using this API the execution will be optimized, it is | ||
advised to throttle the printing/reporting events once every N steps. | ||
Args: | ||
closure (callable): The function to be called. | ||
args (tuple): The arguments to be passed to the closure. | ||
run_async: If True, run the closure asynchronously. | ||
""" | ||
devctx = get_device_context() | ||
closures_type = "async_step_closures" if run_async else "step_closures" | ||
step_closures = getattr(devctx, closures_type, None) | ||
if step_closures is None: | ||
step_closures = [] | ||
setattr(devctx, closures_type, step_closures) | ||
step_closures.append(lambda a=args: closure(*a)) | ||
|
||
|
||
def run_step_closures(): | ||
devctx = get_device_context() | ||
async_step_closures = getattr(devctx, "async_step_closures", None) | ||
if async_step_closures is not None: | ||
devctx.async_step_closures = [] | ||
async_closure_handler = getattr(devctx, "async_closure_handler", None) | ||
if async_closure_handler is None: | ||
async_closure_handler = AsyncClosureHandler() | ||
devctx.async_closure_handler = async_closure_handler | ||
async_closure_handler(async_step_closures) | ||
|
||
step_closures = getattr(devctx, "step_closures", None) | ||
if step_closures is not None: | ||
devctx.step_closures = [] | ||
closure_handler = getattr(devctx, "closure_handler", None) | ||
if closure_handler is None: | ||
closure_handler = ClosureHandler() | ||
devctx.closure_handler = closure_handler | ||
closure_handler(step_closures) | ||
return devctx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import threading | ||
from typing import Any, Dict | ||
|
||
import torch._C._lazy | ||
|
||
|
||
class DeviceContext: | ||
_CONTEXTS: Dict[str, Any] = dict() | ||
_CONTEXTS_LOCK = threading.Lock() | ||
|
||
def __init__(self, device): | ||
self.device = device | ||
|
||
|
||
def get_device_context(device=None): | ||
if device is None: | ||
device = torch._C._lazy._get_default_device_type() | ||
else: | ||
device = str(device) | ||
with DeviceContext._CONTEXTS_LOCK: | ||
devctx = DeviceContext._CONTEXTS.get(device, None) | ||
if devctx is None: | ||
devctx = DeviceContext(device) | ||
DeviceContext._CONTEXTS[device] = devctx | ||
return devctx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters