Skip to content

Commit

Permalink
more, wip, todo
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jul 1, 2024
1 parent 907ca81 commit a55823e
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions returnn/torch/util/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from typing import Optional, Union, Any, Sequence, List, Dict
from dataclasses import dataclass, field
import contextlib
from weakref import ref
import threading

import torch
from torch.utils.weak import WeakTensorKeyDictionary # needs Torch >=2.0.0
Expand Down Expand Up @@ -57,15 +59,18 @@ def __init__(self):
self.record_graph_scope = _RecordGraph()
self.record_graph_scope.graph.gradient_checkpoint_scope_backref = self
# Note: saved_tensors_hooks is thread local.
# TODO maybe hook into saved_tensors_hooks.__enter__ and __exit__ to fix our del issue?
self.saved_tensors_hooks_scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
self.entered = False
self.entered_thread_ref = None
self.exit_args: Optional[tuple] = None
self.exited_saved_tensors_hooks_scope = False

def __enter__(self):
self.record_graph_scope.__enter__()
self.saved_tensors_hooks_scope.__enter__()
self.entered = True
self.entered_thread_ref = ref(threading.current_thread())

def __exit__(self, exc_type, exc_val, exc_tb):
self.exit_args = (exc_type, exc_val, exc_tb)
Expand All @@ -78,9 +83,17 @@ def __exit__(self, exc_type, exc_val, exc_tb):
else:
self.exit_saved_tensors_hooks_scope()

# Note, be very careful what we do in __del__.
# We do not directly want to exit_saved_tensors_hooks_scope() there
# because it might be called in a different thread.
def __del__(self):
# Note that we keep this alive via _Graph.gradient_checkpoint_scope_backref
# as long as any _GraphTensor is alive due to backprop pack_hook.
# Note, be very careful what we do in __del__ because it might be called in a different thread!
if self.entered_thread_ref() is threading.current_thread():
# We are still in the same thread.
# This is fine, we can exit the scope.
self.exit_saved_tensors_hooks_scope()
else:
# TODO what now?
pass

def exit_saved_tensors_hooks_scope(self):
"""
Expand Down

0 comments on commit a55823e

Please sign in to comment.