From a55823e00bae330f7f118baf79a3a5330f8be445 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 1 Jul 2024 12:20:42 +0200 Subject: [PATCH] more, wip, todo --- returnn/torch/util/gradient_checkpoint.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/returnn/torch/util/gradient_checkpoint.py b/returnn/torch/util/gradient_checkpoint.py index c6315d837..e6a90a2e7 100644 --- a/returnn/torch/util/gradient_checkpoint.py +++ b/returnn/torch/util/gradient_checkpoint.py @@ -23,6 +23,8 @@ from typing import Optional, Union, Any, Sequence, List, Dict from dataclasses import dataclass, field import contextlib +from weakref import ref +import threading import torch from torch.utils.weak import WeakTensorKeyDictionary # needs Torch >=2.0.0 @@ -57,8 +59,10 @@ def __init__(self): self.record_graph_scope = _RecordGraph() self.record_graph_scope.graph.gradient_checkpoint_scope_backref = self # Note: saved_tensors_hooks is thread local. + # TODO maybe hook into saved_tensors_hooks.__enter__ and __exit__ to fix our del issue? self.saved_tensors_hooks_scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook) self.entered = False + self.entered_thread_ref = None self.exit_args: Optional[tuple] = None self.exited_saved_tensors_hooks_scope = False @@ -66,6 +70,7 @@ def __enter__(self): self.record_graph_scope.__enter__() self.saved_tensors_hooks_scope.__enter__() self.entered = True + self.entered_thread_ref = ref(threading.current_thread()) def __exit__(self, exc_type, exc_val, exc_tb): self.exit_args = (exc_type, exc_val, exc_tb) @@ -78,9 +83,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): else: self.exit_saved_tensors_hooks_scope() - # Note, be very careful what we do in __del__. - # We do not directly want to exit_saved_tensors_hooks_scope() there - # because it might be called in a different thread. + def __del__(self): + # Note that we keep this alive via _Graph.gradient_checkpoint_scope_backref + # as long as any _GraphTensor is alive due to backprop pack_hook. + # Note, be very careful what we do in __del__ because it might be called in a different thread! + if self.entered_thread_ref() is threading.current_thread(): + # We are still in the same thread. + # This is fine, we can exit the scope. + self.exit_saved_tensors_hooks_scope() + else: + # TODO what now? + pass def exit_saved_tensors_hooks_scope(self): """