From d5a4f66fdc7008805c50550d6cfbfac79b9e8902 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 7 Sep 2023 22:26:22 -0700 Subject: [PATCH] [FFI] Propagate Python errors across FFI boundaries (#15596) * [Runtime] Re-organize BacktraceFullCallback Prior to this commit, the `BacktraceFullCallback` function returned zero for frames that should be excluded from the stack trace, even if they were the `"TVMFuncCall"` that should terminate the stack trace. This commit re-organized `BacktraceFullCallback`, moving the terminating checks to occur prior to the suppression checks, and adding comments to indicate why each suppression is present. * [FFI] Propagate Python errors across FFI boundaries Prior to this commit, if a Python script passes a callback to a C++ function, and that callback raises an exception, the original exception cannot be caught in the outer python script. As a result, interactive debugging, such as done with `pdb` or `ipdb`, could only inspect stack frames outside the first Python to C++ FFI call. This commit updates the FFI API to propagate the Python exception through an FFI boundary. As a result, all Python frames in the stack trace can be inspected. * Updated unit tests that depended on exception coercion. Previously, Python exceptions were coerced to `tvm.error.TVMError` if no corresponding error type had been registered with `tvm._ffi.register_error`. With the pass-through of Python exceptions, this coercion no longer applies. Unit tests that relied on this coercion needed to be updated as a result. --------- Co-authored-by: Chris Sullivan --- include/tvm/runtime/c_runtime_api.h | 7 + include/tvm/runtime/registry.h | 45 ++++++ python/tvm/_ffi/_ctypes/packed_func.py | 26 ++- python/tvm/_ffi/_cython/base.pxi | 5 +- python/tvm/_ffi/_cython/packed_func.pxi | 19 ++- python/tvm/_ffi/base.py | 149 +++++++++++++++++- src/ir/transform.cc | 88 +++++++---- src/relay/analysis/type_solver.cc | 2 - src/runtime/c_runtime_api.cc | 94 ++++++++++- src/runtime/logging.cc | 137 +++++++++++----- src/runtime/registry.cc | 62 +++++++- src/support/ffi_testing.cc | 12 ++ tests/python/relay/test_pass_instrument.py | 16 +- tests/python/relay/test_type_infer.py | 2 +- ...chedule_schedule_rule_apply_custom_rule.py | 2 +- tests/python/unittest/test_runtime_error.py | 102 ++++++++++-- 16 files changed, 653 insertions(+), 115 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 36ae5c6b158e..43cf49948108 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -244,6 +244,13 @@ typedef void* TVMObjectHandle; */ TVM_DLL void TVMAPISetLastError(const char* msg); +/*! + * \brief Used for implementing C API function. + * Set last exception before return. + * \param py_object The python exception to be set + */ +TVM_DLL void TVMAPISetLastPythonError(void* py_object); + /*! * \brief return str message of the last error * all function in this file will return 0 when success diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 3a1e86e87f11..71ea9f4a34e6 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -97,6 +97,51 @@ namespace runtime { */ TVM_DLL void EnvCheckSignals(); +/*! \brief A class that wraps a Python object and preserves its ownership. + + * This class is used to wrap a PyObject* from the Python API and preserve its ownership. + * Allows for the creation of strong references to Python objects, which prevent them from being + * garbage-collected as long as the wrapper object exists. + */ +class WrappedPythonObject { + public: + /*! \brief Construct a wrapper that doesn't own anything */ + WrappedPythonObject() : python_obj_(nullptr) {} + + /*! \brief Conversion constructor from nullptr */ + explicit WrappedPythonObject(std::nullptr_t) : python_obj_(nullptr) {} + + /*! \brief Take ownership of a python object + * + * A new strong reference is created for the underlying python + * object. + * + * \param python_obj A PyObject* from the Python.h API. A new + * strong reference is created using Py_IncRef. + */ + explicit WrappedPythonObject(void* python_obj); + + /*! \brief Drop ownership of a python object + * + * Removes the strong reference held by the wrapper. + */ + ~WrappedPythonObject(); + + WrappedPythonObject(WrappedPythonObject&&); + WrappedPythonObject& operator=(WrappedPythonObject&&); + + WrappedPythonObject(const WrappedPythonObject&); + WrappedPythonObject& operator=(const WrappedPythonObject&); + WrappedPythonObject& operator=(std::nullptr_t); + + operator bool() { return python_obj_; } + + void* raw_pointer() { return python_obj_; } + + private: + void* python_obj_ = nullptr; +}; + /*! \brief Registry for global function */ class Registry { public: diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 32ffe3d8c605..e8680afcdf98 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -22,7 +22,7 @@ import traceback from numbers import Number, Integral -from ..base import _LIB, get_last_ffi_error, py2cerror, check_call +from ..base import _LIB, get_last_ffi_error, py2cerror, check_call, raise_last_ffi_error from ..base import c_str, string_types from ..runtime_ctypes import DataType, TVMByteArray, Device, ObjectRValueRef from . import ndarray as _nd @@ -80,10 +80,11 @@ def cfun(args, type_codes, num_args, ret, _): # pylint: disable=broad-except try: rv = local_pyfunc(*pyargs) - except Exception: + except Exception as err: msg = traceback.format_exc() msg = py2cerror(msg) - _LIB.TVMAPISetLastError(c_str(msg)) + _LIB.TVMAPISetLastPythonError(ctypes.py_object(err)) + return -1 if rv is not None: @@ -94,7 +95,7 @@ def cfun(args, type_codes, num_args, ret, _): if not isinstance(ret, TVMRetValueHandle): ret = TVMRetValueHandle(ret) if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0: - raise get_last_ffi_error() + raise_last_ffi_error() _ = temp_args _ = rv return 0 @@ -106,7 +107,7 @@ def cfun(args, type_codes, num_args, ret, _): pyobj = ctypes.py_object(f) ctypes.pythonapi.Py_IncRef(pyobj) if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0: - raise get_last_ffi_error() + raise_last_ffi_error() return _make_packed_func(handle, False) @@ -212,7 +213,7 @@ def __init__(self, handle, is_global): def __del__(self): if not self.is_global and _LIB is not None: if _LIB.TVMFuncFree(self.handle) != 0: - raise get_last_ffi_error() + raise_last_ffi_error() def __call__(self, *args): """Call the function with positional arguments @@ -235,7 +236,7 @@ def __call__(self, *args): ) != 0 ): - raise get_last_ffi_error() + raise_last_ffi_error() _ = temp_args _ = args return RETURN_SWITCH[ret_tcode.value](ret_val) @@ -258,7 +259,7 @@ def __init_handle_by_constructor__(fconstructor, args): ) != 0 ): - raise get_last_ffi_error() + raise_last_ffi_error() _ = temp_args _ = args assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE @@ -333,3 +334,12 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object): global _FUNC_CONVERT_TO_OBJECT _CLASS_OBJECT_GENERIC = object_generic_class _FUNC_CONVERT_TO_OBJECT = func_convert_to_object + + +def _init_pythonapi_inc_def_ref(): + register_func = _LIB.TVMBackendRegisterEnvCAPI + register_func(c_str("Py_IncRef"), ctypes.pythonapi.Py_IncRef) + register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef) + + +_init_pythonapi_inc_def_ref() diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index c2c06674978d..69e1355f7d13 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from ..base import get_last_ffi_error +from ..base import raise_last_ffi_error from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -113,6 +113,7 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle) # We mark the possibly long running function as nogil below. cdef extern from "tvm/runtime/c_runtime_api.h": void TVMAPISetLastError(const char* msg) + void TVMAPISetLastPythonError(void* py_object) except + const char *TVMGetLastError() int TVMFuncGetGlobal(const char* name, TVMPackedFuncHandle* out) @@ -178,7 +179,7 @@ cdef inline int CHECK_CALL(int ret) except -2: if ret == -2: return -2 if ret != 0: - raise get_last_ffi_error() + raise_last_ffi_error() return 0 diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 7c9ef51bd6f8..ae528bcb7828 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -54,10 +54,11 @@ cdef int tvm_callback(TVMValue* args, pyargs.append(c_make_array(value.v_handle, True, False)) try: rv = local_pyfunc(*pyargs) - except Exception: + except Exception as err: msg = traceback.format_exc() msg = py2cerror(msg) - TVMAPISetLastError(c_str(msg)) + TVMAPISetLastPythonError(err) + return -1 if rv is not None: if isinstance(rv, tuple): @@ -368,3 +369,17 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object): global _FUNC_CONVERT_TO_OBJECT _CLASS_OBJECT_GENERIC = object_generic_class _FUNC_CONVERT_TO_OBJECT = func_convert_to_object + +# Py_INCREF and Py_DECREF are C macros, not function objects. +# Therefore, providing a wrapper function. +cdef void _py_incref_wrapper(void* py_object): + Py_INCREF(py_object) +cdef void _py_decref_wrapper(void* py_object): + Py_DECREF(py_object) + +def _init_pythonapi_inc_def_ref(): + register_func = TVMBackendRegisterEnvCAPI + register_func(c_str("Py_IncRef"), _py_incref_wrapper) + register_func(c_str("Py_DecRef"), _py_decref_wrapper) + +_init_pythonapi_inc_def_ref() diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 744e4c93e181..f0eddf8b3636 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -17,10 +17,17 @@ # coding: utf-8 # pylint: disable=invalid-name, import-outside-toplevel """Base library for TVM FFI.""" -import sys -import os import ctypes +import functools +import os +import re +import sys +import types + +from typing import Callable, Sequence + import numpy as np + from . import libinfo # ---------------------------- @@ -333,6 +340,142 @@ def get_last_ffi_error(): return ERROR_TYPE.get(err_type, TVMError)(py_err_msg) +def _append_traceback_frame(tb, func_name, filepath, lineno): + """Append a dummy frame to appear in the Python traceback""" + + # Compile a dummy function to Python bytecode, so that with the + # filepath that we want to appear in the traceback. Any external + # debugger (e.g. pdb) that catches the exception will use the + # filepath to show code snippets from that FFI file. + code = compile( + "{}def dummy_func(): raise NotImplementedError()".format("\n" * (lineno - 1)), + filepath, + "exec", + ) + + # Replacing the name by updating the bytecode allows the function + # name to be values that would normally be forbidden by python + # syntax. For example, "operator()". + code = code.replace(co_consts=(code.co_consts[0].replace(co_name=func_name), func_name, None)) + namespace = {} + exec(code, namespace) # pylint: disable=exec-used + dummy_func = namespace["dummy_func"] + + # Execute the dummy function in order to generate a stack frame. + dummy_tb = None + try: + dummy_func() + except NotImplementedError as err: + dummy_tb = err.__traceback__ + + # Insert the dummy function into the stack trace. + new_frame = dummy_tb.tb_next + return types.TracebackType(tb, new_frame.tb_frame, new_frame.tb_lasti, new_frame.tb_lineno) + + +def _filter_traceback_frames(tb, filter_funcs: Sequence[Callable[[types.CodeType], bool]]): + orig = tb + filtered_at_least_one = False + temp_all_frames = [] + filtered_frames = [] + + while tb is not None: + frame_code = tb.tb_frame.f_code + should_remove = any(filter_func(frame_code) for filter_func in filter_funcs) + if not should_remove: + filtered_at_least_one = True + filtered_frames.append(tb) + temp_all_frames.append(tb) + tb = tb.tb_next + + if not filtered_at_least_one: + return orig + + def _append_frame(tb, next_tb_frame): + return types.TracebackType( + tb, next_tb_frame.tb_frame, next_tb_frame.tb_lasti, next_tb_frame.tb_lineno + ) + + new_tb = functools.reduce(_append_frame, reversed(filtered_frames)) + + return new_tb + + +def raise_last_ffi_error(): + """Raise the previous error from FFI + + This should be used instead of `raise get_last_ffi_error()`, as it + handle propagation of errors across an FFI boundary. For example, + if Python passes a callback to a C++ function, and the callback + raises an exception, the re-thrown exception should contain the + full stack trace, not just the stack frames that are above the + outermost FFI call. + """ + + _LIB.TVMGetLastPythonError.restype = ctypes.c_void_p + _LIB.TVMGetLastBacktrace.restype = ctypes.c_char_p + py_err = _LIB.TVMGetLastPythonError() + if py_err is None: + c_err_msg = py_str(_LIB.TVMGetLastError()) + py_err_msg, err_type = c2pyerror(c_err_msg) + if err_type is not None and err_type.startswith("tvm.error."): + err_type = err_type[10:] + py_err = ERROR_TYPE.get(err_type, TVMError)(py_err_msg) + + else: + # TVMGetLastPythonError returns a PyObject*, with NULL when + # there is no such value. If we annotated the restype as + # ctypes.py_object, we would need to return Py_None from the + # C++ implementation. This would require introducing a + # dependency on libpython that we want to avoid when not in a + # Python environment. Therefore, casting the resulting void* + # pointer to PyObject* using ctypes. + py_err = ctypes.cast(ctypes.c_void_p(py_err), ctypes.py_object).value + + tb = py_err.__traceback__ + + # The py_err.__traceback__ only goes from the location thrown + # up to the next FFI handoff. To have the stacktrace also + # include the C++ side, we need to adjust the __traceback__ + # before re-throwing. + backtrace = _LIB.TVMGetLastBacktrace() + if backtrace: + frames = re.split(r"\n\W+\d+:\W+", py_str(backtrace)) + frames = frames[1:] # Skip "Stack trace: " + + for frame in frames: + if " at " in frame: + func_name, frame = frame.split(" at ", 1) + filename, lineno = frame.rsplit(":", 1) + func_name = func_name.strip() + filename = filename.strip() + lineno = int(lineno.strip()) + + tb = _append_traceback_frame(tb, func_name, filename, lineno) + + # Remove stack frames that provide little benefit to + # debugging. These are only removed from the stack frames + # contained within the exception we are re-raising, and not to + # the stack frames that it will continue to collect. + # Therefore, there may still be a single instance of these + # frames in the outermost Python-to-FFI call. + filter_funcs = [ + lambda code: "tvm/_ffi/_ctypes/packed_func.py" in code.co_filename, + lambda code: "tvm/_ffi/base.py" in code.co_filename, + ] + tb = _filter_traceback_frames(tb, filter_funcs) + + py_err = py_err.with_traceback(tb) + + # The exception PyObject may contain a large amount of state, + # including all stack frames that may be inspected in a later + # PDB post-mortem. Therefore, we must make sure to remove the + # underlying PyObject* from the C++ side after we retrieve it. + _LIB.TVMDropLastPythonError() + + raise py_err + + def check_call(ret): """Check the return value of C API call @@ -345,4 +488,4 @@ def check_call(ret): return value from API calls """ if ret != 0: - raise get_last_ffi_error() + raise_last_ffi_error() diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 66b06e6b505d..9f98977790bb 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -182,44 +182,78 @@ Map> PassContext::ListConfigs() { PassContext PassContext::Create() { return PassContext(make_object()); } +namespace { +struct ClearOnError { + Array* instruments{nullptr}; + + ~ClearOnError() { + if (instruments) { + LOG(INFO) << "Pass instrumentation enter/exti failed."; + LOG(INFO) << "Disabling pass instrumentation."; + instruments->clear(); + } + } +}; +struct ExitContextOnError { + std::vector successes; + + ~ExitContextOnError() { + for (auto it = successes.rbegin(); it != successes.rend(); it++) { + LOG(INFO) << (*it)->name << " exiting PassContext ..."; + (*it)->ExitPassContext(); + LOG(INFO) << (*it)->name << " exited PassContext."; + } + } +}; +} // namespace + void PassContext::InstrumentEnterPassContext() { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { - Array enter_successes; - try { - for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - pi->EnterPassContext(); - enter_successes.push_back(pi); - } - } catch (const Error& e) { - LOG(INFO) << "Pass instrumentation entering pass context failed."; - LOG(INFO) << "Disable pass instrumentation."; - pass_ctx_node->instruments.clear(); - - for (instrument::PassInstrument pi : enter_successes) { - LOG(INFO) << pi->name << " exiting PassContext ..."; - pi->ExitPassContext(); - LOG(INFO) << pi->name << " exited PassContext."; - } - enter_successes.clear(); - - throw e; + ClearOnError clear_context{&pass_ctx_node->instruments}; + ExitContextOnError exit_context; + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->EnterPassContext(); + exit_context.successes.push_back(pi); } + exit_context.successes.clear(); + clear_context.instruments = nullptr; } } +namespace { + +struct ExitPassSuccesses { + ~ExitPassSuccesses() { + if (all_initialized) { + return; + } + + LOG(INFO) << "Pass instrumentation entering pass context failed."; + LOG(INFO) << "Disable pass instrumentation."; + instruments->clear(); + + for (auto it = successes.rbegin(); it != successes.rend(); it++) { + LOG(INFO) << (*it)->name << " exiting PassContext ..."; + (*it)->ExitPassContext(); + LOG(INFO) << (*it)->name << " exited PassContext."; + } + } + + bool all_initialized{false}; + std::vector successes; + Array* instruments{nullptr}; +}; +} // namespace + void PassContext::InstrumentExitPassContext() { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->instruments.defined()) { - try { - for (instrument::PassInstrument pi : pass_ctx_node->instruments) { - pi->ExitPassContext(); - } - } catch (const Error& e) { - LOG(INFO) << "Pass instrumentation exiting pass context failed."; - pass_ctx_node->instruments.clear(); - throw e; + ClearOnError clear_context{&pass_ctx_node->instruments}; + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->ExitPassContext(); } + clear_context.instruments = nullptr; } } diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 79b340390ba1..5bd5698d8321 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -639,8 +639,6 @@ bool TypeSolver::Solve() { } catch (const CompileError& err) { this->Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; - } catch (const Error& e) { - ICHECK(false) << e.what(); } // Mark inqueue as false after the function call diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 93ca8a924a98..980447214a67 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -35,6 +35,8 @@ #include #include #include +#include +#include #include "object_internal.h" #include "runtime_base.h" @@ -368,22 +370,102 @@ std::string NormalizeError(std::string err_msg) { using namespace tvm::runtime; +struct WrappedPythonError : Error { + WrappedPythonError() : Error("") {} + explicit WrappedPythonError(WrappedPythonObject obj) + : Error(""), obj(std::move(obj)), cpp_backtrace(tvm::runtime::Backtrace()) {} + + WrappedPythonObject obj; + std::string cpp_backtrace; +}; + struct TVMRuntimeEntry { std::string ret_str; - std::string last_error; TVMByteArray ret_bytes; + + std::variant last_error; + std::string last_error_formatted; }; typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; -const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); } +const char* TVMGetLastError() { + auto* store = TVMAPIRuntimeStore::Get(); + const auto& last_error = store->last_error; + if (const auto* message = std::get_if(&last_error)) { + return message->c_str(); + } else if (const auto* internal = std::get_if(&last_error)) { + // Use last_error_formatted to store the formatted error message, to avoid + // dangling pointer. + store->last_error_formatted = NormalizeError(internal->full_message()); + return store->last_error_formatted.c_str(); + } else { + return nullptr; + } +} + +extern "C" void* TVMGetLastPythonError() { + auto& last_error = TVMAPIRuntimeStore::Get()->last_error; + if (auto* wrapped = std::get_if(&last_error)) { + return wrapped->obj.raw_pointer(); + } else { + return nullptr; + } +} + +extern "C" const char* TVMGetLastBacktrace() { + const auto& last_error = TVMAPIRuntimeStore::Get()->last_error; + if (const auto* wrapped = std::get_if(&last_error)) { + return wrapped->cpp_backtrace.data(); + } else if (const auto* wrapped = std::get_if(&last_error)) { + return wrapped->backtrace().data(); + } else { + return nullptr; + } +} + +extern "C" void TVMDropLastPythonError() { + auto& last_error = TVMAPIRuntimeStore::Get()->last_error; + if (std::get_if(&last_error)) { + last_error = ""; + } +} int TVMAPIHandleException(const std::exception& e) { - TVMAPISetLastError(NormalizeError(e.what()).c_str()); + auto& last_error = TVMAPIRuntimeStore::Get()->last_error; + + if (const auto* wrapped = dynamic_cast(&e)) { + last_error = *wrapped; + } else if (const auto* internal = dynamic_cast(&e)) { + last_error = *internal; + } else { + last_error = NormalizeError(e.what()); + } return -1; } -void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; } +extern "C" void TVMAPISetLastPythonError(void* obj) { + auto& last_error = TVMAPIRuntimeStore::Get()->last_error; + last_error = WrappedPythonError(WrappedPythonObject(obj)); +} + +void ThrowLastError() { + auto& last_error = TVMAPIRuntimeStore::Get()->last_error; + if (auto* wrapped = std::get_if(&last_error)) { + WrappedPythonError wrapped_err; + std::swap(wrapped_err, *wrapped); + throw wrapped_err; + } else if (auto* internal = std::get_if(&last_error)) { + throw *internal; + } else if (auto* message = std::get_if(&last_error)) { + throw tvm::Error(NormalizeError(*message) + tvm::runtime::Backtrace()); + } +} + +void TVMAPISetLastError(const char* msg) { + auto& last_error = TVMAPIRuntimeStore::Get()->last_error; + last_error = msg; +} int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { API_BEGIN(); @@ -515,7 +597,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, resource_handle); if (ret != 0) { - throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); + ThrowLastError(); } }); TVMValue val; @@ -531,7 +613,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, rpack.get()); if (ret != 0) { - throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); + ThrowLastError(); } }); TVMValue val; diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc index 04b25f764c8a..844a8bcf1cc2 100644 --- a/src/runtime/logging.cc +++ b/src/runtime/logging.cc @@ -94,64 +94,115 @@ void BacktraceSyminfoCallback(void* data, uintptr_t pc, const char* symname, uin int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int lineno, const char* symbol) { - if (filename != nullptr) { - if (strstr(filename, "include/tvm/runtime/packed_func.h") != nullptr || - strstr(filename, "include/tvm/runtime/registry.h") != nullptr || - strstr(filename, "include/tvm/node/functor.h") != nullptr || - strstr(filename, "include/tvm/relax/expr_functor.h") != nullptr || - strstr(filename, "include/tvm/tir/stmt_functor.h") != nullptr || - strstr(filename, "include/tvm/tir/expr_functor.h") != nullptr || - strstr(filename, "include/tvm/node/reflection.h") != nullptr || - strstr(filename, "src/node/structural_equal.cc") != nullptr || - strstr(filename, "src/ir/transform.cc") != nullptr || - strstr(filename, "src/tir/ir/stmt_functor.cc") != nullptr || - strstr(filename, "src/tir/ir/expr_functor.cc") != nullptr || - strstr(filename, "src/relax/ir/expr_functor.cc") != nullptr || - strstr(filename, "src/relax/ir/py_expr_functor.cc") != nullptr || - strstr(filename, "src/runtime/c_runtime_api.cc") != nullptr || - strstr(filename, "/python-") != nullptr || // - strstr(filename, "include/c++/") != nullptr) { - return 0; - } - } - if (symbol != nullptr) { - if (strstr(symbol, "__libc_") != nullptr) { - return 0; - } - } auto stack_trace = reinterpret_cast(data); - std::stringstream s; std::unique_ptr symbol_str = std::make_unique(""); - if (symbol != nullptr) { + if (symbol) { *symbol_str = DemangleName(symbol); } else { // see if syminfo gives anything backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, BacktraceErrorCallback, symbol_str.get()); } - if (filename == nullptr && strstr(symbol_str.get()->data(), "ffi_call_")) { + symbol = symbol_str->data(); + + // TVMFuncCall denotes the API boundary so we stop there. Exceptions + // should be caught there. This is before any frame suppressions, + // as it would otherwise be suppressed. + bool should_stop_collecting = + (*symbol_str == "TVMFuncCall" || stack_trace->lines.size() >= stack_trace->max_size); + if (should_stop_collecting) { + return 1; + } + + // Exclude frames that contain little useful information for most + // debugging purposes + bool should_exclude = [&]() -> bool { + if (filename) { + // Stack frames for TVM FFI + if (strstr(filename, "include/tvm/runtime/packed_func.h") || + strstr(filename, "include/tvm/runtime/registry.h") || + strstr(filename, "src/runtime/c_runtime_api.cc")) { + return true; + } + // Stack frames for nested tree recursion. + // tir/ir/stmt_functor.cc and tir/ir/expr_functor.cc define + // Expr/Stmt Visitor/Mutator, which should be suppressed, but + // also Substitute which should not be suppressed. Therefore, + // they are suppressed based on the symbol name. + if (strstr(filename, "include/tvm/node/functor.h") || // + strstr(filename, "include/tvm/relax/expr_functor.h") || // + strstr(filename, "include/tvm/tir/stmt_functor.h") || // + strstr(filename, "include/tvm/tir/expr_functor.h") || // + strstr(filename, "include/tvm/node/reflection.h") || // + strstr(filename, "src/node/structural_equal.cc") || // + strstr(filename, "src/ir/transform.cc") || // + strstr(filename, "src/relax/ir/expr_functor.cc") || // + strstr(filename, "src/relax/ir/py_expr_functor.cc")) { + return true; + } + // Python interpreter stack frames + if (strstr(filename, "/python-") || strstr(filename, "/Python/ceval.c") || + strstr(filename, "/Modules/_ctypes")) { + return true; + } + // C++ stdlib frames + if (strstr(filename, "include/c++/")) { + return true; + } + } + if (symbol) { + // C++ stdlib frames + if (strstr(symbol, "__libc_")) { + return true; + } + // Stack frames for nested tree visiting + if (strstr(symbol, "tvm::tir::StmtMutator::VisitStmt_") || + strstr(symbol, "tvm::tir::ExprMutator::VisitExpr_") || + strstr(symbol, "tvm::tir::IRTransformer::VisitExpr") || + strstr(symbol, "tvm::tir::IRTransformer::VisitStmt") || + strstr(symbol, "tvm::tir::IRTransformer::BaseVisitExpr") || + strstr(symbol, "tvm::tir::IRTransformer::BaseVisitStmt")) { + return true; + } + // Python interpreter stack frames + if (strstr(symbol, "_Py") == symbol || strstr(symbol, "PyObject")) { + return true; + } + } + + // libffi.so stack frames. These may also show up as numeric + // addresses with no symbol name. This could be improved in the + // future by using dladdr() to check whether an address is contained + // in libffi.so + if (filename == nullptr && strstr(symbol, "ffi_call_")) { + return true; + } + + // Skip tvm::backtrace and tvm::LogFatal::~LogFatal at the beginning + // of the trace as they don't add anything useful to the backtrace. + if (stack_trace->lines.size() == 0 && (strstr(symbol, "tvm::runtime::Backtrace") || + strstr(symbol, "tvm::runtime::detail::LogFatal"))) { + return true; + } + + return false; + }(); + if (should_exclude) { return 0; } - s << *symbol_str; - if (filename != nullptr) { - s << std::endl << " at " << filename; + std::stringstream frame_str; + frame_str << *symbol_str; + + if (filename) { + frame_str << std::endl << " at " << filename; if (lineno != 0) { - s << ":" << lineno; + frame_str << ":" << lineno; } } - // Skip tvm::backtrace and tvm::LogFatal::~LogFatal at the beginning of the trace as they don't - // add anything useful to the backtrace. - if (!(stack_trace->lines.size() == 0 && - (symbol_str->find("tvm::runtime::Backtrace", 0) == 0 || - symbol_str->find("tvm::runtime::detail::LogFatal", 0) == 0))) { - stack_trace->lines.push_back(s.str()); - } - // TVMFuncCall denotes the API boundary so we stop there. Exceptions should be caught there. - if (*symbol_str == "TVMFuncCall" || stack_trace->lines.size() >= stack_trace->max_size) { - return 1; - } + stack_trace->lines.push_back(frame_str.str()); + return 0; } diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 84586ff630d6..0db8786145d3 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -128,13 +128,26 @@ class EnvCAPIRegistry { */ typedef int (*F_PyErr_CheckSignals)(); - // NOTE: the following function are only registered - // in a python environment. + /*! \brief Callback to increment/decrement the python ref count */ + typedef void (*F_Py_IncDefRef)(void*); + + // NOTE: the following functions are only registered in a python + // environment. /*! * \brief PyErr_CheckSignal function */ F_PyErr_CheckSignals pyerr_check_signals = nullptr; + /*! + * \brief Py_IncRef function + */ + F_Py_IncDefRef py_inc_ref = nullptr; + + /*! + * \brief Py_IncRef function + */ + F_Py_IncDefRef py_dec_ref = nullptr; + static EnvCAPIRegistry* Global() { static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); return inst; @@ -144,6 +157,10 @@ class EnvCAPIRegistry { void Register(const String& symbol_name, void* fptr) { if (symbol_name == "PyErr_CheckSignals") { Update(symbol_name, &pyerr_check_signals, fptr); + } else if (symbol_name == "Py_IncRef") { + Update(symbol_name, &py_inc_ref, fptr); + } else if (symbol_name == "Py_DecRef") { + Update(symbol_name, &py_dec_ref, fptr); } else { LOG(FATAL) << "Unknown env API " << symbol_name; } @@ -159,6 +176,18 @@ class EnvCAPIRegistry { } } + void IncRef(void* python_obj) { + ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through EnvCAPIRegistry, " + << "but Py_IncRef wasn't registered"; + (*py_inc_ref)(python_obj); + } + + void DecRef(void* python_obj) { + ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through EnvCAPIRegistry, " + << "but Py_IncRef wasn't registered"; + (*py_inc_ref)(python_obj); + } + private: // update the internal API table template @@ -173,6 +202,35 @@ class EnvCAPIRegistry { void EnvCheckSignals() { EnvCAPIRegistry::Global()->CheckSignals(); } +WrappedPythonObject::WrappedPythonObject(void* python_obj) : python_obj_(python_obj) { + if (python_obj_) { + EnvCAPIRegistry::Global()->IncRef(python_obj_); + } +} + +WrappedPythonObject::~WrappedPythonObject() { + if (python_obj_) { + EnvCAPIRegistry::Global()->DecRef(python_obj_); + } +} + +WrappedPythonObject::WrappedPythonObject(WrappedPythonObject&& other) : python_obj_(nullptr) { + std::swap(python_obj_, other.python_obj_); +} +WrappedPythonObject& WrappedPythonObject::operator=(WrappedPythonObject&& other) { + std::swap(python_obj_, other.python_obj_); + return *this; +} + +WrappedPythonObject::WrappedPythonObject(const WrappedPythonObject& other) + : WrappedPythonObject(other.python_obj_) {} +WrappedPythonObject& WrappedPythonObject::operator=(const WrappedPythonObject& other) { + return *this = WrappedPythonObject(other); +} +WrappedPythonObject& WrappedPythonObject::operator=(std::nullptr_t) { + return *this = WrappedPythonObject(nullptr); +} + } // namespace runtime } // namespace tvm diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 6e7dec4cb776..e00b9b8d056a 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -63,6 +63,18 @@ TVM_REGISTER_GLOBAL("testing.test_wrap_callback").set_body([](TVMArgs args, TVMR *ret = runtime::TypedPackedFunc([pf]() { pf(); }); }); +TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") + .set_body([](TVMArgs args, TVMRetValue* ret) { + PackedFunc pf = args[0]; + auto result = runtime::TypedPackedFunc([pf]() { + try { + pf(); + } catch (std::exception& err) { + } + }); + *ret = result; + }); + TVM_REGISTER_GLOBAL("testing.test_raise_error_callback") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string msg = args[0]; diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 83ddc4ef3731..455cf20b5de0 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -226,7 +226,7 @@ def enter_pass_ctx(self): raise RuntimeError("Just a dummy error") pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: with pass_ctx: pass assert "Just a dummy error" in str(cm.execption) @@ -246,7 +246,7 @@ def enter_pass_ctx(self): raise RuntimeError("Just a dummy error") cur_pass_ctx = tvm.transform.PassContext.current() - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: cur_pass_ctx.override_instruments([PIBroken()]) assert "Just a dummy error" in str(cm.exception) assert not cur_pass_ctx.instruments @@ -273,7 +273,7 @@ def exit_pass_ctx(self): raise RuntimeError("Just a dummy error") pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), PIBroken("%2"), PI("%3")]) - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: with pass_ctx: pass assert "Just a dummy error" in str(cm.exception) @@ -293,7 +293,7 @@ def exit_pass_ctx(self): raise RuntimeError("Just a dummy error") cur_pass_ctx = tvm.transform.PassContext.current() - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: cur_pass_ctx.override_instruments([PIBroken()]) cur_pass_ctx.override_instruments([PIBroken()]) assert "Just a dummy error" in str(cm.exception) @@ -328,7 +328,7 @@ def transform(mod, ctx): return mod mod = get_test_model() - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: with tvm.transform.PassContext(instruments=[PI()]): mod = transform(mod) assert "Just a dummy error" in str(cm.exception) @@ -373,7 +373,7 @@ def transform(mod, ctx): return mod mod = get_test_model() - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): mod = transform(mod) assert "Just a dummy error" in str(cm.exception) @@ -418,7 +418,7 @@ def transform(mod, ctx): return mod mod = get_test_model() - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): mod = transform(mod) assert "Just a dummy error" in str(cm.exception) @@ -467,7 +467,7 @@ def transform(mod, ctx): x, y = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xy"] mod = tvm.IRModule.from_expr(tvm.relay.add(x, y)) - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(RuntimeError) as cm: with tvm.transform.PassContext(instruments=[PI("%1"), PI("%2")]): mod = transform(mod) assert "Just a dummy error" in str(cm.exception) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 7fbb656b367a..ec88143db6a6 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -541,7 +541,7 @@ def clog(x): t2 = sb.let("t2", relay.add(t1, x)) sb.ret(t2) f = relay.Function([x], sb.get()) - with pytest.raises(tvm.error.TVMError) as cm: + with pytest.raises(AssertionError) as cm: fchecked = infer_expr(f) assert "type relation arg number mismatch" in str(cm.execption) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py index 2bfa3070d1b4..7222c4d64972 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py @@ -59,7 +59,7 @@ def test_custom_rule(): max_trials_global=10, space=space_gen, ) - assert "ValueError: Intended for meta_schedule.cpu.test_apply_custom_rule" in str(e_info.value) + assert "Intended for meta_schedule.cpu.test_apply_custom_rule" in str(e_info.value) if __name__ == "__main__": diff --git a/tests/python/unittest/test_runtime_error.py b/tests/python/unittest/test_runtime_error.py index 3d7a2180994b..efb373ac87ed 100644 --- a/tests/python/unittest/test_runtime_error.py +++ b/tests/python/unittest/test_runtime_error.py @@ -15,12 +15,19 @@ # specific language governing permissions and limitations # under the License. """Test runtime error handling""" + +import functools +import platform +import subprocess +import traceback + +import pytest + import tvm -from tvm import te import tvm.testing -def test_op_translation(): +def test_op_translation_to_not_implemented(): ferror = tvm.testing.test_raise_error_callback("OpNotImplemented: myop") try: ferror() @@ -30,6 +37,8 @@ def test_op_translation(): assert isinstance(e, NotImplementedError) assert msg.find("ffi_testing.cc") != -1 + +def test_op_translation_to_internal_error(): fchk_eq = tvm.testing.test_check_eq_callback("InternalError: myop") try: fchk_eq(0, 1) @@ -38,6 +47,8 @@ def test_op_translation(): msg = str(e) assert msg.find("ffi_testing.cc") != -1 + +def test_op_translation_to_value_error(): try: tvm.testing.ErrorTest(0, 1) assert False @@ -47,6 +58,18 @@ def test_op_translation(): def test_deep_callback(): + """Propagate python errors through API calls + + If a Python exception is raised, and that exception is caught in + Python, the original exception should be propagated so that the + traceback contains all intermediate python frames. + + Stack + - test_deep_callback + - test + + """ + def error_callback(): raise ValueError("callback error") @@ -65,14 +88,73 @@ def flevel3(): try: wrap3() assert False - except ValueError as e: - msg = str(e) - idx2 = msg.find("in flevel2") - idx3 = msg.find("in flevel3") - assert idx2 != -1 and idx3 != -1 - assert idx2 > idx3 + except ValueError as err: + frames = traceback.extract_tb(err.__traceback__) + + local_frames = [frame.name for frame in frames if frame.filename == __file__] + assert local_frames == ["test_deep_callback", "flevel3", "flevel2", "error_callback"] + + +@functools.lru_cache() +def _has_debug_symbols(): + lib = tvm._ffi.base._LIB + headers = subprocess.check_output(["objdump", "--section-headers", lib._name], encoding="utf-8") + return ".debug" in headers + + +@pytest.mark.skipif( + not _has_debug_symbols() or platform.machine != "x86_64", + reason="C++ stack frames require debug symbols, only implemented for x86", +) +def test_cpp_frames_in_stack_trace_from_python_error(): + """A python exception crossing C++ boundaries should have C++ stack frames""" + + def error_callback(): + raise ValueError("callback error") + + wrapped = tvm.testing.test_wrap_callback(error_callback) + + try: + wrapped() + assert False + except ValueError as err: + frames = traceback.extract_tb(err.__traceback__) + + cpp_frames = [ + frame + for frame in frames + if frame.filename.endswith(".cc") or frame.filename.endswith(".c") + ] + assert len(cpp_frames) >= 1, ( + f"Traceback through files '{[frame.filename for frame in frames]}'" + f" expected to contain C/C++ frames, " + f" but instead caught exception {err}" + ) + + +@pytest.mark.skipif( + not _has_debug_symbols() or platform.machine != "x86_64", + reason="C++ stack frames require debug symbols, only implemented for x86", +) +def test_stack_trace_from_cpp_error(): + """A python exception originating in C++ should have C++ stack frames""" + try: + tvm.testing.ErrorTest(0, 1) + assert False + except ValueError as err: + frames = traceback.extract_tb(err.__traceback__) + + cpp_frames = [ + frame + for frame in frames + if frame.filename.endswith(".cc") or frame.filename.endswith(".c") + ] + assert len(cpp_frames) >= 1, ( + f"Traceback through files '{[frame.filename for frame in frames]}'" + f" expected to contain C/C++ frames, " + f" but instead caught exception {err}" + ) if __name__ == "__main__": - test_op_translation() - test_deep_callback() + tvm.testing.main()