From 3330bf2641144e2dc4d2730470af4d8623001c60 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Fri, 11 Oct 2024 13:36:28 +0100 Subject: [PATCH] fix garbage collection in inheritance cases (#4563) * fix garbage collection in inheritance cases * clippy fixes * more clippy fixups * newsfragment * use `get_slot` helper for reading slots * fixup abi3 case * fix new `to_object` deprecation warnings * fix MSRV build --- newsfragments/4563.fixed.md | 1 + pyo3-macros-backend/src/pymethod.rs | 53 ++++++++++- src/impl_/pymethods.rs | 133 +++++++++++++++++++++++++++- src/internal/get_slot.rs | 74 +++++++++++++--- src/pyclass/create_type_object.rs | 27 +++++- tests/test_gc.rs | 126 ++++++++++++++++++++++++++ 6 files changed, 395 insertions(+), 19 deletions(-) create mode 100644 newsfragments/4563.fixed.md diff --git a/newsfragments/4563.fixed.md b/newsfragments/4563.fixed.md new file mode 100644 index 00000000000..c0249a81a8b --- /dev/null +++ b/newsfragments/4563.fixed.md @@ -0,0 +1 @@ +Fix `__traverse__` functions for base classes not being called by subclasses created with `#[pyclass(extends = ...)]`. diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 2672c6a07ac..cd4fe05562d 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -97,7 +97,6 @@ impl PyMethodKind { "__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)), "__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)), "__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)), - "__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)), // Protocols implemented through traits "__getattribute__" => { PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GETATTRIBUTE__)) @@ -146,6 +145,7 @@ impl PyMethodKind { // Some tricky protocols which don't fit the pattern of the rest "__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call), "__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse), + "__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Clear), // Not a proto _ => PyMethodKind::Fn, } @@ -156,6 +156,7 @@ enum PyMethodProtoKind { Slot(&'static SlotDef), Call, Traverse, + Clear, SlotFragment(&'static SlotFragmentDef), } @@ -218,6 +219,9 @@ pub fn gen_py_method( PyMethodProtoKind::Traverse => { GeneratedPyMethod::Proto(impl_traverse_slot(cls, spec, ctx)?) } + PyMethodProtoKind::Clear => { + GeneratedPyMethod::Proto(impl_clear_slot(cls, spec, ctx)?) + } PyMethodProtoKind::SlotFragment(slot_fragment_def) => { let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec, ctx)?; GeneratedPyMethod::SlotTraitImpl(method.method_name, proto) @@ -467,7 +471,7 @@ fn impl_traverse_slot( visit: #pyo3_path::ffi::visitproc, arg: *mut ::std::os::raw::c_void, ) -> ::std::os::raw::c_int { - #pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg) + #pyo3_path::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg, #cls::__pymethod_traverse__) } }; let slot_def = quote! { @@ -482,6 +486,51 @@ fn impl_traverse_slot( }) } +fn impl_clear_slot(cls: &syn::Type, spec: &FnSpec<'_>, ctx: &Ctx) -> syn::Result { + let Ctx { pyo3_path, .. } = ctx; + let (py_arg, args) = split_off_python_arg(&spec.signature.arguments); + let self_type = match &spec.tp { + FnType::Fn(self_type) => self_type, + _ => bail_spanned!(spec.name.span() => "expected instance method for `__clear__` function"), + }; + let mut holders = Holders::new(); + let slf = self_type.receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx); + + if let [arg, ..] = args { + bail_spanned!(arg.ty().span() => "`__clear__` function expected to have no arguments"); + } + + let name = &spec.name; + let holders = holders.init_holders(ctx); + let fncall = if py_arg.is_some() { + quote!(#cls::#name(#slf, py)) + } else { + quote!(#cls::#name(#slf)) + }; + + let associated_method = quote! { + pub unsafe extern "C" fn __pymethod_clear__( + _slf: *mut #pyo3_path::ffi::PyObject, + ) -> ::std::os::raw::c_int { + #pyo3_path::impl_::pymethods::_call_clear(_slf, |py, _slf| { + #holders + let result = #fncall; + #pyo3_path::callback::convert(py, result) + }, #cls::__pymethod_clear__) + } + }; + let slot_def = quote! { + #pyo3_path::ffi::PyType_Slot { + slot: #pyo3_path::ffi::Py_tp_clear, + pfunc: #cls::__pymethod_clear__ as #pyo3_path::ffi::inquiry as _ + } + }; + Ok(MethodAndSlotDef { + associated_method, + slot_def, + }) +} + fn impl_py_class_attribute( cls: &syn::Type, spec: &FnSpec<'_>, diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index 76b71a3e188..77636051a04 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -3,12 +3,14 @@ use crate::exceptions::PyStopAsyncIteration; use crate::gil::LockGIL; use crate::impl_::panic::PanicTrap; use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout}; +use crate::internal::get_slot::{get_slot, TP_BASE, TP_CLEAR, TP_TRAVERSE}; use crate::pycell::impl_::PyClassBorrowChecker as _; use crate::pycell::{PyBorrowError, PyBorrowMutError}; use crate::pyclass::boolean_struct::False; use crate::types::any::PyAnyMethods; #[cfg(feature = "gil-refs")] -use crate::types::{PyModule, PyType}; +use crate::types::PyModule; +use crate::types::PyType; use crate::{ ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef, PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python, @@ -20,6 +22,8 @@ use std::os::raw::{c_int, c_void}; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::ptr::null_mut; +use super::trampoline; + /// Python 3.8 and up - __ipow__ has modulo argument correctly populated. #[cfg(Py_3_8)] #[repr(transparent)] @@ -277,6 +281,7 @@ pub unsafe fn _call_traverse( impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>, visit: ffi::visitproc, arg: *mut c_void, + current_traverse: ffi::traverseproc, ) -> c_int where T: PyClass, @@ -291,6 +296,11 @@ where let trap = PanicTrap::new("uncaught panic inside __traverse__ handler"); let lock = LockGIL::during_traverse(); + let super_retval = call_super_traverse(slf, visit, arg, current_traverse); + if super_retval != 0 { + return super_retval; + } + // SAFETY: `slf` is a valid Python object pointer to a class object of type T, and // traversal is running so no mutations can occur. let class_object: &PyClassObject = &*slf.cast(); @@ -330,6 +340,127 @@ where retval } +/// Call super-type traverse method, if necessary. +/// +/// Adapted from +/// +/// TODO: There are possible optimizations over looking up the base type in this way +/// - if the base type is known in this module, can potentially look it up directly in module state +/// (when we have it) +/// - if the base type is a Python builtin, can jut call the C function directly +/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to +/// tp_alloc where we solve this at compile time +unsafe fn call_super_traverse( + obj: *mut ffi::PyObject, + visit: ffi::visitproc, + arg: *mut c_void, + current_traverse: ffi::traverseproc, +) -> c_int { + // SAFETY: in this function here it's ok to work with raw type objects `ffi::Py_TYPE` + // because the GC is running and so + // - (a) we cannot do refcounting and + // - (b) the type of the object cannot change. + let mut ty = ffi::Py_TYPE(obj); + let mut traverse: Option; + + // First find the current type by the current_traverse function + loop { + traverse = get_slot(ty, TP_TRAVERSE); + if traverse == Some(current_traverse) { + break; + } + ty = get_slot(ty, TP_BASE); + if ty.is_null() { + // FIXME: return an error if current type not in the MRO? Should be impossible. + return 0; + } + } + + // Get first base which has a different traverse function + while traverse == Some(current_traverse) { + ty = get_slot(ty, TP_BASE); + if ty.is_null() { + break; + } + traverse = get_slot(ty, TP_TRAVERSE); + } + + // If we found a type with a different traverse function, call it + if let Some(traverse) = traverse { + return traverse(obj, visit, arg); + } + + // FIXME same question as cython: what if the current type is not in the MRO? + 0 +} + +/// Calls an implementation of __clear__ for tp_clear +pub unsafe fn _call_clear( + slf: *mut ffi::PyObject, + impl_: for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult<()>, + current_clear: ffi::inquiry, +) -> c_int { + trampoline::trampoline(move |py| { + let super_retval = call_super_clear(py, slf, current_clear); + if super_retval != 0 { + return Err(PyErr::fetch(py)); + } + impl_(py, slf)?; + Ok(0) + }) +} + +/// Call super-type traverse method, if necessary. +/// +/// Adapted from +/// +/// TODO: There are possible optimizations over looking up the base type in this way +/// - if the base type is known in this module, can potentially look it up directly in module state +/// (when we have it) +/// - if the base type is a Python builtin, can jut call the C function directly +/// - if the base type is a PyO3 type defined in the same module, can potentially do similar to +/// tp_alloc where we solve this at compile time +unsafe fn call_super_clear( + py: Python<'_>, + obj: *mut ffi::PyObject, + current_clear: ffi::inquiry, +) -> c_int { + let mut ty = PyType::from_borrowed_type_ptr(py, ffi::Py_TYPE(obj)); + let mut clear: Option; + + // First find the current type by the current_clear function + loop { + clear = ty.get_slot(TP_CLEAR); + if clear == Some(current_clear) { + break; + } + let base = ty.get_slot(TP_BASE); + if base.is_null() { + // FIXME: return an error if current type not in the MRO? Should be impossible. + return 0; + } + ty = PyType::from_borrowed_type_ptr(py, base); + } + + // Get first base which has a different clear function + while clear == Some(current_clear) { + let base = ty.get_slot(TP_BASE); + if base.is_null() { + break; + } + ty = PyType::from_borrowed_type_ptr(py, base); + clear = ty.get_slot(TP_CLEAR); + } + + // If we found a type with a different clear function, call it + if let Some(clear) = clear { + return clear(obj); + } + + // FIXME same question as cython: what if the current type is not in the MRO? + 0 +} + // Autoref-based specialization for handling `__next__` returning `Option` pub struct IterBaseTag; diff --git a/src/internal/get_slot.rs b/src/internal/get_slot.rs index c151e855a14..260893d4204 100644 --- a/src/internal/get_slot.rs +++ b/src/internal/get_slot.rs @@ -11,7 +11,14 @@ impl Bound<'_, PyType> { where Slot: GetSlotImpl, { - slot.get_slot(self.as_borrowed()) + // SAFETY: `self` is a valid type object. + unsafe { + slot.get_slot( + self.as_type_ptr(), + #[cfg(all(Py_LIMITED_API, not(Py_3_10)))] + is_runtime_3_10(self.py()), + ) + } } } @@ -21,13 +28,50 @@ impl Borrowed<'_, '_, PyType> { where Slot: GetSlotImpl, { - slot.get_slot(self) + // SAFETY: `self` is a valid type object. + unsafe { + slot.get_slot( + self.as_type_ptr(), + #[cfg(all(Py_LIMITED_API, not(Py_3_10)))] + is_runtime_3_10(self.py()), + ) + } } } +/// Gets a slot from a raw FFI pointer. +/// +/// Safety: +/// - `ty` must be a valid non-null pointer to a `PyTypeObject`. +/// - The Python runtime must be initialized +pub(crate) unsafe fn get_slot( + ty: *mut ffi::PyTypeObject, + slot: Slot, +) -> as GetSlotImpl>::Type +where + Slot: GetSlotImpl, +{ + slot.get_slot( + ty, + // SAFETY: the Python runtime is initialized + #[cfg(all(Py_LIMITED_API, not(Py_3_10)))] + is_runtime_3_10(crate::Python::assume_gil_acquired()), + ) +} + pub(crate) trait GetSlotImpl { type Type; - fn get_slot(self, tp: Borrowed<'_, '_, PyType>) -> Self::Type; + + /// Gets the requested slot from a type object. + /// + /// Safety: + /// - `ty` must be a valid non-null pointer to a `PyTypeObject`. + /// - `is_runtime_3_10` must be `false` if the runtime is not Python 3.10 or later. + unsafe fn get_slot( + self, + ty: *mut ffi::PyTypeObject, + #[cfg(all(Py_LIMITED_API, not(Py_3_10)))] is_runtime_3_10: bool, + ) -> Self::Type; } #[derive(Copy, Clone)] @@ -42,12 +86,14 @@ macro_rules! impl_slots { type Type = $tp; #[inline] - fn get_slot(self, tp: Borrowed<'_, '_, PyType>) -> Self::Type { - let ptr = tp.as_type_ptr(); - + unsafe fn get_slot( + self, + ty: *mut ffi::PyTypeObject, + #[cfg(all(Py_LIMITED_API, not(Py_3_10)))] is_runtime_3_10: bool + ) -> Self::Type { #[cfg(not(Py_LIMITED_API))] - unsafe { - (*ptr).$field + { + (*ty).$field } #[cfg(Py_LIMITED_API)] @@ -59,15 +105,14 @@ macro_rules! impl_slots { // (3.7, 3.8, 3.9) and then look in the type object anyway. This is only ok // because we know that the interpreter is not going to change the size // of the type objects for these historical versions. - if !is_runtime_3_10(tp.py()) - && unsafe { ffi::PyType_HasFeature(ptr, ffi::Py_TPFLAGS_HEAPTYPE) } == 0 + if !is_runtime_3_10 && ffi::PyType_HasFeature(ty, ffi::Py_TPFLAGS_HEAPTYPE) == 0 { - return unsafe { (*ptr.cast::()).$field }; + return (*ty.cast::()).$field; } } // SAFETY: slot type is set carefully to be valid - unsafe { std::mem::transmute(ffi::PyType_GetSlot(ptr, ffi::$slot)) } + std::mem::transmute(ffi::PyType_GetSlot(ty, ffi::$slot)) } } } @@ -75,11 +120,14 @@ macro_rules! impl_slots { }; } -// Slots are implemented on-demand as needed. +// Slots are implemented on-demand as needed.) impl_slots! { TP_ALLOC: (Py_tp_alloc, tp_alloc) -> Option, + TP_BASE: (Py_tp_base, tp_base) -> *mut ffi::PyTypeObject, + TP_CLEAR: (Py_tp_clear, tp_clear) -> Option, TP_DESCR_GET: (Py_tp_descr_get, tp_descr_get) -> Option, TP_FREE: (Py_tp_free, tp_free) -> Option, + TP_TRAVERSE: (Py_tp_traverse, tp_traverse) -> Option, } #[cfg(all(Py_LIMITED_API, not(Py_3_10)))] diff --git a/src/pyclass/create_type_object.rs b/src/pyclass/create_type_object.rs index c365b70fb9f..8a02baa8ad1 100644 --- a/src/pyclass/create_type_object.rs +++ b/src/pyclass/create_type_object.rs @@ -7,12 +7,12 @@ use crate::{ assign_sequence_item_from_mapping, get_sequence_item_from_mapping, tp_dealloc, tp_dealloc_with_gc, MaybeRuntimePyMethodDef, PyClassItemsIter, }, - pymethods::{Getter, Setter}, + pymethods::{Getter, PyGetterDef, PyMethodDefType, PySetterDef, Setter, _call_clear}, trampoline::trampoline, }, internal_tricks::ptr_from_ref, types::{typeobject::PyTypeMethods, PyType}, - Py, PyClass, PyGetterDef, PyMethodDefType, PyResult, PySetterDef, PyTypeInfo, Python, + Py, PyClass, PyResult, PyTypeInfo, Python, }; use std::{ collections::HashMap, @@ -432,7 +432,8 @@ impl PyTypeBuilder { unsafe { self.push_slot(ffi::Py_tp_new, no_constructor_defined as *mut c_void) } } - let tp_dealloc = if self.has_traverse || unsafe { ffi::PyType_IS_GC(self.tp_base) == 1 } { + let base_is_gc = unsafe { ffi::PyType_IS_GC(self.tp_base) == 1 }; + let tp_dealloc = if self.has_traverse || base_is_gc { self.tp_dealloc_with_gc } else { self.tp_dealloc @@ -446,6 +447,22 @@ impl PyTypeBuilder { ))); } + // If this type is a GC type, and the base also is, we may need to add + // `tp_traverse` / `tp_clear` implementations to call the base, if this type didn't + // define `__traverse__` or `__clear__`. + // + // This is because when Py_TPFLAGS_HAVE_GC is set, then `tp_traverse` and + // `tp_clear` are not inherited. + if ((self.class_flags & ffi::Py_TPFLAGS_HAVE_GC) != 0) && base_is_gc { + // If this assertion breaks, need to consider doing the same for __traverse__. + assert!(self.has_traverse); // Py_TPFLAGS_HAVE_GC is set when a `__traverse__` method is found + + if !self.has_clear { + // Safety: This is the correct slot type for Py_tp_clear + unsafe { self.push_slot(ffi::Py_tp_clear, call_super_clear as *mut c_void) } + } + } + // For sequences, implement sq_length instead of mp_length if self.is_sequence { for slot in &mut self.slots { @@ -540,6 +557,10 @@ unsafe extern "C" fn no_constructor_defined( }) } +unsafe extern "C" fn call_super_clear(slf: *mut ffi::PyObject) -> c_int { + _call_clear(slf, |_, _| Ok(()), call_super_clear) +} + #[derive(Default)] struct GetSetDefBuilder { doc: Option<&'static CStr>, diff --git a/tests/test_gc.rs b/tests/test_gc.rs index b95abd4adea..01ca2c8270c 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -565,6 +565,132 @@ fn unsendable_are_not_traversed_on_foreign_thread() { }); } +#[test] +fn test_traverse_subclass() { + #[pyclass(subclass)] + struct Base { + cycle: Option, + drop_called: Arc, + } + + #[pymethods] + impl Base { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.cycle)?; + Ok(()) + } + + fn __clear__(&mut self) { + self.cycle = None; + } + } + + impl Drop for Base { + fn drop(&mut self) { + self.drop_called.store(true, Ordering::Relaxed); + } + } + + #[pyclass(extends = Base)] + struct Sub {} + + #[pymethods] + impl Sub { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + // subclass traverse overrides the base class traverse + Ok(()) + } + } + + let drop_called = Arc::new(AtomicBool::new(false)); + + Python::with_gil(|py| { + let base = Base { + cycle: None, + drop_called: drop_called.clone(), + }; + let obj = Bound::new(py, PyClassInitializer::from(base).add_subclass(Sub {})).unwrap(); + obj.borrow_mut().as_super().cycle = Some(obj.clone().into_any().unbind()); + + drop(obj); + assert!(!drop_called.load(Ordering::Relaxed)); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + py.run_bound("import gc; gc.collect()", None, None).unwrap(); + } + + assert!(drop_called.load(Ordering::Relaxed)); + }); +} + +#[test] +fn test_traverse_subclass_override_clear() { + #[pyclass(subclass)] + struct Base { + cycle: Option, + drop_called: Arc, + } + + #[pymethods] + impl Base { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.cycle)?; + Ok(()) + } + + fn __clear__(&mut self) { + self.cycle = None; + } + } + + impl Drop for Base { + fn drop(&mut self) { + self.drop_called.store(true, Ordering::Relaxed); + } + } + + #[pyclass(extends = Base)] + struct Sub {} + + #[pymethods] + impl Sub { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + // subclass traverse overrides the base class traverse + Ok(()) + } + + fn __clear__(&self) { + // subclass clear overrides the base class clear + } + } + + let drop_called = Arc::new(AtomicBool::new(false)); + + Python::with_gil(|py| { + let base = Base { + cycle: None, + drop_called: drop_called.clone(), + }; + let obj = Bound::new(py, PyClassInitializer::from(base).add_subclass(Sub {})).unwrap(); + obj.borrow_mut().as_super().cycle = Some(obj.clone().into_any().unbind()); + + drop(obj); + assert!(!drop_called.load(Ordering::Relaxed)); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + py.run_bound("import gc; gc.collect()", None, None).unwrap(); + } + + assert!(drop_called.load(Ordering::Relaxed)); + }); +} + // Manual traversal utilities unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option {