Skip to content

Commit e0a2c6f

Browse files
committed
Apply __bool__ conversion only to numpy.bool_ to avoid false positives.
1 parent 45afbc2 commit e0a2c6f

File tree

3 files changed

+49
-62
lines changed

3 files changed

+49
-62
lines changed

pytests/src/misc.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@ fn issue_219() {
66
Python::with_gil(|_| {});
77
}
88

9+
#[pyfunction]
10+
fn accepts_bool(val: bool) -> bool {
11+
val
12+
}
13+
914
#[pymodule]
1015
pub fn misc(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
1116
m.add_function(wrap_pyfunction!(issue_219, m)?)?;
17+
m.add_function(wrap_pyfunction!(accepts_bool, m)?)?;
1218
Ok(())
1319
}

pytests/tests/test_misc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,13 @@ def test_import_in_subinterpreter_forbidden():
4848
)
4949

5050
_xxsubinterpreters.destroy(sub_interpreter)
51+
52+
53+
def test_accepts_numpy_bool():
54+
# binary numpy wheel not available on all platforms
55+
numpy = pytest.importorskip("numpy")
56+
57+
assert pyo3_pytests.misc.accepts_bool(True) is True
58+
assert pyo3_pytests.misc.accepts_bool(False) is False
59+
assert pyo3_pytests.misc.accepts_bool(numpy.bool_(True)) is True
60+
assert pyo3_pytests.misc.accepts_bool(numpy.bool_(False)) is False

src/types/boolobject.rs

Lines changed: 33 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ impl IntoPy<PyObject> for bool {
7777
/// Fails with `TypeError` if the input is not a Python `bool`.
7878
impl<'source> FromPyObject<'source> for bool {
7979
fn extract(obj: &'source PyAny) -> PyResult<Self> {
80-
if let Ok(obj) = obj.downcast::<PyBool>() {
81-
return Ok(obj.is_true());
82-
}
80+
let err = match obj.downcast::<PyBool>() {
81+
Ok(obj) => return Ok(obj.is_true()),
82+
Err(err) => err,
83+
};
8384

8485
let missing_conversion = |obj: &PyAny| {
8586
PyTypeError::new_err(format!(
@@ -92,28 +93,42 @@ impl<'source> FromPyObject<'source> for bool {
9293
unsafe {
9394
let ptr = obj.as_ptr();
9495

95-
if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() {
96-
if let Some(nb_bool) = tp_as_number.nb_bool {
97-
match (nb_bool)(ptr) {
98-
0 => return Ok(false),
99-
1 => return Ok(true),
100-
_ => return Err(crate::PyErr::fetch(obj.py())),
96+
if libc::strcmp(
97+
(*ffi::Py_TYPE(ptr)).tp_name,
98+
b"numpy.bool_\0".as_ptr().cast(),
99+
) == 0
100+
{
101+
if let Some(tp_as_number) = (*(*ptr).ob_type).tp_as_number.as_ref() {
102+
if let Some(nb_bool) = tp_as_number.nb_bool {
103+
match (nb_bool)(ptr) {
104+
0 => return Ok(false),
105+
1 => return Ok(true),
106+
_ => return Err(crate::PyErr::fetch(obj.py())),
107+
}
101108
}
102109
}
103-
}
104110

105-
Err(missing_conversion(obj))
111+
return Err(missing_conversion(obj));
112+
}
106113
}
107114

108115
#[cfg(any(Py_LIMITED_API, PyPy))]
109116
{
110-
let meth = obj
111-
.lookup_special(crate::intern!(obj.py(), "__bool__"))?
112-
.ok_or_else(|| missing_conversion(obj))?;
113-
114-
let obj = meth.call0()?.downcast::<PyBool>()?;
115-
Ok(obj.is_true())
117+
if obj
118+
.get_type()
119+
.name()
120+
.map_or(false, |name| name == "numpy.bool_")
121+
{
122+
let meth = obj
123+
.lookup_special(crate::intern!(obj.py(), "__bool__"))?
124+
.ok_or_else(|| missing_conversion(obj))?;
125+
126+
let obj = meth.call0()?.downcast::<PyBool>()?;
127+
return Ok(obj.is_true());
128+
}
116129
}
130+
131+
Err(err.into())
117132
}
118133

119134
#[cfg(feature = "experimental-inspect")]
@@ -124,7 +139,7 @@ impl<'source> FromPyObject<'source> for bool {
124139

125140
#[cfg(test)]
126141
mod tests {
127-
use crate::types::{PyAny, PyBool, PyModule};
142+
use crate::types::{PyAny, PyBool};
128143
use crate::Python;
129144
use crate::ToPyObject;
130145

@@ -147,48 +162,4 @@ mod tests {
147162
assert!(false.to_object(py).is(PyBool::new(py, false)));
148163
});
149164
}
150-
151-
#[test]
152-
fn test_magic_method() {
153-
Python::with_gil(|py| {
154-
let module = PyModule::from_code(
155-
py,
156-
r#"
157-
class A:
158-
def __bool__(self): return True
159-
class B:
160-
def __bool__(self): return "not a bool"
161-
class C:
162-
def __len__(self): return 23
163-
class D:
164-
pass
165-
"#,
166-
"test.py",
167-
"test",
168-
)
169-
.unwrap();
170-
171-
let a = module.getattr("A").unwrap().call0().unwrap();
172-
assert!(a.extract::<bool>().unwrap());
173-
174-
let b = module.getattr("B").unwrap().call0().unwrap();
175-
assert!(matches!(
176-
&*b.extract::<bool>().unwrap_err().to_string(),
177-
"TypeError: 'str' object cannot be converted to 'PyBool'"
178-
| "TypeError: __bool__ should return bool, returned str"
179-
));
180-
181-
let c = module.getattr("C").unwrap().call0().unwrap();
182-
assert_eq!(
183-
c.extract::<bool>().unwrap_err().to_string(),
184-
"TypeError: object of type '<class 'test.C'>' does not define a '__bool__' conversion",
185-
);
186-
187-
let d = module.getattr("D").unwrap().call0().unwrap();
188-
assert_eq!(
189-
d.extract::<bool>().unwrap_err().to_string(),
190-
"TypeError: object of type '<class 'test.D'>' does not define a '__bool__' conversion",
191-
);
192-
});
193-
}
194165
}

0 commit comments

Comments
 (0)