Skip to content

Commit d2848a0

Browse files
ngoldbaumdavidhewitt
authored andcommitted
Update dict.get_item binding to use PyDict_GetItemRef (#4355)
* Update dict.get_item binding to use PyDict_GetItemRef Refs #4265 * test: add test for dict.get_item error path * test: add test for dict.get_item error path * test: add test for dict.get_item error path * fix: fix logic error in dict.get_item bindings * update: apply david's review suggestions for dict.get_item bindings * update: create ffi::compat to store compatibility shims * update: move PyDict_GetItemRef bindings to spot in order from dictobject.h * build: fix build warning with --no-default-features * doc: expand release note fragments * fix: fix clippy warnings * respond to review comments * Apply suggestion from @mejrs * refactor so cfg is applied to functions * properly set cfgs * fix clippy lints * Apply @davidhewitt's suggestion * deal with upstream deprecation of new_bound
1 parent 6f62b50 commit d2848a0

File tree

6 files changed

+106
-5
lines changed

6 files changed

+106
-5
lines changed

newsfragments/4355.added.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
* Added an `ffi::compat` namespace to store compatibility shims for C API
2+
functions added in recent versions of Python.
3+
4+
* Added bindings for `PyDict_GetItemRef` on Python 3.13 and newer. Also added
5+
`ffi::compat::PyDict_GetItemRef` which re-exports the FFI binding on Python
6+
3.13 or newer and defines a compatibility version on older versions of
7+
Python. This function is inherently safer to use than `PyDict_GetItem` and has
8+
an API that is easier to use than `PyDict_GetItemWithError`. It returns a
9+
strong reference to value, as opposed to the two older functions which return
10+
a possibly unsafe borrowed reference.

newsfragments/4355.fixed.md

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Avoid creating temporary borrowed reference in dict.get_item bindings. Borrowed
2+
references like this are unsafe in the free-threading build.

pyo3-ffi/src/compat.rs

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//! C API Compatibility Shims
2+
//!
3+
//! Some CPython C API functions added in recent versions of Python are
4+
//! inherently safer to use than older C API constructs. This module
5+
//! exposes functions available on all Python versions that wrap the
6+
//! old C API on old Python versions and wrap the function directly
7+
//! on newer Python versions.
8+
9+
// Unless otherwise noted, the compatibility shims are adapted from
10+
// the pythoncapi-compat project: https://github.com/python/pythoncapi-compat
11+
12+
#[cfg(not(Py_3_13))]
13+
use crate::object::PyObject;
14+
#[cfg(not(Py_3_13))]
15+
use std::os::raw::c_int;
16+
17+
#[cfg_attr(docsrs, doc(cfg(all)))]
18+
#[cfg(Py_3_13)]
19+
pub use crate::dictobject::PyDict_GetItemRef;
20+
21+
#[cfg_attr(docsrs, doc(cfg(all)))]
22+
#[cfg(not(Py_3_13))]
23+
pub unsafe fn PyDict_GetItemRef(
24+
dp: *mut PyObject,
25+
key: *mut PyObject,
26+
result: *mut *mut PyObject,
27+
) -> c_int {
28+
{
29+
use crate::dictobject::PyDict_GetItemWithError;
30+
use crate::object::_Py_NewRef;
31+
use crate::pyerrors::PyErr_Occurred;
32+
33+
let item: *mut PyObject = PyDict_GetItemWithError(dp, key);
34+
if !item.is_null() {
35+
*result = _Py_NewRef(item);
36+
return 1; // found
37+
}
38+
*result = std::ptr::null_mut();
39+
if PyErr_Occurred().is_null() {
40+
return 0; // not found
41+
}
42+
-1
43+
}
44+
}

pyo3-ffi/src/dictobject.rs

+6
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ extern "C" {
6666
) -> c_int;
6767
#[cfg_attr(PyPy, link_name = "PyPyDict_DelItemString")]
6868
pub fn PyDict_DelItemString(dp: *mut PyObject, key: *const c_char) -> c_int;
69+
#[cfg(Py_3_13)]
70+
pub fn PyDict_GetItemRef(
71+
dp: *mut PyObject,
72+
key: *mut PyObject,
73+
result: *mut *mut PyObject,
74+
) -> c_int;
6975
// skipped 3.10 / ex-non-limited PyObject_GenericGetDict
7076
}
7177

pyo3-ffi/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
12
//! Raw FFI declarations for Python's C API.
23
//!
34
//! PyO3 can be used to write native Python modules or run Python code and modules from Rust.
@@ -290,6 +291,8 @@ pub const fn _cstr_from_utf8_with_nul_checked(s: &str) -> &CStr {
290291

291292
use std::ffi::CStr;
292293

294+
pub mod compat;
295+
293296
pub use self::abstract_::*;
294297
pub use self::bltinmodule::*;
295298
pub use self::boolobject::*;

src/types/dict.rs

+41-5
Original file line numberDiff line numberDiff line change
@@ -436,13 +436,13 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> {
436436
key: Bound<'_, PyAny>,
437437
) -> PyResult<Option<Bound<'py, PyAny>>> {
438438
let py = dict.py();
439+
let mut result: *mut ffi::PyObject = std::ptr::null_mut();
439440
match unsafe {
440-
ffi::PyDict_GetItemWithError(dict.as_ptr(), key.as_ptr())
441-
.assume_borrowed_or_opt(py)
442-
.map(Borrowed::to_owned)
441+
ffi::compat::PyDict_GetItemRef(dict.as_ptr(), key.as_ptr(), &mut result)
443442
} {
444-
some @ Some(_) => Ok(some),
445-
None => PyErr::take(py).map(Err).transpose(),
443+
std::os::raw::c_int::MIN..=-1 => Err(PyErr::fetch(py)),
444+
0 => Ok(None),
445+
1..=std::os::raw::c_int::MAX => Ok(Some(unsafe { result.assume_owned(py) })),
446446
}
447447
}
448448

@@ -957,6 +957,42 @@ mod tests {
957957
});
958958
}
959959

960+
#[cfg(feature = "macros")]
961+
#[test]
962+
fn test_get_item_error_path() {
963+
use crate::exceptions::PyTypeError;
964+
965+
#[crate::pyclass(crate = "crate")]
966+
struct HashErrors;
967+
968+
#[crate::pymethods(crate = "crate")]
969+
impl HashErrors {
970+
#[new]
971+
fn new() -> Self {
972+
HashErrors {}
973+
}
974+
975+
fn __hash__(&self) -> PyResult<isize> {
976+
Err(PyTypeError::new_err("Error from __hash__"))
977+
}
978+
}
979+
980+
Python::with_gil(|py| {
981+
let class = py.get_type_bound::<HashErrors>();
982+
let instance = class.call0().unwrap();
983+
let d = PyDict::new_bound(py);
984+
match d.get_item(instance) {
985+
Ok(_) => {
986+
panic!("this get_item call should always error")
987+
}
988+
Err(err) => {
989+
assert!(err.is_instance_of::<PyTypeError>(py));
990+
assert_eq!(err.value_bound(py).to_string(), "Error from __hash__")
991+
}
992+
}
993+
})
994+
}
995+
960996
#[test]
961997
#[allow(deprecated)]
962998
#[cfg(all(not(any(PyPy, GraalPy)), feature = "gil-refs"))]

0 commit comments

Comments
 (0)