Skip to content

Commit

Permalink
[FFI] Propagate Python errors across FFI boundaries (#15596)
Browse files Browse the repository at this point in the history
* [Runtime] Re-organize BacktraceFullCallback

Prior to this commit, the `BacktraceFullCallback` function returned
zero for frames that should be excluded from the stack trace, even if
they were the `"TVMFuncCall"` that should terminate the stack trace.

This commit re-organized `BacktraceFullCallback`, moving the
terminating checks to occur prior to the suppression checks, and
adding comments to indicate why each suppression is present.

* [FFI] Propagate Python errors across FFI boundaries

Prior to this commit, if a Python script passes a callback to a C++
function, and that callback raises an exception, the original
exception cannot be caught in the outer python script.  As a result,
interactive debugging, such as done with `pdb` or `ipdb`, could only
inspect stack frames outside the first Python to C++ FFI call.

This commit updates the FFI API to propagate the Python exception
through an FFI boundary.  As a result, all Python frames in the stack
trace can be inspected.

* Updated unit tests that depended on exception coercion.

Previously, Python exceptions were coerced to `tvm.error.TVMError` if
no corresponding error type had been registered with
`tvm._ffi.register_error`.  With the pass-through of Python
exceptions, this coercion no longer applies.  Unit tests that relied
on this coercion needed to be updated as a result.


---------

Co-authored-by: Chris Sullivan <csullivan@octoml.ai>
  • Loading branch information
Lunderberg and csullivan authored Sep 8, 2023
1 parent 666bd14 commit d5a4f66
Show file tree
Hide file tree
Showing 16 changed files with 653 additions and 115 deletions.
7 changes: 7 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ typedef void* TVMObjectHandle;
*/
TVM_DLL void TVMAPISetLastError(const char* msg);

/*!
* \brief Used for implementing C API function.
* Set last exception before return.
* \param py_object The python exception to be set
*/
TVM_DLL void TVMAPISetLastPythonError(void* py_object);

/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
Expand Down
45 changes: 45 additions & 0 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,51 @@ namespace runtime {
*/
TVM_DLL void EnvCheckSignals();

/*! \brief A class that wraps a Python object and preserves its ownership.
* This class is used to wrap a PyObject* from the Python API and preserve its ownership.
* Allows for the creation of strong references to Python objects, which prevent them from being
* garbage-collected as long as the wrapper object exists.
*/
class WrappedPythonObject {
public:
/*! \brief Construct a wrapper that doesn't own anything */
WrappedPythonObject() : python_obj_(nullptr) {}

/*! \brief Conversion constructor from nullptr */
explicit WrappedPythonObject(std::nullptr_t) : python_obj_(nullptr) {}

/*! \brief Take ownership of a python object
*
* A new strong reference is created for the underlying python
* object.
*
* \param python_obj A PyObject* from the Python.h API. A new
* strong reference is created using Py_IncRef.
*/
explicit WrappedPythonObject(void* python_obj);

/*! \brief Drop ownership of a python object
*
* Removes the strong reference held by the wrapper.
*/
~WrappedPythonObject();

WrappedPythonObject(WrappedPythonObject&&);
WrappedPythonObject& operator=(WrappedPythonObject&&);

WrappedPythonObject(const WrappedPythonObject&);
WrappedPythonObject& operator=(const WrappedPythonObject&);
WrappedPythonObject& operator=(std::nullptr_t);

operator bool() { return python_obj_; }

void* raw_pointer() { return python_obj_; }

private:
void* python_obj_ = nullptr;
};

/*! \brief Registry for global function */
class Registry {
public:
Expand Down
26 changes: 18 additions & 8 deletions python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import traceback
from numbers import Number, Integral

from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call, raise_last_ffi_error
from ..base import c_str, string_types
from ..runtime_ctypes import DataType, TVMByteArray, Device, ObjectRValueRef
from . import ndarray as _nd
Expand Down Expand Up @@ -80,10 +80,11 @@ def cfun(args, type_codes, num_args, ret, _):
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
except Exception as err:
msg = traceback.format_exc()
msg = py2cerror(msg)
_LIB.TVMAPISetLastError(c_str(msg))
_LIB.TVMAPISetLastPythonError(ctypes.py_object(err))

return -1

if rv is not None:
Expand All @@ -94,7 +95,7 @@ def cfun(args, type_codes, num_args, ret, _):
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
_ = temp_args
_ = rv
return 0
Expand All @@ -106,7 +107,7 @@ def cfun(args, type_codes, num_args, ret, _):
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
return _make_packed_func(handle, False)


Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(self, handle, is_global):
def __del__(self):
if not self.is_global and _LIB is not None:
if _LIB.TVMFuncFree(self.handle) != 0:
raise get_last_ffi_error()
raise_last_ffi_error()

def __call__(self, *args):
"""Call the function with positional arguments
Expand All @@ -235,7 +236,7 @@ def __call__(self, *args):
)
!= 0
):
raise get_last_ffi_error()
raise_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
Expand All @@ -258,7 +259,7 @@ def __init_handle_by_constructor__(fconstructor, args):
)
!= 0
):
raise get_last_ffi_error()
raise_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE
Expand Down Expand Up @@ -333,3 +334,12 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object


def _init_pythonapi_inc_def_ref():
register_func = _LIB.TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), ctypes.pythonapi.Py_IncRef)
register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef)


_init_pythonapi_inc_def_ref()
5 changes: 3 additions & 2 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from ..base import get_last_ffi_error
from ..base import raise_last_ffi_error
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
Expand Down Expand Up @@ -113,6 +113,7 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
# We mark the possibly long running function as nogil below.
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
void TVMAPISetLastPythonError(void* py_object) except +
const char *TVMGetLastError()
int TVMFuncGetGlobal(const char* name,
TVMPackedFuncHandle* out)
Expand Down Expand Up @@ -178,7 +179,7 @@ cdef inline int CHECK_CALL(int ret) except -2:
if ret == -2:
return -2
if ret != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
return 0


Expand Down
19 changes: 17 additions & 2 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ cdef int tvm_callback(TVMValue* args,
pyargs.append(c_make_array(value.v_handle, True, False))
try:
rv = local_pyfunc(*pyargs)
except Exception:
except Exception as err:
msg = traceback.format_exc()
msg = py2cerror(msg)
TVMAPISetLastError(c_str(msg))
TVMAPISetLastPythonError(<void*>err)

return -1
if rv is not None:
if isinstance(rv, tuple):
Expand Down Expand Up @@ -368,3 +369,17 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object

# Py_INCREF and Py_DECREF are C macros, not function objects.
# Therefore, providing a wrapper function.
cdef void _py_incref_wrapper(void* py_object):
Py_INCREF(<object>py_object)
cdef void _py_decref_wrapper(void* py_object):
Py_DECREF(<object>py_object)

def _init_pythonapi_inc_def_ref():
register_func = TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)

_init_pythonapi_inc_def_ref()
149 changes: 146 additions & 3 deletions python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
# coding: utf-8
# pylint: disable=invalid-name, import-outside-toplevel
"""Base library for TVM FFI."""
import sys
import os
import ctypes
import functools
import os
import re
import sys
import types

from typing import Callable, Sequence

import numpy as np

from . import libinfo

# ----------------------------
Expand Down Expand Up @@ -333,6 +340,142 @@ def get_last_ffi_error():
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)


def _append_traceback_frame(tb, func_name, filepath, lineno):
"""Append a dummy frame to appear in the Python traceback"""

# Compile a dummy function to Python bytecode, so that with the
# filepath that we want to appear in the traceback. Any external
# debugger (e.g. pdb) that catches the exception will use the
# filepath to show code snippets from that FFI file.
code = compile(
"{}def dummy_func(): raise NotImplementedError()".format("\n" * (lineno - 1)),
filepath,
"exec",
)

# Replacing the name by updating the bytecode allows the function
# name to be values that would normally be forbidden by python
# syntax. For example, "operator()".
code = code.replace(co_consts=(code.co_consts[0].replace(co_name=func_name), func_name, None))
namespace = {}
exec(code, namespace) # pylint: disable=exec-used
dummy_func = namespace["dummy_func"]

# Execute the dummy function in order to generate a stack frame.
dummy_tb = None
try:
dummy_func()
except NotImplementedError as err:
dummy_tb = err.__traceback__

# Insert the dummy function into the stack trace.
new_frame = dummy_tb.tb_next
return types.TracebackType(tb, new_frame.tb_frame, new_frame.tb_lasti, new_frame.tb_lineno)


def _filter_traceback_frames(tb, filter_funcs: Sequence[Callable[[types.CodeType], bool]]):
orig = tb
filtered_at_least_one = False
temp_all_frames = []
filtered_frames = []

while tb is not None:
frame_code = tb.tb_frame.f_code
should_remove = any(filter_func(frame_code) for filter_func in filter_funcs)
if not should_remove:
filtered_at_least_one = True
filtered_frames.append(tb)
temp_all_frames.append(tb)
tb = tb.tb_next

if not filtered_at_least_one:
return orig

def _append_frame(tb, next_tb_frame):
return types.TracebackType(
tb, next_tb_frame.tb_frame, next_tb_frame.tb_lasti, next_tb_frame.tb_lineno
)

new_tb = functools.reduce(_append_frame, reversed(filtered_frames))

return new_tb


def raise_last_ffi_error():
"""Raise the previous error from FFI
This should be used instead of `raise get_last_ffi_error()`, as it
handle propagation of errors across an FFI boundary. For example,
if Python passes a callback to a C++ function, and the callback
raises an exception, the re-thrown exception should contain the
full stack trace, not just the stack frames that are above the
outermost FFI call.
"""

_LIB.TVMGetLastPythonError.restype = ctypes.c_void_p
_LIB.TVMGetLastBacktrace.restype = ctypes.c_char_p
py_err = _LIB.TVMGetLastPythonError()
if py_err is None:
c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:]
py_err = ERROR_TYPE.get(err_type, TVMError)(py_err_msg)

else:
# TVMGetLastPythonError returns a PyObject*, with NULL when
# there is no such value. If we annotated the restype as
# ctypes.py_object, we would need to return Py_None from the
# C++ implementation. This would require introducing a
# dependency on libpython that we want to avoid when not in a
# Python environment. Therefore, casting the resulting void*
# pointer to PyObject* using ctypes.
py_err = ctypes.cast(ctypes.c_void_p(py_err), ctypes.py_object).value

tb = py_err.__traceback__

# The py_err.__traceback__ only goes from the location thrown
# up to the next FFI handoff. To have the stacktrace also
# include the C++ side, we need to adjust the __traceback__
# before re-throwing.
backtrace = _LIB.TVMGetLastBacktrace()
if backtrace:
frames = re.split(r"\n\W+\d+:\W+", py_str(backtrace))
frames = frames[1:] # Skip "Stack trace: "

for frame in frames:
if " at " in frame:
func_name, frame = frame.split(" at ", 1)
filename, lineno = frame.rsplit(":", 1)
func_name = func_name.strip()
filename = filename.strip()
lineno = int(lineno.strip())

tb = _append_traceback_frame(tb, func_name, filename, lineno)

# Remove stack frames that provide little benefit to
# debugging. These are only removed from the stack frames
# contained within the exception we are re-raising, and not to
# the stack frames that it will continue to collect.
# Therefore, there may still be a single instance of these
# frames in the outermost Python-to-FFI call.
filter_funcs = [
lambda code: "tvm/_ffi/_ctypes/packed_func.py" in code.co_filename,
lambda code: "tvm/_ffi/base.py" in code.co_filename,
]
tb = _filter_traceback_frames(tb, filter_funcs)

py_err = py_err.with_traceback(tb)

# The exception PyObject may contain a large amount of state,
# including all stack frames that may be inspected in a later
# PDB post-mortem. Therefore, we must make sure to remove the
# underlying PyObject* from the C++ side after we retrieve it.
_LIB.TVMDropLastPythonError()

raise py_err


def check_call(ret):
"""Check the return value of C API call
Expand All @@ -345,4 +488,4 @@ def check_call(ret):
return value from API calls
"""
if ret != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
Loading

0 comments on commit d5a4f66

Please sign in to comment.