Skip to content

Commit

Permalink
fix garbage collection in inheritance cases (#4563)
Browse files Browse the repository at this point in the history
* fix garbage collection in inheritance cases

* clippy fixes

* more clippy fixups

* newsfragment

* use `get_slot` helper for reading slots

* fixup abi3 case

* fix new `to_object` deprecation warnings

* fix MSRV build
  • Loading branch information
davidhewitt committed Oct 12, 2024
1 parent 8b23397 commit 3330bf2
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 19 deletions.
1 change: 1 addition & 0 deletions newsfragments/4563.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `__traverse__` functions for base classes not being called by subclasses created with `#[pyclass(extends = ...)]`.
53 changes: 51 additions & 2 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ impl PyMethodKind {
"__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)),
"__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)),
"__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)),
// Protocols implemented through traits
"__getattribute__" => {
PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GETATTRIBUTE__))
Expand Down Expand Up @@ -146,6 +145,7 @@ impl PyMethodKind {
// Some tricky protocols which don't fit the pattern of the rest
"__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call),
"__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse),
"__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Clear),
// Not a proto
_ => PyMethodKind::Fn,
}
Expand All @@ -156,6 +156,7 @@ enum PyMethodProtoKind {
Slot(&'static SlotDef),
Call,
Traverse,
Clear,
SlotFragment(&'static SlotFragmentDef),
}

Expand Down Expand Up @@ -218,6 +219,9 @@ pub fn gen_py_method(
PyMethodProtoKind::Traverse => {
GeneratedPyMethod::Proto(impl_traverse_slot(cls, spec, ctx)?)
}
PyMethodProtoKind::Clear => {
GeneratedPyMethod::Proto(impl_clear_slot(cls, spec, ctx)?)
}
PyMethodProtoKind::SlotFragment(slot_fragment_def) => {
let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec, ctx)?;
GeneratedPyMethod::SlotTraitImpl(method.method_name, proto)
Expand Down Expand Up @@ -467,7 +471,7 @@ fn impl_traverse_slot(
visit: #pyo3_path::ffi::visitproc,
arg: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int {
#pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg)
#pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg, #cls::__pymethod_traverse__)
}
};
let slot_def = quote! {
Expand All @@ -482,6 +486,51 @@ fn impl_traverse_slot(
})
}

fn impl_clear_slot(cls: &syn::Type, spec: &FnSpec<'_>, ctx: &Ctx) -> syn::Result<MethodAndSlotDef> {
let Ctx { pyo3_path, .. } = ctx;
let (py_arg, args) = split_off_python_arg(&spec.signature.arguments);
let self_type = match &spec.tp {
FnType::Fn(self_type) => self_type,
_ => bail_spanned!(spec.name.span() => "expected instance method for `__clear__` function"),
};
let mut holders = Holders::new();
let slf = self_type.receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx);

if let [arg, ..] = args {
bail_spanned!(arg.ty().span() => "`__clear__` function expected to have no arguments");
}

let name = &spec.name;
let holders = holders.init_holders(ctx);
let fncall = if py_arg.is_some() {
quote!(#cls::#name(#slf, py))
} else {
quote!(#cls::#name(#slf))
};

let associated_method = quote! {
pub unsafe extern "C" fn __pymethod_clear__(
_slf: *mut #pyo3_path::ffi::PyObject,
) -> ::std::os::raw::c_int {
#pyo3_path::impl_::pymethods::_call_clear(_slf, |py, _slf| {
#holders
let result = #fncall;
#pyo3_path::callback::convert(py, result)
}, #cls::__pymethod_clear__)
}
};
let slot_def = quote! {
#pyo3_path::ffi::PyType_Slot {
slot: #pyo3_path::ffi::Py_tp_clear,
pfunc: #cls::__pymethod_clear__ as #pyo3_path::ffi::inquiry as _
}
};
Ok(MethodAndSlotDef {
associated_method,
slot_def,
})
}

fn impl_py_class_attribute(
cls: &syn::Type,
spec: &FnSpec<'_>,
Expand Down
133 changes: 132 additions & 1 deletion src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ use crate::exceptions::PyStopAsyncIteration;
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout};
use crate::internal::get_slot::{get_slot, TP_BASE, TP_CLEAR, TP_TRAVERSE};
use crate::pycell::impl_::PyClassBorrowChecker as _;
use crate::pycell::{PyBorrowError, PyBorrowMutError};
use crate::pyclass::boolean_struct::False;
use crate::types::any::PyAnyMethods;
#[cfg(feature = "gil-refs")]
use crate::types::{PyModule, PyType};
use crate::types::PyModule;
use crate::types::PyType;
use crate::{
ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef,
PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
Expand All @@ -20,6 +22,8 @@ use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr::null_mut;

use super::trampoline;

/// Python 3.8 and up - __ipow__ has modulo argument correctly populated.
#[cfg(Py_3_8)]
#[repr(transparent)]
Expand Down Expand Up @@ -277,6 +281,7 @@ pub unsafe fn _call_traverse<T>(
impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>,
visit: ffi::visitproc,
arg: *mut c_void,
current_traverse: ffi::traverseproc,
) -> c_int
where
T: PyClass,
Expand All @@ -291,6 +296,11 @@ where
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");
let lock = LockGIL::during_traverse();

let super_retval = call_super_traverse(slf, visit, arg, current_traverse);
if super_retval != 0 {
return super_retval;
}

// SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
// traversal is running so no mutations can occur.
let class_object: &PyClassObject<T> = &*slf.cast();
Expand Down Expand Up @@ -330,6 +340,127 @@ where
retval
}

/// Call super-type traverse method, if necessary.
///
/// Adapted from <https://github.com/cython/cython/blob/7acfb375fb54a033f021b0982a3cd40c34fb22ac/Cython/Utility/ExtensionTypes.c#L386>
///
/// TODO: There are possible optimizations over looking up the base type in this way
/// - if the base type is known in this module, can potentially look it up directly in module state
/// (when we have it)
/// - if the base type is a Python builtin, can jut call the C function directly
/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to
/// tp_alloc where we solve this at compile time
unsafe fn call_super_traverse(
obj: *mut ffi::PyObject,
visit: ffi::visitproc,
arg: *mut c_void,
current_traverse: ffi::traverseproc,
) -> c_int {
// SAFETY: in this function here it's ok to work with raw type objects `ffi::Py_TYPE`
// because the GC is running and so
// - (a) we cannot do refcounting and
// - (b) the type of the object cannot change.
let mut ty = ffi::Py_TYPE(obj);
let mut traverse: Option<ffi::traverseproc>;

// First find the current type by the current_traverse function
loop {
traverse = get_slot(ty, TP_TRAVERSE);
if traverse == Some(current_traverse) {
break;
}
ty = get_slot(ty, TP_BASE);
if ty.is_null() {
// FIXME: return an error if current type not in the MRO? Should be impossible.
return 0;
}
}

// Get first base which has a different traverse function
while traverse == Some(current_traverse) {
ty = get_slot(ty, TP_BASE);
if ty.is_null() {
break;
}
traverse = get_slot(ty, TP_TRAVERSE);
}

// If we found a type with a different traverse function, call it
if let Some(traverse) = traverse {
return traverse(obj, visit, arg);
}

// FIXME same question as cython: what if the current type is not in the MRO?
0
}

/// Calls an implementation of __clear__ for tp_clear
pub unsafe fn _call_clear(
slf: *mut ffi::PyObject,
impl_: for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult<()>,
current_clear: ffi::inquiry,
) -> c_int {
trampoline::trampoline(move |py| {
let super_retval = call_super_clear(py, slf, current_clear);
if super_retval != 0 {
return Err(PyErr::fetch(py));
}
impl_(py, slf)?;
Ok(0)
})
}

/// Call super-type traverse method, if necessary.
///
/// Adapted from <https://github.com/cython/cython/blob/7acfb375fb54a033f021b0982a3cd40c34fb22ac/Cython/Utility/ExtensionTypes.c#L386>
///
/// TODO: There are possible optimizations over looking up the base type in this way
/// - if the base type is known in this module, can potentially look it up directly in module state
/// (when we have it)
/// - if the base type is a Python builtin, can jut call the C function directly
/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to
/// tp_alloc where we solve this at compile time
unsafe fn call_super_clear(
py: Python<'_>,
obj: *mut ffi::PyObject,
current_clear: ffi::inquiry,
) -> c_int {
let mut ty = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(obj));
let mut clear: Option<ffi::inquiry>;

// First find the current type by the current_clear function
loop {
clear = ty.get_slot(TP_CLEAR);
if clear == Some(current_clear) {
break;
}
let base = ty.get_slot(TP_BASE);
if base.is_null() {
// FIXME: return an error if current type not in the MRO? Should be impossible.
return 0;
}
ty = PyType::from_borrowed_type_ptr(py, base);
}

// Get first base which has a different clear function
while clear == Some(current_clear) {
let base = ty.get_slot(TP_BASE);
if base.is_null() {
break;
}
ty = PyType::from_borrowed_type_ptr(py, base);
clear = ty.get_slot(TP_CLEAR);
}

// If we found a type with a different clear function, call it
if let Some(clear) = clear {
return clear(obj);
}

// FIXME same question as cython: what if the current type is not in the MRO?
0
}

// Autoref-based specialization for handling `__next__` returning `Option`

pub struct IterBaseTag;
Expand Down
74 changes: 61 additions & 13 deletions src/internal/get_slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ impl Bound<'_, PyType> {
where
Slot<S>: GetSlotImpl,
{
slot.get_slot(self.as_borrowed())
// SAFETY: `self` is a valid type object.
unsafe {
slot.get_slot(
self.as_type_ptr(),
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
is_runtime_3_10(self.py()),
)
}
}
}

Expand All @@ -21,13 +28,50 @@ impl Borrowed<'_, '_, PyType> {
where
Slot<S>: GetSlotImpl,
{
slot.get_slot(self)
// SAFETY: `self` is a valid type object.
unsafe {
slot.get_slot(
self.as_type_ptr(),
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
is_runtime_3_10(self.py()),
)
}
}
}

/// Gets a slot from a raw FFI pointer.
///
/// Safety:
/// - `ty` must be a valid non-null pointer to a `PyTypeObject`.
/// - The Python runtime must be initialized
pub(crate) unsafe fn get_slot<const S: c_int>(
ty: *mut ffi::PyTypeObject,
slot: Slot<S>,
) -> <Slot<S> as GetSlotImpl>::Type
where
Slot<S>: GetSlotImpl,
{
slot.get_slot(
ty,
// SAFETY: the Python runtime is initialized
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
is_runtime_3_10(crate::Python::assume_gil_acquired()),
)
}

pub(crate) trait GetSlotImpl {
type Type;
fn get_slot(self, tp: Borrowed<'_, '_, PyType>) -> Self::Type;

/// Gets the requested slot from a type object.
///
/// Safety:
/// - `ty` must be a valid non-null pointer to a `PyTypeObject`.
/// - `is_runtime_3_10` must be `false` if the runtime is not Python 3.10 or later.
unsafe fn get_slot(
self,
ty: *mut ffi::PyTypeObject,
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))] is_runtime_3_10: bool,
) -> Self::Type;
}

#[derive(Copy, Clone)]
Expand All @@ -42,12 +86,14 @@ macro_rules! impl_slots {
type Type = $tp;

#[inline]
fn get_slot(self, tp: Borrowed<'_, '_, PyType>) -> Self::Type {
let ptr = tp.as_type_ptr();

unsafe fn get_slot(
self,
ty: *mut ffi::PyTypeObject,
#[cfg(all(Py_LIMITED_API, not(Py_3_10)))] is_runtime_3_10: bool
) -> Self::Type {
#[cfg(not(Py_LIMITED_API))]
unsafe {
(*ptr).$field
{
(*ty).$field
}

#[cfg(Py_LIMITED_API)]
Expand All @@ -59,27 +105,29 @@ macro_rules! impl_slots {
// (3.7, 3.8, 3.9) and then look in the type object anyway. This is only ok
// because we know that the interpreter is not going to change the size
// of the type objects for these historical versions.
if !is_runtime_3_10(tp.py())
&& unsafe { ffi::PyType_HasFeature(ptr, ffi::Py_TPFLAGS_HEAPTYPE) } == 0
if !is_runtime_3_10 && ffi::PyType_HasFeature(ty, ffi::Py_TPFLAGS_HEAPTYPE) == 0
{
return unsafe { (*ptr.cast::<PyTypeObject39Snapshot>()).$field };
return (*ty.cast::<PyTypeObject39Snapshot>()).$field;
}
}

// SAFETY: slot type is set carefully to be valid
unsafe { std::mem::transmute(ffi::PyType_GetSlot(ptr, ffi::$slot)) }
std::mem::transmute(ffi::PyType_GetSlot(ty, ffi::$slot))
}
}
}
)*
};
}

// Slots are implemented on-demand as needed.
// Slots are implemented on-demand as needed.)
impl_slots! {
TP_ALLOC: (Py_tp_alloc, tp_alloc) -> Option<ffi::allocfunc>,
TP_BASE: (Py_tp_base, tp_base) -> *mut ffi::PyTypeObject,
TP_CLEAR: (Py_tp_clear, tp_clear) -> Option<ffi::inquiry>,
TP_DESCR_GET: (Py_tp_descr_get, tp_descr_get) -> Option<ffi::descrgetfunc>,
TP_FREE: (Py_tp_free, tp_free) -> Option<ffi::freefunc>,
TP_TRAVERSE: (Py_tp_traverse, tp_traverse) -> Option<ffi::traverseproc>,
}

#[cfg(all(Py_LIMITED_API, not(Py_3_10)))]
Expand Down
Loading

0 comments on commit 3330bf2

Please sign in to comment.