Skip to content

Commit

Permalink
Revert "[Modes] Add assert that the mode isn't already on the stack (p…
Browse files Browse the repository at this point in the history
…ytorch#90770)"

This reverts commit 7028386.

Reverted pytorch#90770 on behalf of https://github.com/DanilBaibak due to Break internal build
  • Loading branch information
pytorchmergebot committed Jan 12, 2023
1 parent a383789 commit db466ae
Show file tree
Hide file tree
Showing 13 changed files with 22 additions and 191 deletions.
5 changes: 0 additions & 5 deletions c10/core/impl/PyInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void trace_gpu_device_synchronization() const override {}
void trace_gpu_stream_synchronization(uintptr_t stream) const override {}
void trace_gpu_event_synchronization(uintptr_t event) const override {}

void mode_state_push_trampoline(
std::shared_ptr<SafePyObject> mode) const override{};
void mode_state_pop_trampoline(
std::shared_ptr<SafePyObject> mode) const override{};
};

void PyInterpreter::disarm() noexcept {
Expand Down
5 changes: 0 additions & 5 deletions c10/core/impl/PyInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,6 @@ struct C10_API PyInterpreterVTable {
virtual void trace_gpu_device_synchronization() const = 0;
virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0;
virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0;

virtual void mode_state_push_trampoline(
std::shared_ptr<SafePyObject> mode) const = 0;
virtual void mode_state_pop_trampoline(
std::shared_ptr<SafePyObject> mode) const = 0;
};

struct C10_API PyInterpreter {
Expand Down
21 changes: 5 additions & 16 deletions c10/core/impl/TorchDispatchModeTLS.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <c10/core/DispatchKeySet.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>

Expand All @@ -7,25 +8,21 @@ namespace impl {

thread_local TorchDispatchModeTLS torchDispatchModeState;

void TorchDispatchModeTLS::push_onto_stack(
std::shared_ptr<c10::SafePyObject> mode) {
void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
if (torchDispatchModeState.stack_.size() == 0) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
mode->pyinterpreter()->mode_state_push_trampoline(mode);
torchDispatchModeState.stack_.push_back(std::move(mode));
}

const std::shared_ptr<c10::SafePyObject> TorchDispatchModeTLS::pop_stack() {
const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
TORCH_CHECK(
torchDispatchModeState.stack_.size() > 0,
"trying to pop from empty mode stack");

std::shared_ptr<c10::SafePyObject> out = torchDispatchModeState.stack_.back();
std::shared_ptr<SafePyObject> out = torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
out->pyinterpreter()->mode_state_pop_trampoline(out);

if (torchDispatchModeState.stack_.size() == 0) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
Expand All @@ -35,7 +32,7 @@ const std::shared_ptr<c10::SafePyObject> TorchDispatchModeTLS::pop_stack() {
return out;
}

const std::shared_ptr<c10::SafePyObject>& TorchDispatchModeTLS::get_stack_at(
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
int64_t idx) {
TORCH_CHECK(
idx < static_cast<int64_t>(torchDispatchModeState.stack_.size()),
Expand All @@ -52,15 +49,7 @@ const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
}

void TorchDispatchModeTLS::set_state(const TorchDispatchModeTLS& state) {
for (const std::shared_ptr<c10::SafePyObject>& state :
torchDispatchModeState.stack_) {
state->pyinterpreter()->mode_state_pop_trampoline(state);
}
for (const std::shared_ptr<c10::SafePyObject>& state : state.stack_) {
state->pyinterpreter()->mode_state_push_trampoline(state);
}
torchDispatchModeState = state;

if (torchDispatchModeState.stack_.size() == 0) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
Expand Down
6 changes: 3 additions & 3 deletions c10/core/impl/TorchDispatchModeTLS.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace c10 {
namespace impl {

struct C10_API TorchDispatchModeTLS {
static void push_onto_stack(std::shared_ptr<c10::SafePyObject> mode);
static const std::shared_ptr<c10::SafePyObject> pop_stack();
static const std::shared_ptr<c10::SafePyObject>& get_stack_at(int64_t idx);
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
static int64_t stack_len();

static const TorchDispatchModeTLS& get_state();
Expand Down
5 changes: 2 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,11 +1905,10 @@ def forward(self, x):
# constant to make_fx, and result in the tensor being traced
# into the graph, which is an error condition. Make sure we
# report adequately in this case.
return (torch.add(x, fake_z), )
return (x + fake_z, )

with self.assertRaisesRegex(
TypeError,
"no implementation found for .*FakeTensor"
AssertionError, "Unexpected fake buffer"
):
aot_module_simplified(MockModule(), (fake_x,), nop)

Expand Down
28 changes: 10 additions & 18 deletions test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,16 +947,9 @@ def __init__(self, msg):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA(self.msg)

class B(TorchDispatchMode):
def __init__(self, msg):
self.msg = msg

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA(self.msg)

with self.assertRaisesRegex(ErrorA, "layer2"):
with A("layer1"):
with B("layer2"):
with A("layer2"):
torch.empty([])

def test_make_subclass_with_modes(self):
Expand Down Expand Up @@ -1056,18 +1049,17 @@ def unwrap(t):
with PoliteMode():
a.abs()

def test_nesting_across_instances(self):
# If the pushed mode is a different instance from current mode, we raise
modeA = LoggingTensorMode()
def test_nesting_same_mode(self):
# If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.

def foo():
with modeA, modeA:
torch.empty([])
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode() as reenabled:
with reenabled:
torch.empty([])
self.assertExpectedInline('\n'.join(logs), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")

self.assertExpectedRaisesInline(
AssertionError, lambda: foo(),
"""Illegal attempt to push an already pushed mode onto the stack"""
)

def test_error_using_class_method_on_mode(self):
class A(TorchDispatchMode):
Expand Down
7 changes: 0 additions & 7 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,13 +635,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
assert fake_mode is arg.fake_mode, "Mixing modes NYI"

assert fake_mode is not None

# if we've hit this instead of the mode, then a higher pri mode must
# have returned NotImplemented. Redispatching will cause an infinite
# loop but one of the other args may be a supported subclass
if hasattr(fake_mode, "tracking") and fake_mode.tracking.on_stack:
return NotImplemented

with fake_mode: # type: ignore[attr-defined]
return func(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion torch/csrc/autograd/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_torch_function_mode.h>
#include <torch/csrc/utils/torch_dispatch_mode.h>

#include <set>
#include <unordered_set>
Expand Down
43 changes: 0 additions & 43 deletions torch/csrc/autograd/python_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,6 @@ struct ConcretePyInterpreterVTable final
CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event);
}

void mode_state_push_trampoline(
std::shared_ptr<c10::SafePyObject> mode) const override;
void mode_state_pop_trampoline(
std::shared_ptr<c10::SafePyObject> mode) const override;

static ConcretePyInterpreterVTable* instance() {
static ConcretePyInterpreterVTable s;
return &s;
Expand Down Expand Up @@ -2812,42 +2807,4 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
END_HANDLE_TH_ERRORS_PYBIND
}

void ConcretePyInterpreterVTable::mode_state_push_trampoline(
const std::shared_ptr<SafePyObject> mode) const {
PyObject* mode_obj = mode->ptr(getPyInterpreter());
const char* check_mode_push_name = "check_mode_state_push";
py::gil_scoped_acquire acquire;

py::object run_function =
PyObject_FastGetAttrString(mode_obj, check_mode_push_name);
if (!run_function) {
TORCH_INTERNAL_ASSERT(0);
}

const auto ret = py::reinterpret_steal<py::object>(
PyObject_CallMethod(mode_obj, check_mode_push_name, ""));
if (ret.ptr() == nullptr) {
throw python_error();
}
}

void ConcretePyInterpreterVTable::mode_state_pop_trampoline(
const std::shared_ptr<SafePyObject> mode) const {
PyObject* mode_obj = mode->ptr(getPyInterpreter());
const char* check_mode_pop_name = "check_mode_state_pop";
py::gil_scoped_acquire acquire;

const auto run_function =
PyObject_FastGetAttrString(mode_obj, check_mode_pop_name);
if (!run_function) {
TORCH_INTERNAL_ASSERT(0);
}

const auto ret = py::reinterpret_steal<py::object>(
PyObject_CallMethod(mode_obj, check_mode_pop_name, ""));
if (ret.ptr() == nullptr) {
throw python_error();
}
}

} // anonymous namespace
48 changes: 0 additions & 48 deletions torch/csrc/utils/ConcreteModePyObjTrampoline.h

This file was deleted.

19 changes: 1 addition & 18 deletions torch/csrc/utils/torch_dispatch_mode.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include <c10/core/impl/PyInterpreter.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>

namespace torch {
Expand All @@ -13,23 +12,15 @@ struct StashTorchDispatchModeGuard {
}

~StashTorchDispatchModeGuard() {
// since we're in the destructor, there might be active exceptions.
// This temporarily removes them in order to update the state of the mode
// before putting it back on the stack

// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
c10::impl::TorchDispatchModeTLS::push_onto_stack(std::move(saved_mode_));
PyErr_Restore(type, value, traceback);
}

const std::shared_ptr<c10::SafePyObject>& get_cur_mode() {
return saved_mode_;
}

private:
std::shared_ptr<c10::SafePyObject> saved_mode_;
std::shared_ptr<at::SafePyObject> saved_mode_;
};

struct StashTorchDispatchStackGuard {
Expand All @@ -41,15 +32,7 @@ struct StashTorchDispatchStackGuard {
}

~StashTorchDispatchStackGuard() {
// since we're in the destructor, there might be active exceptions.
// This temporarily removes them in order to update the state of modes
// on the stack before putting them back

// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_));
PyErr_Restore(type, value, traceback);
}

private:
Expand Down
5 changes: 1 addition & 4 deletions torch/fx/passes/fake_tensor_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

__all__ = ['FakeTensorProp']

from torch.utils._python_dispatch import push_if_not_on_stack


@compatibility(is_backward_compatible=False)
class FakeTensorProp(torch.fx.Interpreter):
"""
Expand All @@ -36,6 +33,6 @@ def run_node(self, n: Node):
return result

def propagate(self, *args):
with push_if_not_on_stack(self._mode):
with self._mode:
fake_args = [self._mode.from_tensor(a) for a in args]
return super().run(*fake_args)
20 changes: 0 additions & 20 deletions torch/utils/_python_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import contextlib

import warnings
import threading

from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack

Expand Down Expand Up @@ -53,20 +51,6 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
_pop_mode()

def check_mode_state_push(self):
if not hasattr(self, "tracking"):
self.tracking = threading.local()
else:
if hasattr(self.tracking, "on_stack"):
assert self.tracking.on_stack is False, "Illegal attempt to push an already pushed mode onto the stack"

self.tracking.on_stack = True

def check_mode_state_pop(self):
assert hasattr(self, "tracking"), "Unexpected, popping a mode we are not tracking"
assert self.tracking.on_stack is True, "Unexpected, popping a mode that thinks its already off the stack"
self.tracking.on_stack = False

@classmethod
def push(cls, *args, **kwargs):
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
Expand Down Expand Up @@ -99,10 +83,6 @@ def _pop_mode_temporarily():
_push_mode(old)


def push_if_not_on_stack(mode):
return contextlib.nullcontext() if hasattr(mode, "tracking") and mode.tracking.on_stack else mode


@contextlib.contextmanager
def _disable_current_modes():
mode_len = _len_torch_dispatch_stack()
Expand Down

0 comments on commit db466ae

Please sign in to comment.