Skip to content

Commit 965af15

Browse files
Merge pull request #32 from BetsyMcPhail/remove-old-objects-from-cache
Remove deregistered objects from the inactive overload cache
2 parents 69a5d92 + 579cd5a commit 965af15

File tree

4 files changed

+77
-2
lines changed

4 files changed

+77
-2
lines changed

include/pybind11/detail/class.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,19 @@ inline bool deregister_instance_impl(void *ptr, instance *self) {
225225
for (auto it = range.first; it != range.second; ++it) {
226226
if (self == it->second && Py_TYPE(self) == Py_TYPE(it->second)) {
227227
registered_instances.erase(it);
228+
// TODO(eric.cousineau): This section can be removed if we can use
229+
// a different key, as mentioned in `get_type_overload` for
230+
// pybind11#1922.
231+
PyObject *to_remove = (PyObject *)it->second;
232+
auto &cache = detail::get_internals().inactive_overload_cache;
233+
for (auto key = cache.begin(); key != cache.end();) {
234+
if (key->first == to_remove) {
235+
key = cache.erase(key);
236+
}
237+
else {
238+
++key;
239+
}
240+
}
228241
return true;
229242
}
230243
}

include/pybind11/pybind11.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,8 +2410,15 @@ inline function get_type_overload(const void *this_ptr, const detail::type_info
24102410
handle self = detail::get_object_handle(this_ptr, this_type);
24112411
if (!self)
24122412
return function();
2413-
handle type = self.get_type();
2414-
auto key = std::make_pair(type.ptr(), name);
2413+
// N.B. This uses `self.ptr()` instead of `type.ptr()` to resolve
2414+
// pybind11#1922.
2415+
// TODO(eric.cousineau): Consider using something like the tuple
2416+
// (cls, f"{cls.__module__}.{cls.__qualname__}", name). However, since
2417+
// internals::inactive_overload_cache only supports
2418+
// `(PyObject*, const char*)` (and thus avoids memory management), it may
2419+
// require some hand-crafting and really bumping the internals version. For
2420+
// now, leave as-is.
2421+
auto key = std::make_pair(self.ptr(), name);
24152422

24162423
/* Cache functions that aren't overloaded in Python to avoid
24172424
many costly Python dictionary lookups below */

tests/test_class.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,24 @@ TEST_SUBMODULE(class_, m) {
367367
.def(py::init<>())
368368
.def("ptr", &Aligned::ptr);
369369
#endif
370+
371+
// Test #1922 (drake#11424).
372+
class ExampleVirt2 {
373+
public:
374+
virtual ~ExampleVirt2() {}
375+
virtual std::string get_name() const { return "ExampleVirt2"; }
376+
};
377+
class PyExampleVirt2 : public ExampleVirt2 {
378+
public:
379+
std::string get_name() const override {
380+
PYBIND11_OVERLOAD(std::string, ExampleVirt2, get_name, );
381+
}
382+
};
383+
py::class_<ExampleVirt2, PyExampleVirt2>(m, "ExampleVirt2")
384+
.def(py::init())
385+
.def("get_name", &ExampleVirt2::get_name);
386+
m.def("example_virt2_get_name",
387+
[](const ExampleVirt2& obj) { return obj.get_name(); });
370388
}
371389

372390
template <int N> class BreaksBase { public: virtual ~BreaksBase() = default; };

tests/test_class.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import weakref
23

34
from pybind11_tests import class_ as m
45
from pybind11_tests import UserType, ConstructorStats
@@ -289,3 +290,39 @@ def test_aligned():
289290
if hasattr(m, "Aligned"):
290291
p = m.Aligned().ptr()
291292
assert p % 1024 == 0
293+
294+
295+
@pytest.mark.skip(
296+
reason="Generally reproducible in CPython, Python 3, non-debug, on Linux. "
297+
"However, hard to pin this down for CI.")
298+
def test_1922():
299+
# Test #1922 (drake#11424).
300+
# Define a derived class which *does not* overload the method.
301+
# WARNING: The reproduction of this failure may be platform-specific, and
302+
# seems to depend on the order of definition and/or the name of the classes
303+
# defined. For example, trying to place this and the C++ code in
304+
# `test_virtual_functions` makes `assert id_1 == id_2` below fail.
305+
class Child1(m.ExampleVirt2):
306+
pass
307+
308+
id_1 = id(Child1)
309+
assert m.example_virt2_get_name(m.ExampleVirt2()) == "ExampleVirt2"
310+
assert m.example_virt2_get_name(Child1()) == "ExampleVirt2"
311+
312+
# Now delete everything (and ensure it's deleted).
313+
wref = weakref.ref(Child1)
314+
del Child1
315+
pytest.gc_collect()
316+
assert wref() is None
317+
318+
# Define a derived class which *does* define an overload.
319+
class Child2(m.ExampleVirt2):
320+
def get_name(self):
321+
return "Child2"
322+
323+
id_2 = id(Child2)
324+
assert id_1 == id_2 # This happens in CPython; not sure about PyPy.
325+
assert m.example_virt2_get_name(m.ExampleVirt2()) == "ExampleVirt2"
326+
# THIS WILL FAIL: This is using the cached `ExampleVirt2.get_name`, rather
327+
# than re-inspect the Python dictionary.
328+
assert m.example_virt2_get_name(Child2()) == "Child2"

0 commit comments

Comments
 (0)